@@ -227,6 +227,9 @@ def valid_block_size(
227
227
for i , k in enumerate (subscript ):
228
228
if k is None :
229
229
continue
230
+ if k is Ellipsis :
231
+ # Ellipsis is not supported in tensor descriptor mode
232
+ return False
230
233
size , stride = size_stride .popleft ()
231
234
if isinstance (k , slice ):
232
235
# Slices with steps are not supported in tensor descriptor mode
@@ -447,6 +450,14 @@ def codegen_store(
447
450
)
448
451
449
452
453
+ def _calculate_ellipsis_dims (
454
+ index : list [object ], current_index : int , total_dims : int
455
+ ) -> int :
456
+ """Calculate how many dimensions an ellipsis should expand to."""
457
+ remaining_indices = len (index ) - current_index - 1
458
+ return total_dims - current_index - remaining_indices
459
+
460
+
450
461
class SubscriptIndexing (NamedTuple ):
451
462
index_expr : ast .AST
452
463
mask_expr : ast .AST
@@ -465,9 +476,18 @@ def compute_shape(
465
476
input_size = collections .deque (tensor .size ())
466
477
output_size = []
467
478
env = CompileEnvironment .current ()
468
- for k in index :
479
+ for i , k in enumerate ( index ) :
469
480
if k is None :
470
481
output_size .append (1 )
482
+ elif k is Ellipsis :
483
+ ellipsis_dims = _calculate_ellipsis_dims (index , i , len (tensor .size ()))
484
+ for _ in range (ellipsis_dims ):
485
+ size = input_size .popleft ()
486
+ if size != 1 :
487
+ rdim = env .allocate_reduction_dimension (size )
488
+ output_size .append (rdim .var )
489
+ else :
490
+ output_size .append (1 )
471
491
elif isinstance (k , int ):
472
492
input_size .popleft ()
473
493
elif isinstance (k , torch .SymInt ):
@@ -517,6 +537,21 @@ def create(
517
537
for n , k in enumerate (index ):
518
538
if k is None :
519
539
output_idx += 1
540
+ elif k is Ellipsis :
541
+ ellipsis_dims = _calculate_ellipsis_dims (index , n , fake_value .ndim )
542
+ for _ in range (ellipsis_dims ):
543
+ expand = tile_strategy .expand_str (output_size , output_idx )
544
+ size = fake_value .size (len (index_values ))
545
+ if size != 1 :
546
+ rdim = env .allocate_reduction_dimension (size )
547
+ block_idx = rdim .block_id
548
+ index_var = state .codegen .index_var (block_idx )
549
+ index_values .append (f"({ index_var } ){ expand } " )
550
+ if mask := state .codegen .mask_var (block_idx ):
551
+ mask_values .setdefault (f"({ mask } ){ expand } " )
552
+ else :
553
+ index_values .append (f"tl.zeros([1], { dtype } ){ expand } " )
554
+ output_idx += 1
520
555
elif isinstance (k , int ):
521
556
index_values .append (repr (k ))
522
557
elif isinstance (k , torch .SymInt ):
@@ -729,8 +764,16 @@ def is_supported(
729
764
# TODO(jansel): support block_ptr with extra_mask
730
765
return False
731
766
input_sizes = collections .deque (fake_tensor .size ())
732
- for k in index :
733
- input_size = 1 if k is None else input_sizes .popleft ()
767
+ for n , k in enumerate (index ):
768
+ if k is None :
769
+ input_size = 1
770
+ elif k is Ellipsis :
771
+ ellipsis_dims = _calculate_ellipsis_dims (index , n , fake_tensor .ndim )
772
+ for _ in range (ellipsis_dims ):
773
+ input_sizes .popleft ()
774
+ continue
775
+ else :
776
+ input_size = input_sizes .popleft ()
734
777
if isinstance (k , torch .SymInt ):
735
778
symbol = k ._sympy_ ()
736
779
origin = None
@@ -780,9 +823,21 @@ def create(
780
823
fake_value ,
781
824
reshaped_size = SubscriptIndexing .compute_shape (fake_value , index ),
782
825
)
783
- for k in index :
826
+ for n , k in enumerate ( index ) :
784
827
if k is None :
785
828
pass # handled by reshaped_size
829
+ elif k is Ellipsis :
830
+ ellipsis_dims = _calculate_ellipsis_dims (index , n , fake_value .ndim )
831
+ env = CompileEnvironment .current ()
832
+ for _ in range (ellipsis_dims ):
833
+ size = fake_value .size (len (res .offsets ))
834
+ if size != 1 :
835
+ rdim = env .allocate_reduction_dimension (size )
836
+ res .offsets .append (state .codegen .offset_var (rdim .block_id ))
837
+ res .block_shape .append (rdim .var )
838
+ else :
839
+ res .offsets .append ("0" )
840
+ res .block_shape .append (1 )
786
841
elif isinstance (k , int ):
787
842
res .offsets .append (repr (k ))
788
843
res .block_shape .append (1 )
0 commit comments