Skip to content

Commit 70004b1

Browse files
1vndeliahu
authored andcommitted
Fix shape inference to TF serving (#319)
(cherry picked from commit ec3a798)
1 parent 8956b63 commit 70004b1

File tree

1 file changed

+8
-2
lines changed
  • pkg/workloads/cortex/tf_api

1 file changed

+8
-2
lines changed

pkg/workloads/cortex/tf_api/api.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,17 @@ def create_raw_prediction_request(sample):
150150
prediction_request.model_spec.signature_name = signature_key
151151

152152
for column_name, value in sample.items():
153-
shape = [1]
154153
if util.is_list(value):
155154
shape = [len(value)]
155+
for dim in signature_def[signature_key]["inputs"][column_name]["tensorShape"]["dim"][
156+
1:
157+
]:
158+
shape.append(int(dim["size"]))
159+
else:
160+
shape = [1]
161+
value = [value]
156162
sig_type = signature_def[signature_key]["inputs"][column_name]["dtype"]
157-
tensor_proto = tf.make_tensor_proto([value], dtype=DTYPE_TO_TF_TYPE[sig_type], shape=shape)
163+
tensor_proto = tf.make_tensor_proto(value, dtype=DTYPE_TO_TF_TYPE[sig_type], shape=shape)
158164
prediction_request.inputs[column_name].CopyFrom(tensor_proto)
159165

160166
return prediction_request

0 commit comments

Comments
 (0)