Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 7ab2fe2

Browse files
authored
Allow input overrides in check_correctness.py (#307)
* Allow input overrides in check_correctness.py * review comments, style + quality
1 parent aa541b6 commit 7ab2fe2

File tree

1 file changed

+32
-16
lines changed

1 file changed

+32
-16
lines changed

examples/benchmark/check_correctness.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,13 @@
4444

4545
import argparse
4646

47-
import onnxruntime
48-
4947
from deepsparse import compile_model, cpu
48+
from deepsparse.benchmark_model.ort_engine import ORTEngine
5049
from deepsparse.utils import (
5150
generate_random_inputs,
52-
get_input_names,
53-
get_output_names,
54-
override_onnx_batch_size,
51+
model_to_path,
52+
override_onnx_input_shapes,
53+
parse_input_shapes,
5554
verify_outputs,
5655
)
5756

@@ -81,30 +80,47 @@ def parse_args():
8180
help="The batch size to run the analysis for",
8281
)
8382

83+
parser.add_argument(
84+
"-shapes",
85+
"--input_shapes",
86+
type=str,
87+
default="",
88+
help="Override the shapes of the inputs, "
89+
'i.e., -shapes "[1,2,3],[4,5,6],[7,8,9]" results in '
90+
"input0=[1,2,3] input1=[4,5,6] input2=[7,8,9]. ",
91+
)
92+
8493
return parser.parse_args()
8594

8695

8796
def main():
8897
args = parse_args()
89-
onnx_filepath = args.onnx_filepath
98+
onnx_filepath = model_to_path(args.onnx_filepath)
9099
batch_size = args.batch_size
91100

92-
inputs = generate_random_inputs(onnx_filepath, batch_size)
93-
input_names = get_input_names(onnx_filepath)
94-
output_names = get_output_names(onnx_filepath)
95-
inputs_dict = {name: value for name, value in zip(input_names, inputs)}
101+
input_shapes = parse_input_shapes(args.input_shapes)
102+
103+
if input_shapes:
104+
with override_onnx_input_shapes(onnx_filepath, input_shapes) as model_path:
105+
inputs = generate_random_inputs(model_path, args.batch_size)
106+
else:
107+
inputs = generate_random_inputs(onnx_filepath, args.batch_size)
96108

97109
# ONNXRuntime inference
98110
print("Executing model with ONNXRuntime...")
99-
sess_options = onnxruntime.SessionOptions()
100-
with override_onnx_batch_size(onnx_filepath, batch_size) as override_onnx_filepath:
101-
ort_network = onnxruntime.InferenceSession(override_onnx_filepath, sess_options)
102-
103-
ort_outputs = ort_network.run(output_names, inputs_dict)
111+
ort_network = ORTEngine(
112+
model=onnx_filepath,
113+
batch_size=batch_size,
114+
num_cores=None,
115+
input_shapes=input_shapes,
116+
)
117+
ort_outputs = ort_network.run(inputs)
104118

105119
# DeepSparse Engine inference
106120
print("Executing model with DeepSparse Engine...")
107-
dse_network = compile_model(onnx_filepath, batch_size=batch_size)
121+
dse_network = compile_model(
122+
onnx_filepath, batch_size=batch_size, input_shapes=input_shapes
123+
)
108124
dse_outputs = dse_network(inputs)
109125

110126
verify_outputs(dse_outputs, ort_outputs)

0 commit comments

Comments
 (0)