Skip to content

Commit cd2be98

Browse files
ysiraichipobin6
authored andcommitted
[inductor] Don't clamp on split operation. (pytorch#141078)
This PR turns clamping off for the `split` operation. By doing so, we generate less bound guards and reduce the number of recompilation when varying the input size. ```python @torch.compile(dynamic=True) def f(x): return x.chunk(4) >>> f(torch.arange(12)) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10, 11])) >>> f(torch.arange(11)) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10])) >>> f(torch.arange(10)) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([9])) ``` Pull Request resolved: pytorch#141078 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#141077
1 parent 569d040 commit cd2be98

File tree

2 files changed

+34
-7
lines changed

2 files changed

+34
-7
lines changed

test/inductor/test_torchinductor.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11802,6 +11802,8 @@ def test_chunk_recompiles(self):
1180211802
def f(x):
1180311803
return x.chunk(4)
1180411804

11805+
# Runs f and its torch.compile-d version with a fresh 1D tensor
11806+
# of a specific size, and checks that the result is correct.
1180511807
def run(size):
1180611808
input = torch.randn(size)
1180711809
expected_out = f(input)
@@ -11823,11 +11825,24 @@ def run(size):
1182311825
run(4 * i)
1182411826
self.assertEqual(cnts.frame_count, 2)
1182511827

11828+
# Input size: 11
11829+
# Not a multiple of 4, but still generates 4 output tensors,
11830+
# where the last one has size > 1.
11831+
run(11)
11832+
self.assertEqual(cnts.frame_count, 2)
11833+
11834+
# Input size: 10
11835+
# Even though it still generates 4 output tensors, the last
11836+
# one has size 1, falling into our 0/1 specialization. Thus,
11837+
# this one also triggers recompilation.
11838+
run(10)
11839+
self.assertEqual(cnts.frame_count, 3)
11840+
1182611841
# Input size: 9
1182711842
# Yields one less output tensor, which should trigger a
1182811843
# recompilation.
1182911844
run(9)
11830-
self.assertEqual(cnts.frame_count, 3)
11845+
self.assertEqual(cnts.frame_count, 4)
1183111846

1183211847

1183311848
@dataclasses.dataclass

torch/_inductor/lowering.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,25 +1668,37 @@ def select(x, dim, idx):
16681668

16691669

16701670
@register_lowering(aten.split, type_promotion_kind=None)
1671-
def split(x, sizes, dim=0, clamp=True):
1671+
def split(x, sizes, dim=0):
16721672
dim = _validate_dim(x, dim, 0)
1673+
sizes_ = sizes
1674+
1675+
# If sizes is an integer (or a SymInt), we turn it into a list of sizes
1676+
# by computing what the actual size of each chunk should be.
16731677
if not isinstance(sizes, (list, tuple)):
1678+
x_size = x.get_size()[dim]
16741679
chunks = V.graph.sizevars.evaluate_static_shape(
1675-
FloorDiv(x.get_size()[dim] + sizes - 1, sizes)
1680+
FloorDiv(x_size + sizes - 1, sizes)
16761681
)
1677-
sizes = [sizes] * chunks
1682+
sizes_ = [sizes] * chunks
1683+
# The last chunk might have a smaller size than the rest.
1684+
sizes_[-1] = x_size - (chunks - 1) * sizes
1685+
1686+
# From this point, we assume that the sum of the sizes of all chunks
1687+
# equals the size of the base tensor.
16781688
result = []
16791689
start = 0
1680-
for size in sizes:
1690+
for size in sizes_:
16811691
end = start + size
1682-
result.append(slice_(x, dim, start, end, clamp=clamp))
1692+
# No need for clamping here, since we compute the exact
1693+
# start and end values.
1694+
result.append(slice_(x, dim, start, end, clamp=False))
16831695
start = end
16841696
return result
16851697

16861698

16871699
@register_lowering(aten.split_with_sizes, type_promotion_kind=None)
16881700
def split_with_sizes(x, sizes, dim=0):
1689-
return split(x, sizes, dim, clamp=False)
1701+
return split(x, sizes, dim)
16901702

16911703

16921704
@register_lowering(aten.unbind, type_promotion_kind=None)

0 commit comments

Comments
 (0)