Skip to content

Commit 8c66072

Browse files
author
hhsecond
committed
better error message
1 parent 9de4466 commit 8c66072

File tree

3 files changed

+21
-0
lines changed

3 files changed

+21
-0
lines changed

redisai/command_builder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ def modelset(self, name: AnyStr, backend: str, device: str, data: ByteString,
1515
batch: int, minbatch: int, tag: AnyStr,
1616
inputs: Union[AnyStr, List[AnyStr]],
1717
outputs: Union[AnyStr, List[AnyStr]]) -> Sequence:
18+
if device.upper() not in utils.allowed_devices:
19+
raise ValueError(f"Device not allowed. Use any from {utils.allowed_devices}")
20+
if backend.upper() not in utils.allowed_backends:
21+
raise ValueError(f"Backend not allowed. Use any from {utils.allowed_backends}")
1822
args = ['AI.MODELSET', name, backend, device]
1923

2024
if batch is not None:
@@ -87,6 +91,8 @@ def tensorget(self,
8791
return args
8892

8993
def scriptset(self, name: AnyStr, device: str, script: str, tag: AnyStr = None) -> Sequence:
94+
if device.upper() not in utils.allowed_devices:
95+
raise ValueError(f"Device not allowed. Use any from {utils.allowed_devices}")
9096
args = ['AI.SCRIPTSET', name, device]
9197
if tag:
9298
args += ['TAG', tag]

redisai/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
'uint32': 'UINT32',
1717
'uint64': 'UINT64'}
1818

19+
allowed_devices = {'CPU', 'GPU'}
20+
allowed_backends = {'TF', 'TFLITE', 'TORCH', 'ONNX'}
21+
1922

2023
def numpy2blob(tensor: np.ndarray) -> tuple:
2124
"""Convert the numpy input from user to `Tensor`."""

test/test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,18 @@ def test_numpy_tensor(self):
106106
with self.assertRaises(TypeError):
107107
con.tensorset('trying', stringarr)
108108

109+
def test_modelset_errors(self):
110+
model_path = os.path.join(MODEL_DIR, 'graph.pb')
111+
model_pb = load_model(model_path)
112+
con = self.get_client()
113+
with self.assertRaises(ValueError):
114+
con.modelset('m', 'tf', 'wrongdevice', model_pb,
115+
inputs=['a', 'b'], outputs=['mul'], tag='v1.0')
116+
with self.assertRaises(ValueError):
117+
con.modelset('m', 'wrongbackend', 'cpu', model_pb,
118+
inputs=['a', 'b'], outputs=['mul'], tag='v1.0')
119+
120+
109121
def test_modelget_meta(self):
110122
model_path = os.path.join(MODEL_DIR, 'graph.pb')
111123
model_pb = load_model(model_path)

0 commit comments

Comments
 (0)