Skip to content

Let quantizers add custom meta data to GraphModule #2711

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -1578,6 +1578,38 @@ def forward(self, x):
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)

def test_custom_meta_transform_for_annotation(self):
class TestQuantizer(Quantizer):
def transform_for_annotation(
self, m: torch.fx.GraphModule
) -> torch.fx.GraphModule:
# Make a copy of the graph to ensure that we are using the
# return value of this function.
graph = torch.fx.Graph()
graph.graph_copy(m.graph, {})
model = torch.fx.GraphModule(m, graph)
model.meta["custom"] = {"_test_data": True}
return model

def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
return model

def validate(self, model: torch.fx.GraphModule) -> None:
pass

class M(torch.nn.Module):
def forward(self, x):
return x + 3

m = M().eval()
quantizer = TestQuantizer()
example_inputs = (torch.randn(1, 2, 3, 3),)
m = torch.export.export(m, example_inputs, strict=True).module()
prepared = prepare_pt2e(m, quantizer)

custom_meta = prepared.meta.get("custom", {}).get("_test_data", False)
self.assertTrue(custom_meta)

def test_composable_quantizer_transform_for_annotation(self):
class TestQuantizer1(Quantizer):
def transform_for_annotation(
Expand Down
4 changes: 4 additions & 0 deletions torchao/quantization/pt2e/quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,17 @@ def calibrate(model, data_loader):
model = quantizer.transform_for_annotation(model)
quantizer.annotate(model)
quantizer.validate(model)
# Store the 'custom' meta data if any added by quantizer
annotated_meta = model.meta.get("custom", {})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the intended use for this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to pass on a list of the annotated nodes to the partitioner and then on to the exported programs graph_module in order to validate the quantization folding decisions in the backend when we start to mix INT and FP. The ET part of it (passing on the 'custom' meta to the edge program) is also needed, but otherwise the 'custom' field is handled in the re-exports to survive all the way to that point.

model = prepare(
model,
node_name_to_scope,
is_qat=False,
obs_or_fq_callback=quantizer.prepare_obs_or_fq_callback,
)
model.meta.update(original_graph_meta)
# Update the 'custom' meta data from quantizer
model.meta.update({"custom": annotated_meta})
model = _disallow_eval_train(model)
return model

Expand Down
Loading