Skip to content

[pt2e] Avoid getting model device once per node #2695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions torchao/quantization/pt2e/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@
)
from torch.ao.quantization.fx.utils import (
_get_module,
assert_and_get_unique_device,
collect_producer_nodes,
create_getattr_from_value,
graph_module_from_producer_nodes,
node_arg_is_weight,
)
Expand All @@ -73,7 +71,11 @@

from torchao.quantization.pt2e import FROM_NODE_KEY
from torchao.quantization.pt2e.observer import _is_activation_post_process
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
from torchao.quantization.pt2e.utils import create_getattr_from_value
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_6,
_assert_and_get_unique_device,
)

if TORCH_VERSION_AT_LEAST_2_6:
from torch.fx.traceback import NodeSource, NodeSourceAction
Expand Down Expand Up @@ -132,6 +134,7 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
modules: dict[str, torch.nn.Module],
node_name_to_scope: dict[str, tuple[str, type]],
node_name_to_qconfig: dict[str, QConfigAny],
model_device: Optional[torch.device] = None,
) -> None:
"""Replace activation_post_process module call node with quantize and
dequantize node working with decomposed Tensor
Expand Down Expand Up @@ -260,7 +263,11 @@ def add_quantize_dequantize_node_info(qdq_node, original_node):
# sure that the default overload can be used.
# TODO: maybe need more complex attr name here
qparam_node = create_getattr_from_value(
model, graph, module_path + prefix + key, value_or_node
model,
graph,
module_path + prefix + key,
value_or_node,
model_device,
)
quantize_op_inputs.append(qparam_node)
else:
Expand Down Expand Up @@ -407,6 +414,7 @@ def _replace_observer_with_quantize_dequantize_node(
modules: dict[str, torch.nn.Module],
node_name_to_scope: dict[str, tuple[str, type]],
node_name_to_qconfig: dict[str, QConfigAny],
model_device: Optional[torch.device] = None,
) -> None:
"""Replace activation_post_process module call node with quantize and
dequantize node
Expand Down Expand Up @@ -487,7 +495,11 @@ def _replace_observer_with_quantize_dequantize_node(
# For scale and zero_point values we register them as buffers in the root module.
# TODO: maybe need more complex attr name here
qparam_node = create_getattr_from_value(
model, graph, module_path + prefix + key, value_or_node
model,
graph,
module_path + prefix + key,
value_or_node,
model_device,
)
quantize_op_inputs.append(qparam_node)
else:
Expand Down Expand Up @@ -785,6 +797,7 @@ def convert_weighted_module(
backend_config: BackendConfig,
is_decomposed: bool = False,
is_reference: bool = False,
model_device: Optional[torch.device] = None,
) -> None:
"""Convert a weighted module to reference quantized module in the model
If the QConfig of a QAT module is not set, the module will still be converted to
Expand Down Expand Up @@ -873,7 +886,10 @@ def convert_weighted_module(
is_ptq = weight_post_process is None
if is_ptq:
weight_post_process = qconfig.weight() # type: ignore[union-attr, operator]
device = assert_and_get_unique_device(float_module)
if model_device is not None:
device = model_device
else:
device = _assert_and_get_unique_device(float_module)
if device:
weight_post_process.to(device)

Expand Down Expand Up @@ -1076,6 +1092,7 @@ def convert(
root_module_classes = tuple(root_module_to_quantized_reference_module.keys())
qat_module_classes = get_qat_module_classes(backend_config)
fused_module_classes = get_fused_module_classes(backend_config)
model_device = _assert_and_get_unique_device(model)

for node in list(model.graph.nodes):
if node.op == "placeholder":
Expand Down Expand Up @@ -1123,6 +1140,7 @@ def convert(
modules,
node_name_to_scope,
node_name_to_qconfig,
model_device,
)
else:
_replace_observer_with_quantize_dequantize_node(
Expand All @@ -1131,6 +1149,7 @@ def convert(
modules,
node_name_to_scope,
node_name_to_qconfig,
model_device,
)
elif isinstance(mod, DeQuantStub):
_replace_observer_or_dequant_stub_with_dequantize_node(
Expand Down Expand Up @@ -1160,6 +1179,7 @@ def convert(
backend_config,
is_decomposed,
is_reference,
model_device,
)

# remove deadcode after converting observers to quant/dequant ops
Expand Down
12 changes: 10 additions & 2 deletions torchao/quantization/pt2e/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1915,10 +1915,18 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node):
else:
scale, zero_point = self.calculate_qparams()
scale_node = create_getattr_from_value(
model, model.graph, "_scale", scale
model,
model.graph,
"_scale",
scale,
scale.device,
)
zero_point_node = create_getattr_from_value(
model, model.graph, "_zero_point", zero_point
model,
model.graph,
"_zero_point",
zero_point,
zero_point.device,
)

q_node = model.graph.call_function(
Expand Down
26 changes: 23 additions & 3 deletions torchao/quantization/pt2e/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
SharedQuantizationSpec,
)
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6, _assert_and_get_unique_device

# TODO: make pt2e folder private?
__all__ = [
Expand Down Expand Up @@ -409,6 +409,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
named_modules: dict[str, torch.nn.Module],
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
is_qat: bool,
model_device: Optional[torch.device] = None,
) -> Argument:
"""
Given a `node` and an `arg`, inserts an input observer between
Expand All @@ -427,6 +428,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
named_modules,
obs_or_fq_map,
is_qat,
model_device,
)
new_arg_to_return.append(new_inner_arg)
return type(arg)(new_arg_to_return)
Expand Down Expand Up @@ -479,6 +481,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
return maybe_obs_node

