From 3b487e3af269e748c145fd814b9e010ad35ba2b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Wed, 6 Aug 2025 13:56:46 +0200 Subject: [PATCH] Let quantizers add custom meta data to GraphModule MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Make the 'custom' meta data field in the GraphModule be propageted with any data that is added by the quantizer. Signed-off-by: Per Åstrand --- test/quantization/pt2e/test_quantize_pt2e.py | 32 ++++++++++++++++++++ torchao/quantization/pt2e/quantize_pt2e.py | 4 +++ 2 files changed, 36 insertions(+) diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 19f208a55c..6fa066dd64 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -1578,6 +1578,38 @@ def forward(self, x): } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + def test_custom_meta_transform_for_annotation(self): + class TestQuantizer(Quantizer): + def transform_for_annotation( + self, m: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + # Make a copy of the graph to ensure that we are using the + # return value of this function. + graph = torch.fx.Graph() + graph.graph_copy(m.graph, {}) + model = torch.fx.GraphModule(m, graph) + model.meta["custom"] = {"_test_data": True} + return model + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + return model + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + class M(torch.nn.Module): + def forward(self, x): + return x + 3 + + m = M().eval() + quantizer = TestQuantizer() + example_inputs = (torch.randn(1, 2, 3, 3),) + m = torch.export.export(m, example_inputs, strict=True).module() + prepared = prepare_pt2e(m, quantizer) + + custom_meta = prepared.meta.get("custom", {}).get("_test_data", False) + self.assertTrue(custom_meta) + def test_composable_quantizer_transform_for_annotation(self): class TestQuantizer1(Quantizer): def transform_for_annotation( diff --git a/torchao/quantization/pt2e/quantize_pt2e.py b/torchao/quantization/pt2e/quantize_pt2e.py index 5eb385b7de..7b5e0d97a1 100644 --- a/torchao/quantization/pt2e/quantize_pt2e.py +++ b/torchao/quantization/pt2e/quantize_pt2e.py @@ -116,6 +116,8 @@ def calibrate(model, data_loader): model = quantizer.transform_for_annotation(model) quantizer.annotate(model) quantizer.validate(model) + # Store the 'custom' meta data if any added by quantizer + annotated_meta = model.meta.get("custom", {}) model = prepare( model, node_name_to_scope, @@ -123,6 +125,8 @@ def calibrate(model, data_loader): obs_or_fq_callback=quantizer.prepare_obs_or_fq_callback, ) model.meta.update(original_graph_meta) + # Update the 'custom' meta data from quantizer + model.meta.update({"custom": annotated_meta}) model = _disallow_eval_train(model) return model