Skip to content

Commit 63b1790

Browse files
sxupobin6
authored andcommitted
Add prepare_obs_or_fq_callback to quantizer (pytorch#140863)
Test Plan: CI. Differential Revision: D65982003 Pull Request resolved: pytorch#140863 Approved by: https://github.com/jerryzh168
1 parent 0744d5f commit 63b1790

File tree

4 files changed

+107
-3
lines changed

4 files changed

+107
-3
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Owner(s): ["oncall: quantization"]
2-
from typing import List, Tuple
2+
from typing import Dict, List, Tuple
33

44
import torch
55
from torch import Tensor
@@ -18,6 +18,7 @@
1818
)
1919
from torch.ao.quantization.quantizer import (
2020
DerivedQuantizationSpec,
21+
EdgeOrNode,
2122
FixedQParamsQuantizationSpec,
2223
QuantizationAnnotation,
2324
QuantizationSpec,
@@ -2339,6 +2340,76 @@ def forward(self, x):
23392340
m = convert_pt2e(m)
23402341
m(*example_inputs)
23412342

2343+
def test_prepare_obs_or_fq_callback(self):
2344+
class Model(torch.nn.Module):
2345+
def forward(self, x):
2346+
x = torch.nn.functional.max_pool2d(x, 2, 2)
2347+
x = torch.nn.functional.pixel_shuffle(x, 2)
2348+
return x.permute(0, 2, 3, 1)
2349+
2350+
class BackendAQuantizer(Quantizer):
2351+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
2352+
act_qspec = QuantizationSpec(
2353+
dtype=torch.uint8,
2354+
quant_min=0,
2355+
quant_max=255,
2356+
qscheme=torch.per_tensor_affine,
2357+
is_dynamic=False,
2358+
observer_or_fake_quant_ctr=observer.default_observer,
2359+
)
2360+
for node in model.graph.nodes:
2361+
if node.op == "call_function" and node.target in (
2362+
torch.ops.aten.max_pool2d.default,
2363+
torch.ops.aten.permute.default,
2364+
torch.ops.aten.pixel_shuffle.default,
2365+
):
2366+
node.meta["quantization_annotation"] = QuantizationAnnotation(
2367+
input_qspec_map={
2368+
node.args[0]: act_qspec,
2369+
},
2370+
output_qspec=SharedQuantizationSpec((node.args[0], node)),
2371+
_annotated=True,
2372+
)
2373+
2374+
def validate(self, model: torch.fx.GraphModule) -> None:
2375+
pass
2376+
2377+
def prepare_obs_or_fq_callback(
2378+
self,
2379+
model: torch.fx.GraphModule,
2380+
edge_or_node_to_obs_or_fq: Dict[EdgeOrNode, ObserverOrFakeQuantize],
2381+
) -> None:
2382+
# hard code output quant by updating entire sharing group
2383+
output_node = next(n for n in model.graph.nodes if n.op == "output")
2384+
output_value = output_node.args[0][0]
2385+
old_observer = edge_or_node_to_obs_or_fq[output_value]
2386+
sharing_group = [
2387+
k for k, v in edge_or_node_to_obs_or_fq.items() if v is old_observer
2388+
]
2389+
new_observer = observer.FixedQParamsObserver(
2390+
scale=0.125,
2391+
zero_point=42,
2392+
dtype=torch.uint8,
2393+
quant_min=0,
2394+
quant_max=255,
2395+
qscheme=torch.per_tensor_affine,
2396+
)
2397+
for x in sharing_group:
2398+
edge_or_node_to_obs_or_fq[x] = new_observer
2399+
2400+
example_inputs = (torch.rand(1, 32, 16, 16),)
2401+
gm = export_for_training(Model().eval(), example_inputs).module()
2402+
gm = prepare_pt2e(gm, BackendAQuantizer())
2403+
gm = convert_pt2e(gm)
2404+
for n in gm.graph.nodes:
2405+
if n.op == "call_function" and n.target in (
2406+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
2407+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2408+
):
2409+
# Entire graph share the same qspec which was overriden by FixedQParamsObserver
2410+
self.assertEqual(n.args[1], 0.125)
2411+
self.assertEqual(n.args[2], 42)
2412+
23422413

23432414
instantiate_parametrized_tests(TestQuantizePT2E)
23442415

torch/ao/quantization/pt2e/prepare.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,7 @@ def prepare(
535535
model: GraphModule,
536536
node_name_to_scope: Dict[str, Tuple[str, type]],
537537
is_qat: bool,
538+
obs_or_fq_callback=None,
538539
) -> GraphModule:
539540
# Since we are mutating the graph as we go, we iterate over the original
540541
# nodes before observer insertion, instead of model.graph.nodes.
@@ -549,6 +550,8 @@ def prepare(
549550
obs_or_fq_map = _get_obs_or_fq_map(
550551
edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat
551552
)
553+
if obs_or_fq_callback:
554+
obs_or_fq_callback(model, obs_or_fq_map)
552555

553556
for node in nodes_before_observation:
554557
# TODO: simplify logic for inserting observers

torch/ao/quantization/quantize_pt2e.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,12 @@ def calibrate(model, data_loader):
9999
model = quantizer.transform_for_annotation(model)
100100
quantizer.annotate(model)
101101
quantizer.validate(model)
102-
model = prepare(model, node_name_to_scope, is_qat=False)
102+
model = prepare(
103+
model,
104+
node_name_to_scope,
105+
is_qat=False,
106+
obs_or_fq_callback=quantizer.prepare_obs_or_fq_callback,
107+
)
103108
model.meta.update(original_graph_meta)
104109
model = _disallow_eval_train(model)
105110
return model
@@ -172,7 +177,12 @@ def train_loop(model, train_data):
172177
# subgraph that don't need to be quantized
173178
# TODO: only fuse if conv and bn are both configured to be quantized
174179
_fuse_conv_bn_qat(model)
175-
model = prepare(model, node_name_to_scope, is_qat=True)
180+
model = prepare(
181+
model,
182+
node_name_to_scope,
183+
is_qat=True,
184+
obs_or_fq_callback=quantizer.prepare_obs_or_fq_callback,
185+
)
176186
model.meta.update(original_graph_meta)
177187
model = _disallow_eval_train(model)
178188
return model

torch/ao/quantization/quantizer/quantizer.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,23 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
159159
@abstractmethod
160160
def validate(self, model: torch.fx.GraphModule) -> None:
161161
pass
162+
163+
def prepare_obs_or_fq_callback(
164+
self,
165+
model: torch.fx.GraphModule,
166+
edge_or_node_to_obs_or_fq: Dict[EdgeOrNode, ObserverOrFakeQuantize],
167+
) -> None:
168+
"""A callback that will be called after the observers or fake quants are created
169+
for each sharing group, but before they are inserted into the graph. The
170+
callback can be used to make final quantization adjustments, such as enforcing
171+
specific scale and zero point on model input or output.
172+
173+
Args:
174+
* `model`: the graph module being prepared.
175+
* `edge_or_node_to_obs_or_fq`: a dictionary mapping each annotated edge and
176+
node to the corresponding observer or fake quant object. Note that multiple
177+
edges and/or nodes can map to the same observer / fake quant instance if
178+
they were annotated with SharedQuantizationSpec. This dictionary can be
179+
modified by the callback.
180+
"""
181+
return

0 commit comments

Comments
 (0)