assert isinstance(model.graph, Graph)
# TODO: pass in model_device here after https://github.com/pytorch/pytorch/pull/159901
new_arg = _insert_obs_or_fq(
arg, input_edge_obs_or_fq, model, named_modules, model.graph
)
Expand All @@ -492,6 +495,7 @@ def _maybe_insert_input_observers_for_node(
named_modules: dict[str, torch.nn.Module],
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
is_qat: bool,
model_device: Optional[torch.device] = None,
) -> None:
"""
If needed, inserts observers to the input args and kwargs of `node`.
Expand All @@ -518,6 +522,7 @@ def _maybe_insert_input_observers_for_node(
named_modules,
obs_or_fq_map,
is_qat,
model_device,
)
new_args.append(new_arg)

Expand All @@ -542,9 +547,11 @@ def _maybe_insert_output_observer_for_node(
graph: Graph,
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
is_qat: bool,
model_device: Optional[torch.device] = None,
) -> Optional[Node]:
if node in obs_or_fq_map:
output_act_obs_or_fq = obs_or_fq_map[node]
# TODO: pass in model_device here after https://github.com/pytorch/pytorch/pull/159901
new_output = _insert_obs_or_fq(
node, output_act_obs_or_fq, model, named_modules, graph
)
Expand All @@ -565,6 +572,7 @@ def _maybe_insert_input_and_output_observers_for_node(
model: torch.fx.GraphModule,
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
is_qat: bool,
model_device: Optional[torch.device] = None,
):
this_node_quantization_annotation = (
node.meta[Q_ANNOTATION_KEY] if Q_ANNOTATION_KEY in node.meta else None
Expand All @@ -580,6 +588,7 @@ def _maybe_insert_input_and_output_observers_for_node(
named_modules,
obs_or_fq_map,
is_qat,
model_device,
)

output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor)
Expand All @@ -588,7 +597,13 @@ def _maybe_insert_input_and_output_observers_for_node(

# this returns the new observer node if it was needed
maybe_output_obs_node = _maybe_insert_output_observer_for_node(
node, model, named_modules, model.graph, obs_or_fq_map, is_qat
node,
model,
named_modules,
model.graph,
obs_or_fq_map,
is_qat,
model_device,
)

if maybe_output_obs_node is None:
Expand Down Expand Up @@ -636,11 +651,16 @@ def prepare(
)
if obs_or_fq_callback:
obs_or_fq_callback(model, obs_or_fq_map)
model_device = _assert_and_get_unique_device(model)

for node in nodes_before_observation:
# TODO: simplify logic for inserting observers
_maybe_insert_input_and_output_observers_for_node(
node, model, obs_or_fq_map, is_qat
node,
model,
obs_or_fq_map,
is_qat,
model_device,
)

model = GraphModule(model, model.graph)
Expand Down
9 changes: 7 additions & 2 deletions torchao/quantization/pt2e/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,15 +525,20 @@ def get_attr_name(i: int):


def create_getattr_from_value(
module: torch.nn.Module, graph: Graph, prefix: str, value: Any
module: torch.nn.Module,
graph: Graph,
prefix: str,
value: Any,
device: Optional[torch.device] = None,
) -> Node:
"""
Given a value of any type, creates a getattr node corresponding to the value and
registers the value as a buffer to the module.
"""
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
attr_name = get_new_attr_name(module)
device = _assert_and_get_unique_device(module)
if device is None:
device = _assert_and_get_unique_device(module)
new_value = (
value.detach().clone()
if isinstance(value, torch.Tensor)
Expand Down
Loading