3
3
# Please refer to the license found in the LICENSE file in the root directory of the source tree.
4
4
5
5
import argparse
6
+ import collections
6
7
import copy
7
8
8
9
import pathlib
23
24
from executorch .exir import to_edge
24
25
25
26
from executorch .exir .backend .backend_api import to_backend
26
-
27
- from torch .export import export
27
+ from executorch .extension .export_util .utils import save_pte_program
28
28
29
29
REPO_ROOT = pathlib .Path (__file__ ).resolve ().parent .parent .parent .parent .parent
30
30
EXAMPLES_DIR = REPO_ROOT / "examples"
41
41
)
42
42
43
43
44
- def parse_args () -> argparse .ArgumentParser :
44
+ def is_fbcode ():
45
+ return not hasattr (torch .version , "git_version" )
46
+
47
+
48
+ _CAN_RUN_WITH_PYBINDINGS = (sys .platform == "darwin" ) and not is_fbcode ()
49
+ if _CAN_RUN_WITH_PYBINDINGS :
50
+ from executorch .runtime import Runtime
51
+
52
+
53
+ def parse_args () -> argparse .Namespace :
45
54
parser = argparse .ArgumentParser ()
46
55
47
56
parser .add_argument (
@@ -82,9 +91,12 @@ def parse_args() -> argparse.ArgumentParser:
82
91
required = False ,
83
92
default = False ,
84
93
)
94
+ parser .add_argument (
95
+ "--run_with_pybindings" ,
96
+ action = argparse .BooleanOptionalAction ,
97
+ )
85
98
86
99
args = parser .parse_args ()
87
- # pyre-fixme[7]: Expected `ArgumentParser` but got `Namespace`.
88
100
return args
89
101
90
102
@@ -95,7 +107,8 @@ def partition_module_to_coreml(module):
95
107
def lower_module_to_coreml (module , compile_specs , example_inputs ):
96
108
module = module .eval ()
97
109
edge = to_edge (
98
- export (module , example_inputs , strict = True ), compile_config = _EDGE_COMPILE_CONFIG
110
+ torch .export .export (module , example_inputs , strict = True ),
111
+ compile_config = _EDGE_COMPILE_CONFIG ,
99
112
)
100
113
# All of the subsequent calls on the edge_dialect_graph generated above (such as delegation or
101
114
# to_executorch()) are done in place and the graph is also modified in place. For debugging purposes
@@ -115,24 +128,23 @@ def lower_module_to_coreml(module, compile_specs, example_inputs):
115
128
def export_lowered_module_to_executorch_program (lowered_module , example_inputs ):
116
129
lowered_module (* example_inputs )
117
130
exec_prog = to_edge (
118
- export (lowered_module , example_inputs , strict = True ),
131
+ torch . export . export (lowered_module , example_inputs , strict = True ),
119
132
compile_config = _EDGE_COMPILE_CONFIG ,
120
133
).to_executorch (config = exir .ExecutorchBackendConfig (extract_delegate_segments = True ))
121
134
122
135
return exec_prog
123
136
124
137
125
- def save_executorch_program (exec_prog , model_name , compute_unit ):
126
- buffer = exec_prog .buffer
127
- filename = f"{ model_name } _coreml_{ compute_unit } .pte"
128
- print (f"Saving exported program to { filename } " )
129
- with open (filename , "wb" ) as file :
130
- file .write (buffer )
131
- return
138
+ def get_pte_base_name (args : argparse .Namespace ) -> str :
139
+ pte_name = args .model_name
140
+ if args .compile :
141
+ pte_name += "_compiled"
142
+ pte_name = f"{ pte_name } _coreml_{ args .compute_unit } "
143
+ return pte_name
132
144
133
145
134
- def save_processed_bytes (processed_bytes , model_name , compute_unit ):
135
- filename = f"{ model_name } _coreml_ { compute_unit } .bin"
146
+ def save_processed_bytes (processed_bytes , base_name : str ):
147
+ filename = f"{ base_name } .bin"
136
148
print (f"Saving processed bytes to { filename } " )
137
149
with open (filename , "wb" ) as file :
138
150
file .write (processed_bytes )
@@ -154,6 +166,37 @@ def generate_compile_specs_from_args(args):
154
166
)
155
167
156
168
169
+ def run_with_pybindings (executorch_program , eager_reference , example_inputs , precision ):
170
+ if not _CAN_RUN_WITH_PYBINDINGS :
171
+ raise RuntimeError ("Cannot run with pybindings on this platform." )
172
+
173
+ dtype = {
174
+ "float32" : torch .float32 ,
175
+ "float16" : torch .float16 ,
176
+ }[precision ]
177
+
178
+ runtime = Runtime .get ()
179
+ program = runtime .load_program (executorch_program .buffer )
180
+ method = program .load_method ("forward" )
181
+ et_outputs = method .execute (* example_inputs )[0 ]
182
+ eager_outputs = eager_reference (* example_inputs )
183
+ if isinstance (eager_outputs , collections .OrderedDict ):
184
+ eager_outputs = eager_outputs ["out" ]
185
+ if isinstance (eager_outputs , list | tuple ):
186
+ eager_outputs = eager_outputs [0 ]
187
+
188
+ mse = ((et_outputs - eager_outputs ) ** 2 ).mean ().sqrt ()
189
+ print (f"Mean square error: { mse } " )
190
+ assert mse < 0.1 , "Mean square error is too high."
191
+
192
+ if dtype == torch .float32 :
193
+ assert torch .allclose (
194
+ et_outputs , eager_outputs , atol = 1e-02 , rtol = 1e-02
195
+ ), f"""Outputs do not match eager reference:
196
+ \t et_outputs (first 5)={ et_outputs .reshape (- 1 )[0 :5 ]}
197
+ \t eager_outputs (first 5)={ eager_outputs .reshape (- 1 )[0 :5 ]} """
198
+
199
+
157
200
def main ():
158
201
args = parse_args ()
159
202
@@ -170,49 +213,65 @@ def main():
170
213
f"Valid compute units are { valid_compute_units } ."
171
214
)
172
215
173
- model , example_inputs , _ , dynamic_shapes = EagerModelFactory . create_model (
174
- * MODEL_NAME_TO_MODEL [args .model_name ]
216
+ model , example_args , example_kwargs , dynamic_shapes = (
217
+ EagerModelFactory . create_model ( * MODEL_NAME_TO_MODEL [args .model_name ])
175
218
)
176
219
if not args .dynamic_shapes :
177
220
dynamic_shapes = None
178
221
179
222
compile_specs = generate_compile_specs_from_args (args )
180
- lowered_module = None
181
-
223
+ pte_base_name = get_pte_base_name (args )
182
224
if args .use_partitioner :
183
- model .eval ()
184
- exir_program_aten = torch .export .export (
185
- model , example_inputs , dynamic_shapes = dynamic_shapes , strict = True
186
- )
187
-
188
- edge_program_manager = exir .to_edge (exir_program_aten )
189
- edge_copy = copy .deepcopy (edge_program_manager )
190
- partitioner = CoreMLPartitioner (
191
- skip_ops_for_coreml_delegation = None , compile_specs = compile_specs
225
+ model = model .eval ()
226
+ assert not args .generate_etrecord , "ETRecord is not supported with partitioner"
227
+ ep = torch .export .export (
228
+ model ,
229
+ args = example_args ,
230
+ kwargs = example_kwargs ,
231
+ dynamic_shapes = dynamic_shapes ,
192
232
)
193
- delegated_program_manager = edge_program_manager .to_backend (partitioner )
194
- exec_program = delegated_program_manager .to_executorch (
195
- config = exir .ExecutorchBackendConfig (extract_delegate_segments = True )
233
+ print (ep )
234
+ delegated_program = exir .to_edge_transform_and_lower (
235
+ ep ,
236
+ partitioner = [CoreMLPartitioner (compile_specs = compile_specs )],
196
237
)
238
+ exec_program = delegated_program .to_executorch ()
239
+ save_pte_program (exec_program , pte_base_name )
240
+ if args .run_with_pybindings :
241
+ run_with_pybindings (
242
+ executorch_program = exec_program ,
243
+ eager_reference = model ,
244
+ example_inputs = example_args ,
245
+ precision = args .compute_precision ,
246
+ )
197
247
else :
198
248
lowered_module , edge_copy = lower_module_to_coreml (
199
249
module = model ,
200
- example_inputs = example_inputs ,
250
+ example_inputs = example_args ,
201
251
compile_specs = compile_specs ,
202
252
)
203
253
exec_program = export_lowered_module_to_executorch_program (
204
254
lowered_module ,
205
- example_inputs ,
206
- )
207
-
208
- model_name = f"{ args .model_name } _compiled" if args .compile else args .model_name
209
- save_executorch_program (exec_program , model_name , args .compute_unit )
210
- generate_etrecord (f"{ args .model_name } _coreml_etrecord.bin" , edge_copy , exec_program )
211
-
212
- if args .save_processed_bytes and lowered_module is not None :
213
- save_processed_bytes (
214
- lowered_module .processed_bytes , args .model_name , args .compute_unit
255
+ example_args ,
215
256
)
257
+ save_pte_program (exec_program , pte_base_name )
258
+ if args .generate_etrecord :
259
+ generate_etrecord (
260
+ f"{ args .model_name } _coreml_etrecord.bin" , edge_copy , exec_program
261
+ )
262
+
263
+ if args .save_processed_bytes :
264
+ save_processed_bytes (
265
+ lowered_module .processed_bytes ,
266
+ pte_base_name ,
267
+ )
268
+ if args .run_with_pybindings :
269
+ run_with_pybindings (
270
+ executorch_program = exec_program ,
271
+ eager_reference = model ,
272
+ example_inputs = example_args ,
273
+ precision = args .compute_precision ,
274
+ )
216
275
217
276
218
277
if __name__ == "__main__" :
0 commit comments