diff --git a/src/main/java/com/redislabs/redisai/DataType.java b/src/main/java/com/redislabs/redisai/DataType.java index 336a10d..ba3309f 100644 --- a/src/main/java/com/redislabs/redisai/DataType.java +++ b/src/main/java/com/redislabs/redisai/DataType.java @@ -109,23 +109,6 @@ protected Object toObject(List data) { raw = SafeEncoder.encode(this.name()); } - static DataType getDataTypefromString(String dtypeRaw) { - DataType dt = null; - if (dtypeRaw.equals(DataType.INT32.name())) { - dt = DataType.INT32; - } - if (dtypeRaw.equals(DataType.INT64.name())) { - dt = DataType.INT64; - } - if (dtypeRaw.equals(DataType.FLOAT.name())) { - dt = DataType.FLOAT; - } - if (dtypeRaw.equals(DataType.DOUBLE.name())) { - dt = DataType.DOUBLE; - } - return dt; - } - /** The class for the data type to which Java object o corresponds. */ public static DataType baseObjType(Object o) { Class c = o.getClass(); diff --git a/src/main/java/com/redislabs/redisai/Tensor.java b/src/main/java/com/redislabs/redisai/Tensor.java index faaac72..92cb14f 100644 --- a/src/main/java/com/redislabs/redisai/Tensor.java +++ b/src/main/java/com/redislabs/redisai/Tensor.java @@ -38,10 +38,7 @@ protected static Tensor createTensorFromRespReply(List reply) { switch (arrayKey) { case "dtype": String dtypeString = SafeEncoder.encode((byte[]) reply.get(i + 1)); - dtype = DataType.getDataTypefromString(dtypeString); - if (dtype == null) { - throw new JRedisAIRunTimeException("Unrecognized datatype: " + dtypeString); - } + dtype = DataType.valueOf(dtypeString); break; case "shape": List shapeResp = (List) reply.get(i + 1); diff --git a/src/test/java/com/redislabs/redisai/DataTypeTest.java b/src/test/java/com/redislabs/redisai/DataTypeTest.java deleted file mode 100644 index 8b307ed..0000000 --- a/src/test/java/com/redislabs/redisai/DataTypeTest.java +++ /dev/null @@ -1,26 +0,0 @@ -package com.redislabs.redisai; - -import org.junit.Assert; -import org.junit.Test; - -public class DataTypeTest { - - @Test - public void getDataTypefromString() { - DataType dtypef = DataType.getDataTypefromString("FLOAT"); - Assert.assertEquals("FLOAT", dtypef.name()); - Assert.assertEquals(DataType.FLOAT.getRaw(), dtypef.getRaw()); - - DataType dtyped = DataType.getDataTypefromString("DOUBLE"); - Assert.assertEquals("DOUBLE", dtyped.name()); - Assert.assertEquals(DataType.DOUBLE.getRaw(), dtyped.getRaw()); - - DataType dtypei = DataType.getDataTypefromString("INT32"); - Assert.assertEquals("INT32", dtypei.name()); - Assert.assertEquals(DataType.INT32.getRaw(), dtypei.getRaw()); - - DataType dtypel = DataType.getDataTypefromString("INT64"); - Assert.assertEquals("INT64", dtypel.name()); - Assert.assertEquals(DataType.INT64.getRaw(), dtypel.getRaw()); - } -} diff --git a/src/test/java/com/redislabs/redisai/RedisAITest.java b/src/test/java/com/redislabs/redisai/RedisAITest.java index 848573a..bad50f0 100644 --- a/src/test/java/com/redislabs/redisai/RedisAITest.java +++ b/src/test/java/com/redislabs/redisai/RedisAITest.java @@ -1,6 +1,7 @@ package com.redislabs.redisai; import java.io.IOException; +import java.net.URISyntaxException; import java.nio.file.Files; import java.nio.file.Paths; import java.util.Map; @@ -175,9 +176,9 @@ public void testSetModelNegative() { @Test public void testSetModelFromModelOnnx() { try { - ClassLoader classLoader = getClass().getClassLoader(); - String modelPath = classLoader.getResource("test_data/mnist.onnx").getFile(); - byte[] blob = Files.readAllBytes(Paths.get(modelPath)); + byte[] blob = + Files.readAllBytes( + Paths.get(getClass().getClassLoader().getResource("test_data/mnist.onnx").toURI())); Model m1 = new Model(Backend.ONNX, Device.CPU, new String[] {}, new String[] {}, blob); Assert.assertTrue(client.setModel("mnist.onnx", m1)); Model m2 = client.getModel("mnist.onnx"); @@ -187,7 +188,7 @@ public void testSetModelFromModelOnnx() { Model m3 = client.getModel("mnist.onnx.m2"); Assert.assertEquals(m2.getDevice(), m3.getDevice()); Assert.assertEquals(m2.getBackend(), m3.getBackend()); - } catch (IOException e) { + } catch (IOException | URISyntaxException e) { e.printStackTrace(); } } @@ -195,9 +196,13 @@ public void testSetModelFromModelOnnx() { @Test public void testSetModelFromModelTFLite() { try { - ClassLoader classLoader = getClass().getClassLoader(); - String modelPath = classLoader.getResource("test_data/mnist_model_quant.tflite").getFile(); - byte[] blob = Files.readAllBytes(Paths.get(modelPath)); + byte[] blob = + Files.readAllBytes( + Paths.get( + getClass() + .getClassLoader() + .getResource("test_data/mnist_model_quant.tflite") + .toURI())); Model m1 = new Model(Backend.TFLITE, Device.CPU, new String[] {}, new String[] {}, blob); Assert.assertTrue(client.setModel("mnist.tflite", m1)); Model m2 = client.getModel("mnist.tflite"); @@ -207,7 +212,7 @@ public void testSetModelFromModelTFLite() { Model m3 = client.getModel("mnist.tflite.m2"); Assert.assertEquals(m2.getDevice(), m3.getDevice()); Assert.assertEquals(m2.getBackend(), m3.getBackend()); - } catch (IOException e) { + } catch (IOException | URISyntaxException e) { e.printStackTrace(); } }