55# This source code is licensed under the BSD-style license found in the
66# LICENSE file in the root directory of this source tree.
77
8- # pyre-unsafe
9-
108
119from collections import defaultdict
1210
8987 QuantizeOperatorArguments ,
9088 RemoveNoopPass ,
9189 ReplaceInfValues ,
92- ReplaceScalarWithTensorArgPassTOSABI ,
93- ReplaceScalarWithTensorArgPassTOSAMI ,
90+ ReplaceScalarWithTensorByProfilePass ,
9491 RetraceFoldedDtypesPass ,
9592 RewriteConv2dPass ,
9693 RewriteMatmulPass ,
@@ -156,15 +153,15 @@ def _transform(self, graph_module: GraphModule):
156153 with TosaLoweringContext (self .tosa_spec ):
157154 return self (graph_module ).graph_module
158155
159- def _tosa_INT_pipeline (self , exported_program : ExportedProgram ) -> GraphModule :
156+ def _tosa_INT_pipeline (
157+ self , exported_program : ExportedProgram , graph_module : GraphModule
158+ ) -> GraphModule :
160159 self .add_pass (AnnotateOutputDimOrderPass ())
161160 self .add_pass (FuseQuantizedActivationPass ())
162161 self .add_pass (RemoveGetItemPass ())
163162 self .add_pass (ConvertSplitToSlicePass ())
164163 self .add_pass (ConvertMmToBmmPass ())
165- self .add_pass (
166- DecomposeMeanDimPass (exported_program .graph_module , self .tosa_spec )
167- )
164+ self .add_pass (DecomposeMeanDimPass (graph_module , self .tosa_spec ))
168165 self .add_pass (ConvertFullLikeToFullPass ())
169166 self .add_pass (ConvertToClampPass ())
170167 self .add_pass (ConvertMinMaxPass ())
@@ -174,7 +171,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
174171 self .add_pass (CastToInt32Pass ())
175172
176173 self .add_pass (CastBoolToInt8Pass ())
177- self .add_pass (ReplaceScalarWithTensorArgPassTOSABI ())
174+ self .add_pass (ReplaceScalarWithTensorByProfilePass ())
178175 self .add_pass (AnnotateDecomposedMatmulPass ())
179176 self .add_pass (QuantizeOperatorArguments ())
180177 self .add_pass (ConvertELUParamsPass ())
@@ -194,7 +191,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
194191 self .add_pass (ConvertExpandCopyToRepeatPass ())
195192 self .add_pass (UnsqueezeBeforeRepeatPass ())
196193 self .add_pass (CastInt64BuffersToInt32Pass (exported_program ))
197- self .add_pass (DecomposeSumPass ())
198194 self .add_pass (DecomposeCumsumPass (exported_program ))
199195 self .add_pass (Conv1dUnsqueezePass ())
200196 self .add_pass (DecomposeMaxPool2DPass ())
@@ -215,15 +211,18 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
215211 self .add_pass (RewriteMatmulPass ())
216212 self .add_pass (RewriteUpsamplePass ())
217213 self .add_pass (FuseEqualPlaceholdersPass (exported_program ))
214+ self .add_pass (InsertRescaleInt32Pass ())
215+ self .add_pass (DecomposeSumPass ())
218216 self .add_pass (ToTosaMemoryFormatPass (exported_program ))
219217 self .add_pass (RemoveNoopPass ())
220218 self .add_pass (InsertRescalePass ())
221- self .add_pass (InsertRescaleInt32Pass ())
222219
223220 self .validate_constraints_mandatory ()
224- return self ._transform (exported_program . graph_module )
221+ return self ._transform (graph_module )
225222
226- def _tosa_FP_pipeline (self , exported_program : ExportedProgram ) -> GraphModule :
223+ def _tosa_FP_pipeline (
224+ self , exported_program : ExportedProgram , graph_module : GraphModule
225+ ) -> GraphModule :
227226 self .add_pass (AnnotateOutputDimOrderPass ())
228227 self .add_pass (DecomposeExpm1Pass ())
229228 self .add_pass (DecomposeLogitPass ())
@@ -244,7 +243,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
244243 self .add_pass (DecomposeSinhPass ())
245244 self .add_pass (DecomposeSignPass ())
246245 self .add_pass (DecomposeDivTensorModePass ())
247- self .add_pass (ReplaceScalarWithTensorArgPassTOSAMI ())
246+ self .add_pass (ReplaceScalarWithTensorByProfilePass ())
248247 self .add_pass (DecomposeEmbeddingPass ())
249248 self .add_pass (FuseQuantizedActivationPass ())
250249 self .add_pass (RemoveGetItemPass ())
@@ -258,9 +257,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
258257 self .add_pass (DecomposeLayerNormPass ())
259258 self .add_pass (DecomposeBatchNormNoStatsPass ())
260259 self .add_pass (DecomposeVarPass ())
261- self .add_pass (
262- DecomposeMeanDimPass (exported_program .graph_module , self .tosa_spec )
263- )
260+ self .add_pass (DecomposeMeanDimPass (graph_module , self .tosa_spec ))
264261 self .add_pass (DecomposeNotEqualPass ())
265262 self .add_pass (DecomposeDivPass ())
266263 self .add_pass (DecomposeAddSubAlphaPass ())
@@ -308,14 +305,16 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
308305 self .add_pass (InsertRescalePass ())
309306
310307 self .validate_constraints_mandatory ()
311- return self ._transform (exported_program . graph_module )
308+ return self ._transform (graph_module )
312309
313- def transform_to_backend_pipeline (self , exported_program : ExportedProgram ):
310+ def transform_to_backend_pipeline (
311+ self , exported_program : ExportedProgram , graph_module : GraphModule
312+ ):
314313 """Apply passes before transforming program to backend"""
315314 if self .tosa_spec == TosaSpecification .create_from_string ("TOSA-1.0+FP" ):
316- return self ._tosa_FP_pipeline (exported_program )
315+ return self ._tosa_FP_pipeline (exported_program , graph_module )
317316 elif self .tosa_spec == TosaSpecification .create_from_string ("TOSA-1.0+INT" ):
318- return self ._tosa_INT_pipeline (exported_program )
317+ return self ._tosa_INT_pipeline (exported_program , graph_module )
319318 else :
320319 raise NotImplementedError (
321320 f"No pass pipeline implemented for { self .tosa_spec = } "
@@ -337,7 +336,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
337336 self .add_pass (DecomposeAddmmPass ())
338337 self .add_pass (DecomposeDivTensorModePass ())
339338 self .add_pass (DecomposeAddSubAlphaPass ())
340- self .add_pass (ReplaceScalarWithTensorArgPassTOSABI ())
339+ self .add_pass (ReplaceScalarWithTensorByProfilePass ())
341340 self .add_pass (ScalarsToAttributePass ())
342341 self .add_pass (DecomposeGroupNormPass ())
343342 self .add_pass (DecomposeLayerNormPass ())
@@ -361,7 +360,6 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
361360
362361 self .add_pass (ConvertMinMaxPass ())
363362 self .add_pass (ReplaceInfValues ())
364- self .add_pass (DecomposeSumPass ())
365363
366364 if not self .tosa_spec .is_U55_subset :
367365 # Uses where which is not supported on Ethos-U55
0 commit comments