Skip to content

Commit 86196de

Browse files
committed
Add support for ellipsis (...) indexing in Helion
stack-info: PR: #437, branch: yf225/stack/53
1 parent 4718678 commit 86196de

File tree

3 files changed

+79
-7
lines changed

3 files changed

+79
-7
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ def valid_block_size(
227227
for i, k in enumerate(subscript):
228228
if k is None:
229229
continue
230+
if k is Ellipsis:
231+
# Ellipsis is not supported in tensor descriptor mode
232+
return False
230233
size, stride = size_stride.popleft()
231234
if isinstance(k, slice):
232235
# Slices with steps are not supported in tensor descriptor mode
@@ -447,6 +450,14 @@ def codegen_store(
447450
)
448451

449452

453+
def _calculate_ellipsis_dims(
454+
index: list[object], current_index: int, total_dims: int
455+
) -> int:
456+
"""Calculate how many dimensions an ellipsis should expand to."""
457+
remaining_indices = len(index) - current_index - 1
458+
return total_dims - current_index - remaining_indices
459+
460+
450461
class SubscriptIndexing(NamedTuple):
451462
index_expr: ast.AST
452463
mask_expr: ast.AST
@@ -465,9 +476,18 @@ def compute_shape(
465476
input_size = collections.deque(tensor.size())
466477
output_size = []
467478
env = CompileEnvironment.current()
468-
for k in index:
479+
for i, k in enumerate(index):
469480
if k is None:
470481
output_size.append(1)
482+
elif k is Ellipsis:
483+
ellipsis_dims = _calculate_ellipsis_dims(index, i, len(tensor.size()))
484+
for _ in range(ellipsis_dims):
485+
size = input_size.popleft()
486+
if size != 1:
487+
rdim = env.allocate_reduction_dimension(size)
488+
output_size.append(rdim.var)
489+
else:
490+
output_size.append(1)
471491
elif isinstance(k, int):
472492
input_size.popleft()
473493
elif isinstance(k, torch.SymInt):
@@ -517,6 +537,21 @@ def create(
517537
for n, k in enumerate(index):
518538
if k is None:
519539
output_idx += 1
540+
elif k is Ellipsis:
541+
ellipsis_dims = _calculate_ellipsis_dims(index, n, fake_value.ndim)
542+
for _ in range(ellipsis_dims):
543+
expand = tile_strategy.expand_str(output_size, output_idx)
544+
size = fake_value.size(len(index_values))
545+
if size != 1:
546+
rdim = env.allocate_reduction_dimension(size)
547+
block_idx = rdim.block_id
548+
index_var = state.codegen.index_var(block_idx)
549+
index_values.append(f"({index_var}){expand}")
550+
if mask := state.codegen.mask_var(block_idx):
551+
mask_values.setdefault(f"({mask}){expand}")
552+
else:
553+
index_values.append(f"tl.zeros([1], {dtype}){expand}")
554+
output_idx += 1
520555
elif isinstance(k, int):
521556
index_values.append(repr(k))
522557
elif isinstance(k, torch.SymInt):
@@ -729,8 +764,16 @@ def is_supported(
729764
# TODO(jansel): support block_ptr with extra_mask
730765
return False
731766
input_sizes = collections.deque(fake_tensor.size())
732-
for k in index:
733-
input_size = 1 if k is None else input_sizes.popleft()
767+
for n, k in enumerate(index):
768+
if k is None:
769+
input_size = 1
770+
elif k is Ellipsis:
771+
ellipsis_dims = _calculate_ellipsis_dims(index, n, fake_tensor.ndim)
772+
for _ in range(ellipsis_dims):
773+
input_sizes.popleft()
774+
continue
775+
else:
776+
input_size = input_sizes.popleft()
734777
if isinstance(k, torch.SymInt):
735778
symbol = k._sympy_()
736779
origin = None
@@ -780,9 +823,21 @@ def create(
780823
fake_value,
781824
reshaped_size=SubscriptIndexing.compute_shape(fake_value, index),
782825
)
783-
for k in index:
826+
for n, k in enumerate(index):
784827
if k is None:
785828
pass # handled by reshaped_size
829+
elif k is Ellipsis:
830+
ellipsis_dims = _calculate_ellipsis_dims(index, n, fake_value.ndim)
831+
env = CompileEnvironment.current()
832+
for _ in range(ellipsis_dims):
833+
size = fake_value.size(len(res.offsets))
834+
if size != 1:
835+
rdim = env.allocate_reduction_dimension(size)
836+
res.offsets.append(state.codegen.offset_var(rdim.block_id))
837+
res.block_shape.append(rdim.var)
838+
else:
839+
res.offsets.append("0")
840+
res.block_shape.append(1)
786841
elif isinstance(k, int):
787842
res.offsets.append(repr(k))
788843
res.block_shape.append(1)

helion/_compiler/type_propagation.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,26 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
433433
inputs_consumed += 1
434434
elif k.value is None:
435435
output_sizes.append(1)
436+
elif k.value is Ellipsis:
437+
# Count indices after ellipsis (excluding None)
438+
remaining_keys = sum(
439+
1
440+
for key in keys[keys.index(k) + 1 :]
441+
if not (isinstance(key, LiteralType) and key.value is None)
442+
)
443+
ellipsis_dims = (
444+
self.fake_value.ndim - inputs_consumed - remaining_keys
445+
)
446+
for _ in range(ellipsis_dims):
447+
size = self.fake_value.size(inputs_consumed)
448+
inputs_consumed += 1
449+
if self.origin.is_device():
450+
output_sizes.append(size)
451+
elif size != 1:
452+
rdim = env.allocate_reduction_dimension(size)
453+
output_sizes.append(rdim.var)
454+
else:
455+
output_sizes.append(1)
436456
else:
437457
raise exc.InvalidIndexingType(k)
438458
elif isinstance(k, SymIntType):

test/test_indexing.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -759,9 +759,6 @@ def kernel(
759759
torch.testing.assert_close(src_result, expected_src)
760760
torch.testing.assert_close(dst_result, expected_dst)
761761

762-
@skipIfNormalMode(
763-
"RankMismatch: Cannot assign a tensor of rank 2 to a buffer of rank 3"
764-
)
765762
def test_ellipsis_indexing(self):
766763
"""Test both setter from scalar and getter for [..., i]"""
767764

0 commit comments

Comments
 (0)