@@ -177,6 +177,29 @@ def __post_init__(self):
177
177
178
178
179
179
class QnnQuantizer (Quantizer ):
180
+ """
181
+ QnnQuantizer is a quantization annotator designed for QNN backends.
182
+ It uses OP_ANNOTATOR, a dictionary mapping OpOverload to annotator functions,
183
+ to determine how each node should be annotated for quantization.
184
+
185
+ Example usage:
186
+ quantizer = QnnQuantizer()
187
+ quantizer.set_default_quant_config(
188
+ quant_dtype=QuantDtype.use_8a8w,
189
+ is_qat=False,
190
+ is_conv_per_channel=True,
191
+ is_linear_per_channel=True,
192
+ act_observer=MovingAverageMinMaxObserver,
193
+ )
194
+ quantizer.set_block_size_map({"conv2d": (1, 128, 1, 1)})
195
+ quantizer.set_submodule_qconfig_list([
196
+ (get_submodule_type_predicate("Add"), ModuleQConfig(quant_dtype=QuantDtype.use_16a4w))
197
+ ])
198
+ quantizer.add_custom_quant_annotations(...)
199
+ quantizer.add_discard_nodes([node.name to skip annotation])
200
+ quantizer.add_discard_ops([node.target to skip annotation])
201
+ """
202
+
180
203
SUPPORTED_OPS : Set = set (OP_ANNOTATOR .keys ())
181
204
182
205
def __init__ (self ):
@@ -193,6 +216,11 @@ def __init__(self):
193
216
self .discard_nodes : Set [str ] = set ()
194
217
195
218
def _annotate (self , gm : GraphModule ) -> None :
219
+ """
220
+ Annotates the nodes of the provided GraphModule in-place based on user defined quant configs during prepare_pt2e.
221
+
222
+ For each node in the graph, nodes without quant config or those explicitly listed in `self.discard_nodes` are not annotated.
223
+ """
196
224
for node in gm .graph .nodes :
197
225
if node .name in self .discard_nodes :
198
226
continue
@@ -206,18 +234,34 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
206
234
annotation_func (gm )
207
235
208
236
def _get_submodule_qconfig (self , node : torch .fx .Node ):
237
+ """
238
+ Retrieves the `ModuleQConfig` for a given node by matching the first applicable callable function in the `submodule_qconfig_list`.
239
+ You can add submodule-specific quant config using the `set_submodule_qconfig_list` method.
240
+
241
+ Args:
242
+ node (torch.fx.Node): The node for which to retrieve the quant config.
243
+
244
+ Returns:
245
+ ModuleQConfig: The matched submodule config, or the default config if no match is found.
246
+ """
209
247
for func , qconfig in self .submodule_qconfig_list :
210
248
if func (node ):
211
249
return qconfig
212
250
return self .default_quant_config
213
251
214
252
def _get_quant_config (self , node : torch .fx .Node ) -> Optional [QuantizationConfig ]:
215
253
"""
216
- How to pick:
217
- 1. is one of per_block_quant_config
218
- 2. Pick specific submodule config if given.
219
- 3. Pick one if op belongs to use_per_channel_weight_quant_ops
220
- 4. If not 3, pick normal quant config
254
+ Select the quant config for a node based on priority.
255
+
256
+ Priority order:
257
+ 1. Per-block quant config if block_size is set for node.
258
+ 2. Submodule-specific config if predicate matches.
259
+ 3. Per-channel config if op is in per-channel set.
260
+ 4. Default quant config if op is supported.
261
+
262
+ Args:
263
+ node (torch.fx.Node): The node to get quant config for.
264
+
221
265
"""
222
266
op = node .target
223
267
if isinstance (op , str ):
@@ -241,22 +285,49 @@ def _get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]
241
285
def add_custom_quant_annotations (
242
286
self , custom_quant_annotations : Sequence [Callable ]
243
287
) -> None :
288
+ """
289
+ Add custom annotation functions to be applied during prepare_pt2e.
290
+
291
+ Args:
292
+ custom_quant_annotations (Sequence[Callable]): A sequence of functions that take a GraphModule and perform custom annotation.
293
+ """
244
294
self .custom_quant_annotations = custom_quant_annotations
245
295
246
296
def add_discard_nodes (self , nodes : Sequence [str ]) -> None :
297
+ """
298
+ Specifies node IDs to exclude from quantization.
299
+ """
247
300
self .discard_nodes = set (nodes )
248
301
249
302
def add_discard_ops (self , ops : Sequence [OpOverload ]) -> None :
303
+ """
304
+ Specifies OpOverloads to exclude from quantization.
305
+ """
250
306
for op in ops :
251
307
self .quant_ops .remove (op )
252
308
253
309
def annotate (self , model : GraphModule ) -> GraphModule :
310
+ """
311
+ Annotates GraphModule during prepare_pt2e.
312
+
313
+ Args:
314
+ model (GraphModule): The FX GraphModule to annotate.
315
+
316
+ Returns:
317
+ GraphModule: The annotated model.
318
+ """
254
319
self ._annotate (model )
255
320
self ._annotate_custom_annotation (model )
256
321
257
322
return model
258
323
259
324
def get_supported_ops (self ) -> Set [OpOverload ]:
325
+ """
326
+ Returns the set of supported OpOverloads for quantization.
327
+
328
+ Returns:
329
+ Set[OpOverload]: Supported ops.
330
+ """
260
331
return self .SUPPORTED_OPS
261
332
262
333
def set_default_quant_config (
@@ -267,6 +338,17 @@ def set_default_quant_config(
267
338
is_linear_per_channel = False ,
268
339
act_observer = None ,
269
340
) -> None :
341
+ """
342
+ Set the default quant config for quantizer.
343
+
344
+ Args:
345
+ quant_dtype (QuantDtype): Specifies the quantized data type. By default, 8-bit activations and weights (8a8w) are used.
346
+ is_qat (bool, optional): Enables Quantization-Aware Training (QAT) mode. Defaults to Post-Training Quantization (PTQ) mode.
347
+ is_conv_per_channel (bool, optional): Enables per-channel quantization for convolution operations.
348
+ is_linear_per_channel (bool, optional): Enables per-channel quantization for linear (fully connected) operations.
349
+ act_observer (Optional[UniformQuantizationObserverBase], optional): Custom observer for activation quantization. If not specified, the default observer is determined by `QUANT_CONFIG_DICT`.
350
+
351
+ """
270
352
self .default_quant_config = ModuleQConfig (
271
353
quant_dtype ,
272
354
is_qat ,
@@ -276,6 +358,12 @@ def set_default_quant_config(
276
358
)
277
359
278
360
def set_block_size_map (self , block_size_map : Dict [str , Tuple ]) -> None :
361
+ """
362
+ Set the mapping from node names to block sizes for per-block quantization.
363
+
364
+ Args:
365
+ block_size_map (Dict[str, Tuple]): Mapping from node name to block size.
366
+ """
279
367
self .block_size_map = block_size_map
280
368
281
369
def set_submodule_qconfig_list (
@@ -288,6 +376,15 @@ def set_submodule_qconfig_list(
288
376
self .submodule_qconfig_list = submodule_qconfig_list
289
377
290
378
def transform_for_annotation (self , model : GraphModule ) -> GraphModule :
379
+ """
380
+ Applies QNN-specific transformation before annotation during prepare_pt2e.
381
+
382
+ Args:
383
+ model (GraphModule): The FX GraphModule to transform.
384
+
385
+ Returns:
386
+ GraphModule: The transformed model.
387
+ """
291
388
return QnnPassManager ().transform_for_annotation_pipeline (model )
292
389
293
390
def validate (self , model : GraphModule ) -> None :
0 commit comments