Skip to content

Commit 569d040

Browse files
ysiraichipobin6
authored andcommitted
[inductor] Don't specialize split on sizes parameter. (pytorch#141077)
Fix: pytorch#139936 This PR modifies the lowering of `split` operation, so that it won't generate guards, specializing on the sizes parameter. Instead, it specializes on the number of output tensors being generated (i.e. function of the size of the base tensor, and the sizes parameter). As a result, operations such as `chunk` (whose number of output tensors usually is constant given a static chunk number) won't trigger recompiles when varying the size of the base tensor. Pull Request resolved: pytorch#141077 Approved by: https://github.com/ezyang
1 parent 1b6d610 commit 569d040

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

test/inductor/test_torchinductor.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11798,6 +11798,37 @@ def f(m, x):
1179811798
f"Ref:\n{ref_grad_list}\nAct:\n{act_grad_list}",
1179911799
)
1180011800

11801+
def test_chunk_recompiles(self):
11802+
def f(x):
11803+
return x.chunk(4)
11804+
11805+
def run(size):
11806+
input = torch.randn(size)
11807+
expected_out = f(input)
11808+
actual_out = optf(input)
11809+
self.assertEqual(expected_out, actual_out)
11810+
11811+
cnts = CompileCounterWithBackend("inductor")
11812+
optf = torch.compile(f, backend=cnts, fullgraph=True)
11813+
11814+
# The first run should compile once with static shapes.
11815+
run(4)
11816+
self.assertEqual(cnts.frame_count, 1)
11817+
11818+
# Varying the input size should trigger a recompilation.
11819+
# Since the input size is a multiple of 4 (i.e. all runs shall
11820+
# generate 4 output tensors), there should be no further
11821+
# recompilation.
11822+
for i in range(2, 12):
11823+
run(4 * i)
11824+
self.assertEqual(cnts.frame_count, 2)
11825+
11826+
# Input size: 9
11827+
# Yields one less output tensor, which should trigger a
11828+
# recompilation.
11829+
run(9)
11830+
self.assertEqual(cnts.frame_count, 3)
11831+
1180111832

1180211833
@dataclasses.dataclass
1180311834
class TestFailure:

torch/_inductor/lowering.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1670,13 +1670,11 @@ def select(x, dim, idx):
16701670
@register_lowering(aten.split, type_promotion_kind=None)
16711671
def split(x, sizes, dim=0, clamp=True):
16721672
dim = _validate_dim(x, dim, 0)
1673-
if isinstance(sizes, sympy.Expr):
1674-
# TODO: We don't have to guard on sizes per se, but the number
1675-
# of splits must stay constant
1676-
sizes = V.graph.sizevars.evaluate_static_shape(sizes)
1677-
if isinstance(sizes, (int, sympy.Integer)):
1678-
x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim])
1679-
sizes = [sizes] * ((x_size + sizes - 1) // sizes)
1673+
if not isinstance(sizes, (list, tuple)):
1674+
chunks = V.graph.sizevars.evaluate_static_shape(
1675+
FloorDiv(x.get_size()[dim] + sizes - 1, sizes)
1676+
)
1677+
sizes = [sizes] * chunks
16801678
result = []
16811679
start = 0
16821680
for size in sizes:

0 commit comments

Comments
 (0)