Skip to content

Commit 0744d5f

Browse files
janselpobin6
authored andcommitted
[inductor] Refactor MutableBox to make IRNode typing easier (pytorch#140895)
Pull Request resolved: pytorch#140895 Approved by: https://github.com/ezyang, https://github.com/Skylion007
1 parent a59baaa commit 0744d5f

File tree

12 files changed

+303
-142
lines changed

12 files changed

+303
-142
lines changed

torch/_inductor/codegen/common.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1795,8 +1795,6 @@ def construct_input(inp):
17951795
if isinstance(inp, torch._prims_common.Number):
17961796
return inp
17971797
else:
1798-
assert hasattr(inp, "dtype")
1799-
18001798
# construct a tmp tensor to use dtype promotion util function
18011799
return torch.empty([1], dtype=inp.dtype)
18021800

torch/_inductor/codegen/cpp_wrapper_cpu.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import sys
66
from itertools import count
7-
from typing import Callable, Dict, List, Optional, Tuple
7+
from typing import Callable, Dict, List, Optional, Sequence, Tuple
88

99
import sympy
1010
from sympy import Expr
@@ -1106,8 +1106,8 @@ def codegen_tuple_access(self, basename: str, name: str, index: str) -> str:
11061106
# in the abi_compatible mode, outputs are returned via arguments
11071107
return name
11081108

1109-
def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
1110-
parts = list(map(self.codegen_sizevar, shape))
1109+
def codegen_shape_tuple(self, shape: Sequence[Expr]) -> str:
1110+
parts = [*map(self.codegen_sizevar, shape)]
11111111
if len(parts) == 0:
11121112
return "{}"
11131113
if len(parts) == 1:
@@ -1904,7 +1904,7 @@ def generate_fallback_kernel_with_runtime_lookup_jit(
19041904
py_args_var = f"py_args_{next(self.arg_var_id)}"
19051905
# First arg is always the python op name
19061906
lines = f"""
1907-
RAIIPyObject {py_args_var}(PyTuple_New({num_args+1}));
1907+
RAIIPyObject {py_args_var}(PyTuple_New({num_args + 1}));
19081908
if ({py_args_var}.get() == NULL) {{
19091909
throw std::runtime_error("PyTuple_New {py_args_var} failed");
19101910
}}

torch/_inductor/codegen/cpp_wrapper_gpu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch._inductor.runtime.triton_heuristics import grid as default_grid_fn
1212

1313
from ..codecache import CudaKernelParamCache
14+
from ..ir import IRNode
1415
from ..utils import DeferredLineBase, get_gpu_type
1516
from ..virtualized import V
1617
from .aoti_hipify_utils import maybe_hipify_code_wrapper
@@ -261,7 +262,7 @@ def generate_user_defined_triton_kernel(
261262
]
262263
args = [self.val_to_arg_str(v) for v in raw_args]
263264
arg_types = [
264-
arg.get_dtype() if hasattr(arg, "get_dtype") else type(arg)
265+
arg.get_dtype() if isinstance(arg, IRNode) else type(arg)
265266
for arg in raw_args
266267
]
267268
self.generate_kernel_call(

torch/_inductor/codegen/wrapper.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Iterator,
2121
List,
2222
Optional,
23+
Sequence,
2324
Set,
2425
Tuple,
2526
TYPE_CHECKING,
@@ -43,7 +44,7 @@
4344

4445
from .. import async_compile, config, ir
4546
from ..codecache import output_code_log
46-
from ..ir import ReinterpretView
47+
from ..ir import IRNode, ReinterpretView
4748
from ..runtime import triton_heuristics
4849
from ..runtime.hints import DeviceProperties
4950
from ..utils import (
@@ -1016,7 +1017,7 @@ def generate_user_defined_triton_kernel(
10161017

10171018
args = [self.val_to_arg_str(v) for v in raw_args]
10181019
arg_types = [
1019-
arg.get_dtype() if hasattr(arg, "get_dtype") else type(arg)
1020+
arg.get_dtype() if isinstance(arg, IRNode) else type(arg)
10201021
for arg in raw_args
10211022
]
10221023
self.generate_kernel_call(
@@ -1306,15 +1307,15 @@ def codegen_sizevar(self, x: Expr) -> str:
13061307
def codegen_tuple_access(self, basename: str, name: str, index: str) -> str:
13071308
return f"{basename}[{index}]"
13081309

1309-
def codegen_python_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
1310-
parts = list(map(self.codegen_python_sizevar, shape))
1310+
def codegen_python_shape_tuple(self, shape: Sequence[Expr]) -> str:
1311+
parts = [*map(self.codegen_python_sizevar, shape)]
13111312
if len(parts) == 0:
13121313
return "()"
13131314
if len(parts) == 1:
13141315
return f"({parts[0]}, )"
13151316
return f"({', '.join(parts)})"
13161317

1317-
def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
1318+
def codegen_shape_tuple(self, shape: Sequence[Expr]) -> str:
13181319
return self.codegen_python_shape_tuple(shape)
13191320

13201321
def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str:

torch/_inductor/comms.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -272,17 +272,10 @@ def node_summary(snode):
272272
if isinstance(snode.node, ir.ExternKernelOut):
273273
detail = f" ({snode.node.python_kernel_name})"
274274
out_tensor_info = ""
275-
if (
276-
hasattr(snode.node, "layout")
277-
and hasattr(snode.node.layout, "size")
278-
and hasattr(snode.node.layout, "stride")
279-
):
280-
out_tensor_info = (
281-
f" (size={snode.node.layout.size}, stride={snode.node.layout.stride})"
282-
)
283-
node_name = ""
284-
if hasattr(snode.node, "name"):
285-
node_name = snode.node.name
275+
layout = snode.node.maybe_get_layout()
276+
if isinstance(layout, ir.Layout):
277+
out_tensor_info = f" (size={layout.size}, stride={layout.stride})"
278+
node_name = snode.node.maybe_get_name() or ""
286279
return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name})"
287280

288281

torch/_inductor/compile_fx.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
from .fx_passes.post_grad import post_grad_passes, view_to_reshape
9494
from .fx_passes.pre_grad import pre_grad_passes
9595
from .graph import GraphLowering
96+
from .ir import get_device_type, IRNode
9697
from .utils import (
9798
align_inputs_from_check_idxs,
9899
clone_preserve_strides,
@@ -1818,24 +1819,19 @@ def warn_and_skip(device: torch.device) -> Never:
18181819
)
18191820
raise SkipFrame("BF16 is not supported")
18201821

1821-
for inp in graph.graph_inputs.values():
1822-
device = getattr(inp, "get_device", lambda: torch.device("meta"))()
1823-
if (not is_gpu(device.type)) or inp.get_dtype() != torch.bfloat16:
1822+
for node in itertools.chain(graph.graph_inputs.values(), graph.graph_outputs):
1823+
if not isinstance(node, IRNode):
18241824
continue
1825-
# Print warning and skip frame if attempting to compile for bfloat16
1826-
# on device without hardware support for dtype
1827-
device_interface = get_interface_for_device(device.type)
1828-
if device_interface.is_bf16_supported(including_emulation=False):
1829-
return
1830-
warn_and_skip(device)
1831-
1832-
for out in graph.graph_outputs:
1833-
device = getattr(out, "get_device", lambda: torch.device("meta"))()
1834-
if (not is_gpu(device.type)) or out.get_dtype() != torch.bfloat16:
1825+
device_type = get_device_type(node)
1826+
if (
1827+
not device_type
1828+
or not is_gpu(device_type)
1829+
or node.get_dtype() != torch.bfloat16
1830+
):
18351831
continue
18361832
# Print warning and skip frame if attempting to compile for bfloat16
18371833
# on device without hardware support for dtype
1838-
device_interface = get_interface_for_device(device.type)
1834+
device_interface = get_interface_for_device(device_type)
18391835
if device_interface.is_bf16_supported(including_emulation=False):
18401836
return
1841-
warn_and_skip(device)
1837+
warn_and_skip(node.get_device())

torch/_inductor/dependencies.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -637,12 +637,15 @@ def extract_input_node_reduction_ranges(
637637
Otherwise returns (None, None).
638638
"""
639639

640-
from .ir import ComputedBuffer, Loops
640+
from .ir import ComputedBuffer, ExternKernel, Loops
641+
642+
size: Optional[List[sympy.Expr]]
643+
reduction_size: Optional[List[sympy.Expr]]
641644

642645
if isinstance(input_node.data, ComputedBuffer):
643646
# Input node has already been realized. Return its size and reduction_size.
644-
size = input_node.get_size()
645-
reduction_size = input_node.get_reduction_size()
647+
size = [*input_node.get_size()]
648+
reduction_size = [*input_node.get_reduction_size()]
646649
if len(reduction_size) > 0:
647650
return (size, reduction_size)
648651
else:
@@ -660,7 +663,7 @@ def extract_input_node_reduction_ranges(
660663
size = None
661664
while reduction_size is None and len(reads) > 0:
662665
seen: OrderedSet[str] = OrderedSet()
663-
new_reads = []
666+
new_reads: List[Dep] = []
664667
for read in reads:
665668
if not isinstance(read, MemoryDep):
666669
continue
@@ -671,7 +674,7 @@ def extract_input_node_reduction_ranges(
671674
if buffer is None:
672675
continue
673676
op = buffer.get_defining_op()
674-
if op is None:
677+
if op is None or isinstance(op, ExternKernel):
675678
continue
676679

677680
if isinstance(op, ComputedBuffer) and len(op.get_reduction_size()) > 0:
@@ -685,7 +688,7 @@ def extract_input_node_reduction_ranges(
685688
if reads == new_reads:
686689
return (size, reduction_size)
687690
else:
688-
reads = new_reads
691+
reads = OrderedSet(new_reads)
689692
return (size, reduction_size)
690693

691694

torch/_inductor/graph.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,11 +1446,7 @@ def debug(msg: str) -> None:
14461446
result.realize()
14471447
strides = n.meta["val"].stride()
14481448
sym_strides = torch._inductor.utils.any_is_symbolic(*strides)
1449-
if (
1450-
not hasattr(result, "get_stride")
1451-
or result.get_stride() != strides
1452-
and not sym_strides
1453-
):
1449+
if result.maybe_get_stride() != strides and not sym_strides:
14541450
stride_order = ir.get_stride_order(strides)
14551451
result = ir.ExternKernel.require_stride_order(result, stride_order)
14561452
if (

0 commit comments

Comments
 (0)