24
24
linear_regression ,
25
25
pme ,
26
26
tf_micro ,
27
- nnom
27
+ nnom ,
28
28
)
29
29
from django .core .exceptions import ValidationError
30
30
35
35
"bonsai" ,
36
36
"pme" ,
37
37
"linear_regression" ,
38
- "nnom"
38
+ "nnom" ,
39
39
]
40
40
41
41
CLASSIFER_MAP = {
42
42
"decision tree ensemble" : "decision_tree_ensemble" ,
43
43
"tensorflow lite for microcontrollers" : "tf_micro" ,
44
+ "Neural Network" : "tf_micro" ,
44
45
"nnom" : "nnom" ,
45
46
"pme" : "pme" ,
46
47
"boosted tree ensemble" : "boosted_tree_ensemble" ,
@@ -61,7 +62,7 @@ def get_classifier_type(model_configuration):
61
62
return classifier_type .lower ()
62
63
63
64
64
- #TODO: Make this an interface that returns the object instead of having all of these if statements
65
+ # TODO: Make this an interface that returns the object instead of having all of these if statements
65
66
class ModelGen :
66
67
@staticmethod
67
68
def create_classifier_structures (classifier_type , kb_models ):
@@ -82,10 +83,10 @@ def create_classifier_structures(classifier_type, kb_models):
82
83
83
84
if classifier_type == "linear_regression" :
84
85
return linear_regression .create_classifier_structures (kb_models )
85
-
86
+
86
87
if classifier_type == "nnom" :
87
88
return nnom .create_classifier_structures (kb_models )
88
-
89
+
89
90
return ""
90
91
91
92
@staticmethod
@@ -107,7 +108,7 @@ def create_max_tmp_parameters(classifier_type, kb_models):
107
108
108
109
if classifier_type == "linear_regression" :
109
110
return linear_regression .create_max_tmp_parameters (kb_models )
110
-
111
+
111
112
if classifier_type == "nnom" :
112
113
return nnom .create_max_tmp_parameters (kb_models )
113
114
@@ -151,7 +152,7 @@ def validate_model_parameters(model_parameters, model_configuration):
151
152
152
153
if classifier_type == "linear_regression" :
153
154
return linear_regression .validate_model_parameters (model_parameters )
154
-
155
+
155
156
if classifier_type == "nnom" :
156
157
return nnom .validate_model_parameters (model_parameters )
157
158
@@ -180,8 +181,7 @@ def validate_model_configuration(model_configuration):
180
181
181
182
if classifier_type == "linear_regression" :
182
183
return linear_regression .validate_model_configuration (model_configuration )
183
-
184
-
184
+
185
185
if classifier_type == "nnom" :
186
186
return nnom .validate_model_configuration (model_configuration )
187
187
@@ -204,7 +204,7 @@ def get_output_tensor_size(classifier_type, model):
204
204
205
205
if classifier_type == "linear_regression" :
206
206
return linear_regression .get_output_tensor_size (model )
207
-
207
+
208
208
if classifier_type == "nnom" :
209
209
return nnom .get_output_tensor_size (model )
210
210
@@ -232,7 +232,7 @@ def get_input_feature_type(model):
232
232
233
233
if classifier_type == "linear_regression" :
234
234
return FLOAT
235
-
235
+
236
236
if classifier_type == "nnom" :
237
237
return UINT8_T
238
238
@@ -263,7 +263,7 @@ def get_input_feature_def(model):
263
263
264
264
if classifier_type == "nnom" :
265
265
return UINT8_T
266
-
266
+
267
267
raise ValueError ("No classifier type found" )
268
268
269
269
@staticmethod
@@ -273,7 +273,7 @@ def get_model_type(model):
273
273
CLASSIFICATION = 1
274
274
if classifier_type == "tf_micro" :
275
275
return CLASSIFICATION
276
-
276
+
277
277
if classifier_type == "nnom" :
278
278
return CLASSIFICATION
279
279
0 commit comments