Skip to content

Commit 6d4174c

Browse files
Sherin Thomaslantiga
authored andcommitted
ONNX support, bug fix, user friendly APIs (#10)
* bumbed version through version.py file * gitignore for built files * Readme.md linked to example repo * assets for onnx tests * onnx support * user friendly apis * tests for onnx support and new user friendly APIs * minor nit * testing against redisai edge * pandas to test requirments
1 parent ab579ae commit 6d4174c

File tree

13 files changed

+276
-129
lines changed

13 files changed

+276
-129
lines changed

.circleci/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ jobs:
88
build:
99
docker:
1010
- image: circleci/python:3.7.1
11-
- image: redisai/redisai:latest
11+
- image: redisai/redisai:edge
1212

1313
working_directory: ~/repo
1414

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
.project
22
.pydevproject
33
*.pyc
4-
.venv/
4+
.venv/
5+
redisai.egg-info

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@
2020
$ pip install redisai
2121
```
2222

23+
[RedisAI example repo](https://github.com/RedisAI/redisai-examples) shows few examples made using redisai-py under `python_client` section.
2324

2425

redisai/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,22 @@
1+
from .version import __version__
12
from .client import (Client, Tensor, BlobTensor, DType, Device, Backend)
3+
4+
5+
def save_model(*args, **kwargs):
6+
"""
7+
Importing inside to avoid loading the TF/PyTorch/ONNX
8+
into the scope unnecessary. This function wraps the
9+
internal save model utility to make it user friendly
10+
"""
11+
from .model import Model
12+
Model.save(*args, **kwargs)
13+
14+
15+
def load_model(*args, **kwargs):
16+
"""
17+
Importing inside to avoid loading the TF/PyTorch/ONNX
18+
into the scope unnecessary. This function wraps the
19+
internal load model utility to make it user friendly
20+
"""
21+
from .model import Model
22+
return Model.load(*args, **kwargs)

redisai/client.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,31 @@
1515

1616

1717
class Device(Enum):
18-
cpu = 'cpu'
19-
gpu = 'gpu'
18+
cpu = 'CPU'
19+
gpu = 'GPU'
2020

2121

2222
class Backend(Enum):
23-
tf = 'tf'
24-
torch = 'torch'
25-
onnx = 'ort'
23+
tf = 'TF'
24+
torch = 'TORCH'
25+
onnx = 'ONNX'
2626

2727

2828
class DType(Enum):
29-
float = 'float'
30-
double = 'double'
31-
int8 = 'int8'
32-
int16 = 'int16'
33-
int32 = 'int32'
34-
int64 = 'int64'
35-
uint8 = 'uint8'
36-
uint16 = 'uint16'
37-
uint32 = 'uint32'
38-
uint64 = 'uint64'
29+
float = 'FLOAT'
30+
double = 'DOUBLE'
31+
int8 = 'INT8'
32+
int16 = 'INT16'
33+
int32 = 'INT32'
34+
int64 = 'INT64'
35+
uint8 = 'UINT8'
36+
uint16 = 'UINT16'
37+
uint32 = 'UINT32'
38+
uint64 = 'UINT64'
3939

4040
# aliases
41-
float32 = 'float'
42-
float64 = 'double'
41+
float32 = 'FLOAT'
42+
float64 = 'DOUBLE'
4343

4444

4545
def _str_or_strlist(v):
@@ -54,7 +54,7 @@ def _convert_to_num(dt, arr):
5454
if isinstance(obj, list):
5555
_convert_to_num(obj)
5656
else:
57-
if dt in (DType.float, DType.double):
57+
if dt in (DType.float.value, DType.double.value):
5858
arr[ix] = float(obj)
5959
else:
6060
arr[ix] = int(obj)
@@ -159,10 +159,9 @@ def to_numpy(self):
159159

160160
@staticmethod
161161
def _to_numpy_type(t):
162-
t = t.lower()
163162
mm = {
164-
'float': 'float32',
165-
'double': 'float64'
163+
'FLOAT': 'float32',
164+
'DOUBLE': 'float64'
166165
}
167166
if t in mm:
168167
return mm[t]

redisai/model.py

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,34 @@
1-
import pickle
21
import os
32
import warnings
4-
5-
from .client import Device, Backend
3+
import sys
64

75
try:
86
import tensorflow as tf
97
except (ModuleNotFoundError, ImportError):
10-
pass # that's Okey if you don't have TF
8+
pass
119

1210
try:
1311
import torch
1412
except (ModuleNotFoundError, ImportError):
15-
pass # it's Okey if you don't have PT either
13+
pass
1614

15+
try:
16+
import onnx
17+
except (ModuleNotFoundError, ImportError):
18+
pass
19+
20+
try:
21+
import skl2onnx
22+
import sklearn
23+
except (ModuleNotFoundError, ImportError):
24+
pass
1725

1826

1927
class Model:
2028

2129
__slots__ = ['graph', 'backend', 'device', 'inputs', 'outputs']
22-
23-
def __init__(self, path, device=Device.cpu, inputs=None, outputs=None):
30+
31+
def __init__(self, path, device=None, inputs=None, outputs=None):
2432
"""
2533
Declare a model suitable for passing to modelset
2634
:param path: Filepath from where the stored model can be read
@@ -37,9 +45,9 @@ def __init__(self, path, device=Device.cpu, inputs=None, outputs=None):
3745
raise NotImplementedError('Instance creation is not impelemented yet')
3846

3947
@classmethod
40-
def save(cls, obj, path: str, input=None, output=None, as_native=True):
48+
def save(cls, obj, path: str, input=None, output=None, as_native=True, prototype=None):
4149
"""
42-
Infer the backend (TF/PyTorch) by inspecting the class hierarchy
50+
Infer the backend (TF/PyTorch/ONNX) by inspecting the class hierarchy
4351
and calls the appropriate serialization utility. It is essentially a
4452
wrapper over serialization mechanism of each backend
4553
:param path: Path to which the graph/model will be saved
@@ -54,15 +62,25 @@ def save(cls, obj, path: str, input=None, output=None, as_native=True):
5462
mechanism if True. If False, custom saving utility will be called
5563
which saves other informations required for modelset. Defaults to True
5664
"""
57-
if issubclass(type(obj), tf.Session):
65+
if 'tensorflow' in sys.modules and issubclass(type(obj), tf.Session):
5866
cls._save_tf_graph(obj, path, output, as_native)
59-
elif issubclass(type(type(obj)), torch.jit.ScriptMeta):
67+
elif 'torch' in sys.modules and issubclass(
68+
type(type(obj)), torch.jit.ScriptMeta):
6069
# TODO Is there a better way to check this
61-
cls._save_pt_graph(obj, path, as_native)
70+
cls._save_torch_graph(obj, path, as_native)
71+
elif 'onnx' in sys.modules and issubclass(
72+
type(obj), onnx.onnx_ONNX_RELEASE_ml_pb2.ModelProto):
73+
cls._save_onnx_graph(obj, path, as_native)
74+
elif 'skl2onnx' in sys.modules and issubclass(
75+
type(obj), sklearn.base.BaseEstimator):
76+
cls._save_sklearn_graph(obj, path, as_native, prototype)
6277
else:
63-
raise TypeError(('Invalid Object. '
64-
'Need traced graph or scripted graph from PyTorch or '
65-
'Session object from Tensorflow'))
78+
message = ("Could not find the required dependancy to export the graph object. "
79+
"`save_model` relies on serialization mechanism provided by the"
80+
" supported backends such as Tensorflow, PyTorch, ONNX or skl2onnx. "
81+
"Please install package required for serializing your graph. "
82+
"For more information, checkout the redisia-py documentation")
83+
raise RuntimeError(message)
6684

6785
@classmethod
6886
def _save_tf_graph(cls, sess, path, output, as_native):
@@ -81,10 +99,10 @@ def _save_tf_graph(cls, sess, path, output, as_native):
8199
raise NotImplementedError('Saving non-native graph is not supported yet')
82100

83101
@classmethod
84-
def _save_pt_graph(cls, graph, path, as_native):
102+
def _save_torch_graph(cls, graph, path, as_native):
85103
# TODO how to handle the cpu/gpu
86104
if as_native:
87-
if graph.training == True:
105+
if graph.training is True:
88106
warnings.warn(
89107
'Graph is in training mode. Converting to evaluation mode')
90108
graph.eval()
@@ -93,25 +111,33 @@ def _save_pt_graph(cls, graph, path, as_native):
93111
else:
94112
raise NotImplementedError('Saving non-native graph is not supported yet')
95113

96-
@staticmethod
97-
def _get_filled_dict(graph, backend, input=None, output=None):
98-
return {
99-
'graph': graph,
100-
'backend': backend,
101-
'input': input,
102-
'output': output}
114+
@classmethod
115+
def _save_onnx_graph(cls, graph, path, as_native):
116+
if as_native:
117+
with open(path, 'wb') as f:
118+
f.write(graph.SerializeToString())
119+
else:
120+
raise NotImplementedError('Saving non-native graph is not supported yet')
103121

104-
@staticmethod
105-
def _write_custom_model(outdict, path):
106-
with open(path, 'wb') as file:
107-
pickle.dump(outdict, file)
122+
@classmethod
123+
def _save_sklearn_graph(cls, graph, path, as_native, prototype):
124+
if not as_native:
125+
raise NotImplementedError('Saving non-native graph is not supported yet')
126+
if hasattr(prototype, 'shape') and hasattr(prototype, 'dtype'):
127+
datatype = skl2onnx.common.data_types.guess_data_type(prototype)
128+
serialized = skl2onnx.convert_sklearn(graph, initial_types=datatype)
129+
cls._save_onnx_graph(serialized, path, as_native)
130+
else:
131+
raise TypeError(
132+
"Serializing scikit learn model needs to know shape and dtype"
133+
" of input data which will be inferred from `prototype` "
134+
"parameter. It has to be a valid `numpy.ndarray` of shape of your input")
108135

109136
@classmethod
110-
def load(cls, path:str):
137+
def load(cls, path: str):
111138
"""
112139
Return the binary data if saved with `as_native` otherwise return the dict
113-
that contains binary graph/model on `graph` key. Check `_get_filled_dict`
114-
for more details.
140+
that contains binary graph/model on `graph` key (Not implemented yet).
115141
:param path: File path from where the native model or the rai models are saved
116142
"""
117143
with open(path, 'rb') as f:

redisai/version.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Store the version here so:
2+
# 1) we don't load dependencies by storing it in __init__.py
3+
# 2) we can import it in setup.py for the same reason
4+
# 3) we can import it into your module module
5+
__version__ = '0.3.0'

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
#!/usr/bin/env python
33
from setuptools import setup, find_packages
44

5+
exec(open('redisai/version.py').read())
56

67
setup(
78
name='redisai',
8-
version='0.2.0',
9+
version=__version__, # comes from redisai/version.py
910

1011
description='RedisAI Python Client',
1112
url='http://github.com/RedisAI/redisai-py',

test-requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
numpy
22
torch
3-
tensorflow
3+
tensorflow
4+
onnx
5+
skl2onnx
6+
pandas

test/test.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
import os.path
44
from redisai import Client, DType, Backend, Device, Tensor, BlobTensor
5+
from redisai import load_model
56
from redis.exceptions import ResponseError
67

78

@@ -44,14 +45,11 @@ def test_numpy_tensor(self):
4445
self.assertEqual([2, 3], values)
4546

4647
def test_run_tf_model(self):
47-
model = os.path.join(MODEL_DIR, 'graph.pb')
48-
bad_model = os.path.join(MODEL_DIR, 'pt-minimal.pt')
48+
model_path = os.path.join(MODEL_DIR, 'graph.pb')
49+
bad_model_path = os.path.join(MODEL_DIR, 'pt-minimal.pt')
4950

50-
with open(model, 'rb') as f:
51-
model_pb = f.read()
52-
53-
with open(bad_model, 'rb') as f:
54-
wrong_model_pb = f.read()
51+
model_pb = load_model(model_path)
52+
wrong_model_pb = load_model(bad_model_path)
5553

5654
con = self.get_client()
5755
con.modelset('m', Backend.tf, Device.cpu, model_pb,
@@ -96,5 +94,28 @@ def bar(a, b):
9694
tensor = con.tensorget('c')
9795
self.assertEqual([4, 6], tensor.value)
9896

97+
def test_run_onnxml_model(self):
98+
mlmodel_path = os.path.join(MODEL_DIR, 'boston.onnx')
99+
onnxml_model = load_model(mlmodel_path)
100+
con = self.get_client()
101+
con.modelset("onnx_model", Backend.onnx, Device.cpu, onnxml_model)
102+
tensor = BlobTensor.from_numpy(np.ones((1, 13), dtype=np.float32))
103+
con.tensorset("input", tensor)
104+
con.modelrun("onnx_model", ["input"], ["output"])
105+
outtensor = con.tensorget("output")
106+
self.assertEqual(int(outtensor.value[0]), 24)
107+
108+
def test_run_onnxdl_model(self):
109+
# A PyTorch model that finds the square
110+
dlmodel_path = os.path.join(MODEL_DIR, 'findsquare.onnx')
111+
onnxdl_model = load_model(dlmodel_path)
112+
con = self.get_client()
113+
con.modelset("onnx_model", Backend.onnx, Device.cpu, onnxdl_model)
114+
tensor = BlobTensor.from_numpy(np.array((2, 3), dtype=np.float32))
115+
con.tensorset("input", tensor)
116+
con.modelrun("onnx_model", ["input"], ["output"])
117+
outtensor = con.tensorget("output")
118+
self.assertEqual(outtensor.value, [4.0, 9.0])
119+
99120

100-
# TODO: image/blob tests; more numpy tests..
121+
# TODO: image/blob tests; more numpy tests..

0 commit comments

Comments
 (0)