Skip to content

Commit 9e4bdc1

Browse files
blaine-risterChao1Han
authored andcommitted
[Inductor-FX] Support Tensor.item (pytorch#165599)
# Feature This PR supports compiling `Tensor.item` with Inductor's FX backend. This maps to a custom WrapperCodeGen method called `codegen_dynamic_scalar`. # Implementation The implementation is fairly mechanical, following the usual flow for these types of PRs. 1. Introduce a new Wrapper IR line for this, called `DynamicScalarLine`. 2. Split `PythonWrapperCodegen.codegen_dynamic_scalar` into 2 parts: a public method which generates the Wrapper IR line, and a private one generating Python from Wrapper IR. 3. Implement an FX codegen method for the wrapper IR line. This one calls `aten.where.Scalar` to handle code like `1 if x.item() else 0`, which is a bit tricky. It also calls `aten.item.default` to convert tensors to scalars. # Test plan Added CI tests mirroring the AOTI ones. They test float, int and bool types, the latter taking a distinct codegen path. Pull Request resolved: pytorch#165599 Approved by: https://github.com/angelayi, https://github.com/jansel
1 parent aef6d9e commit 9e4bdc1

File tree

4 files changed

+68
-1
lines changed

4 files changed

+68
-1
lines changed

test/inductor/test_fxir_backend.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,22 @@ def forward(self, x):
10341034
x = torch.randn(7, device=self.device)
10351035
self.check(M(), (x,), dynamic_shapes=({0: Dim.DYNAMIC},))
10361036

1037+
@parametrize("dynamic", (False, True))
1038+
@parametrize("input_", (1.5, 2, False))
1039+
def test_item(self, input_, dynamic: bool):
1040+
"""
1041+
Test calling Tensor.item.
1042+
"""
1043+
1044+
class M(torch.nn.Module):
1045+
def forward(self, x):
1046+
return x[1].item()
1047+
1048+
x = torch.tensor((input_,) * 10)
1049+
d = Dim("s0", min=1)
1050+
dynamic_shapes = ({0: 2 * d},) if dynamic else None
1051+
self.check(M(), (x,), dynamic_shapes=dynamic_shapes)
1052+
10371053
@parametrize("pred", (False, True))
10381054
def test_mismatched_branch_dynamic(self, pred: bool):
10391055
"""

torch/_inductor/codegen/cpp_wrapper_cpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1485,7 +1485,7 @@ def _generate_symbolic_call_arg_helper(
14851485
else:
14861486
self.writeline(f"{arg.inner} = {cexpr(arg.inner_expr)};")
14871487

1488-
def codegen_dynamic_scalar(self, node):
1488+
def _codegen_dynamic_scalar(self, node):
14891489
(data,) = (t.codegen_reference() for t in node.inputs)
14901490
self.codegen_tensor_item(node.inputs[0].get_dtype(), data, f"{node.sym}_raw")
14911491

torch/_inductor/codegen/wrapper.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,19 @@ def codegen_fx(converter: FxConverter) -> FxConversionFunc:
415415
return converter._generate_comment
416416

417417

418+
@dataclasses.dataclass
419+
class DynamicScalarLine(WrapperLine):
420+
wrapper: PythonWrapperCodegen
421+
node: ir.DynamicScalar
422+
423+
def codegen(self, code: IndentedBuffer) -> None:
424+
self.wrapper._codegen_dynamic_scalar(self.node)
425+
426+
@staticmethod
427+
def codegen_fx(converter: FxConverter) -> FxConversionFunc:
428+
return converter._generate_dynamic_scalar
429+
430+
418431
@dataclasses.dataclass
419432
class ExitSubgraphLine(WrapperLine):
420433
wrapper: PythonWrapperCodegen
@@ -2060,6 +2073,9 @@ def codegen_with_step(start_var, end_var, step):
20602073
self.unbacked_symbol_decls.add(str(node.unbacked_size_symbol))
20612074

20622075
def codegen_dynamic_scalar(self, node):
2076+
self.writeline(DynamicScalarLine(self, node))
2077+
2078+
def _codegen_dynamic_scalar(self, node):
20632079
(data,) = (t.codegen_reference() for t in node.inputs)
20642080
if len(node.keypath) == 0:
20652081
self.writeline(f"{node.sym} = {data}.item()")

torch/_inductor/codegen/wrapper_fxir.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from torch.fx import GraphModule
3030
from torch.fx.experimental.symbolic_shapes import (
3131
CallMethodKey,
32+
ConvertIntKey,
3233
DivideByKey,
3334
free_unbacked_symbols,
3435
)
@@ -54,6 +55,7 @@
5455
CommBufferFreeLine,
5556
CommentLine,
5657
ConditionalLine,
58+
DynamicScalarLine,
5759
EnterDeviceContextManagerLine,
5860
EnterSubgraphLine,
5961
ExitDeviceContextManagerLine,
@@ -738,6 +740,39 @@ def _generate_comment(self, line: WrapperLine) -> None:
738740
assert isinstance(line, CommentLine)
739741
# We ignore comments in FX IR.
740742

743+
def _generate_dynamic_scalar(self, line: WrapperLine) -> None:
744+
assert isinstance(line, DynamicScalarLine)
745+
746+
ir_node = line.node
747+
(input_ir_node,) = ir_node.inputs
748+
assert isinstance(input_ir_node, ir.IRNode)
749+
input_fx_node = self._generate_buffer(input_ir_node)
750+
keypath = ir_node.keypath
751+
graph = self.gm.graph
752+
753+
def generate_item(x: Optional[torch.fx.Node]) -> torch.fx.Node:
754+
assert x is not None
755+
return graph.call_function(
756+
aten.item.default,
757+
args=(x,),
758+
)
759+
760+
if len(keypath) == 0:
761+
result_fx_node = generate_item(input_fx_node)
762+
elif len(keypath) == 1 and isinstance(keypath[0], ConvertIntKey):
763+
where_fx_node = graph.call_function(
764+
aten.where.Scalar,
765+
args=(input_fx_node, 1, 0),
766+
)
767+
result_fx_node = generate_item(where_fx_node)
768+
else:
769+
raise NotImplementedError(f"Unsupported keypath: {keypath}")
770+
771+
result_symbol = ir_node.sym
772+
result_buffer = SymbolBuffer(result_symbol)
773+
self._record_allocation(result_buffer, result_fx_node)
774+
self._generate_size_proxy(result_fx_node, result_symbol)
775+
741776
def _generate_enter_device_context_manager(self, line: WrapperLine) -> None:
742777
assert isinstance(line, EnterDeviceContextManagerLine)
743778
# We ignore the device context in FX IR.

0 commit comments

Comments
 (0)