|
44 | 44 |
|
45 | 45 | import argparse |
46 | 46 |
|
47 | | -import onnxruntime |
48 | | - |
49 | 47 | from deepsparse import compile_model, cpu |
| 48 | +from deepsparse.benchmark_model.ort_engine import ORTEngine |
50 | 49 | from deepsparse.utils import ( |
51 | 50 | generate_random_inputs, |
52 | | - get_input_names, |
53 | | - get_output_names, |
54 | | - override_onnx_batch_size, |
| 51 | + model_to_path, |
| 52 | + override_onnx_input_shapes, |
| 53 | + parse_input_shapes, |
55 | 54 | verify_outputs, |
56 | 55 | ) |
57 | 56 |
|
@@ -81,30 +80,47 @@ def parse_args(): |
81 | 80 | help="The batch size to run the analysis for", |
82 | 81 | ) |
83 | 82 |
|
| 83 | + parser.add_argument( |
| 84 | + "-shapes", |
| 85 | + "--input_shapes", |
| 86 | + type=str, |
| 87 | + default="", |
| 88 | + help="Override the shapes of the inputs, " |
| 89 | + 'i.e., -shapes "[1,2,3],[4,5,6],[7,8,9]" results in ' |
| 90 | + "input0=[1,2,3] input1=[4,5,6] input2=[7,8,9]. ", |
| 91 | + ) |
| 92 | + |
84 | 93 | return parser.parse_args() |
85 | 94 |
|
86 | 95 |
|
87 | 96 | def main(): |
88 | 97 | args = parse_args() |
89 | | - onnx_filepath = args.onnx_filepath |
| 98 | + onnx_filepath = model_to_path(args.onnx_filepath) |
90 | 99 | batch_size = args.batch_size |
91 | 100 |
|
92 | | - inputs = generate_random_inputs(onnx_filepath, batch_size) |
93 | | - input_names = get_input_names(onnx_filepath) |
94 | | - output_names = get_output_names(onnx_filepath) |
95 | | - inputs_dict = {name: value for name, value in zip(input_names, inputs)} |
| 101 | + input_shapes = parse_input_shapes(args.input_shapes) |
| 102 | + |
| 103 | + if input_shapes: |
| 104 | + with override_onnx_input_shapes(onnx_filepath, input_shapes) as model_path: |
| 105 | + inputs = generate_random_inputs(model_path, args.batch_size) |
| 106 | + else: |
| 107 | + inputs = generate_random_inputs(onnx_filepath, args.batch_size) |
96 | 108 |
|
97 | 109 | # ONNXRuntime inference |
98 | 110 | print("Executing model with ONNXRuntime...") |
99 | | - sess_options = onnxruntime.SessionOptions() |
100 | | - with override_onnx_batch_size(onnx_filepath, batch_size) as override_onnx_filepath: |
101 | | - ort_network = onnxruntime.InferenceSession(override_onnx_filepath, sess_options) |
102 | | - |
103 | | - ort_outputs = ort_network.run(output_names, inputs_dict) |
| 111 | + ort_network = ORTEngine( |
| 112 | + model=onnx_filepath, |
| 113 | + batch_size=batch_size, |
| 114 | + num_cores=None, |
| 115 | + input_shapes=input_shapes, |
| 116 | + ) |
| 117 | + ort_outputs = ort_network.run(inputs) |
104 | 118 |
|
105 | 119 | # DeepSparse Engine inference |
106 | 120 | print("Executing model with DeepSparse Engine...") |
107 | | - dse_network = compile_model(onnx_filepath, batch_size=batch_size) |
| 121 | + dse_network = compile_model( |
| 122 | + onnx_filepath, batch_size=batch_size, input_shapes=input_shapes |
| 123 | + ) |
108 | 124 | dse_outputs = dse_network(inputs) |
109 | 125 |
|
110 | 126 | verify_outputs(dse_outputs, ort_outputs) |
|
0 commit comments