diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index 636e76bc..06d06155 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -31,6 +31,36 @@ ShapeLike = Sequence[SymIntLike] +def _normalize_negative_index( + k: int, + dim_idx: int, + fake_value: torch.Tensor, + state: CodegenState, +) -> str: + """Normalize negative indices to positive ones. + + Args: + k: The negative index value + dim_idx: The dimension index + fake_value: The fake tensor to get dimension size from + state: The codegen state + + Returns: + String representation of the normalized index + """ + assert k < 0, "This function should only be called for negative indices" + + dim_size = fake_value.size(dim_idx) + # Handle both concrete and symbolic dimension sizes + if isinstance(dim_size, int): + normalized_k = k + dim_size + return repr(normalized_k) + # For symbolic dimensions, we need to generate the proper expression + # The state.codegen is a GenerateAST instance which has device_function + sympy_expr = dim_size._sympy_() + k + return f"({state.codegen.device_function.user_sympy_expr(sympy_expr)})" + + class IndexingStrategy: def codegen_load( self, @@ -553,7 +583,14 @@ def create( index_values.append(f"tl.zeros([1], {dtype}){expand}") output_idx += 1 elif isinstance(k, int): - index_values.append(repr(k)) + # Normalize negative indices + if k < 0: + dim_idx = len(index_values) + index_values.append( + _normalize_negative_index(k, dim_idx, fake_value, state) + ) + else: + index_values.append(repr(k)) elif isinstance(k, torch.SymInt): symbol = k._sympy_() origin = None @@ -839,7 +876,14 @@ def create( res.offsets.append("0") res.block_shape.append(1) elif isinstance(k, int): - res.offsets.append(repr(k)) + # Normalize negative indices + if k < 0: + dim_idx = len(res.offsets) + res.offsets.append( + _normalize_negative_index(k, dim_idx, fake_value, state) + ) + else: + res.offsets.append(repr(k)) res.block_shape.append(1) elif isinstance(k, torch.SymInt): symbol = k._sympy_() diff --git a/test/test_indexing.py b/test/test_indexing.py index 5cfc8a6d..3ca95fd6 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -733,7 +733,6 @@ def kernel( torch.testing.assert_close(src2_result, expected_src2) torch.testing.assert_close(dst2_result, expected_dst2) - @skipIfNormalMode("InternalError: Negative indexes") def test_negative_indexing(self): """Test both setter from scalar and getter for [-1]""" @@ -784,9 +783,6 @@ def kernel( torch.testing.assert_close(src_result, expected_src) torch.testing.assert_close(dst_result, expected_dst) - @skipIfNormalMode( - "RankMismatch: Cannot assign a tensor of rank 2 to a buffer of rank 3" - ) def test_multi_dim_slice(self): """Test both setter from scalar and getter for [:, :, i]"""