@@ -86,6 +86,70 @@ def _fn_make_precompiler(x, v):
8686 return make_precompiler(_fn_kernel)(x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), v, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""" ,
8787 )
8888
89+ def test_if_arg_one_element_tensor (self ):
90+ @helion .kernel
91+ def fn (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
92+ output = torch .zeros_like (x )
93+
94+ for idx in hl .grid (x .shape [0 ]):
95+ # Since `y[idx]` is a one-element tensor, comparing it against 0 will also create a one-element tensor.
96+ if y [idx ] != 0 :
97+ output [idx ] = x [idx ] * 2
98+ if (
99+ y [idx ] == 0
100+ ): # TODO(yf225): `else:` raises MLIR error in Triton, so we use a second if.
101+ output [idx ] = x [idx ]
102+
103+ return output
104+
105+ x = torch .tensor ([1.0 , 2.0 , 3.0 , 4.0 ], device = DEVICE )
106+ y = torch .tensor ([0 , 1 , 0 , 1 ], device = DEVICE , dtype = torch .int32 )
107+ expected = torch .tensor ([1.0 , 4.0 , 3.0 , 8.0 ], device = DEVICE )
108+ code , result = code_and_output (
109+ fn ,
110+ (x , y ),
111+ )
112+ torch .testing .assert_close (result , expected )
113+ self .assertExpectedInline (
114+ code ,
115+ """\
116+ from __future__ import annotations
117+
118+ import torch
119+ import triton
120+ import triton.language as tl
121+
122+ @triton.jit
123+ def _fn_kernel(x, y, output, output_stride_0, x_stride_0, y_stride_0):
124+ pid_0 = tl.program_id(0)
125+ offset_0 = pid_0
126+ indices_0 = offset_0 + tl.zeros([1], tl.int32)
127+ load = tl.load(y + indices_0 * y_stride_0, None)
128+ v_0 = tl.full([], 0, tl.int32)
129+ v_1 = load != v_0
130+ if tl.sum(v_1):
131+ load_1 = tl.load(x + indices_0 * x_stride_0, None)
132+ v_2 = 2.0
133+ v_3 = load_1 * v_2
134+ tl.store(output + indices_0 * output_stride_0, v_3, None)
135+ load_2 = tl.load(y + indices_0 * y_stride_0, None)
136+ v_4 = tl.full([], 0, tl.int32)
137+ v_5 = load_2 == v_4
138+ if tl.sum(v_5):
139+ load_3 = tl.load(x + indices_0 * x_stride_0, None)
140+ tl.store(output + indices_0 * output_stride_0, load_3, None)
141+
142+ def fn(x: torch.Tensor, y: torch.Tensor):
143+ output = torch.zeros_like(x)
144+ _fn_kernel[x.size(0),](x, y, output, output.stride(0), x.stride(0), y.stride(0), num_warps=4, num_stages=3)
145+ return output
146+
147+ def _fn_make_precompiler(x: torch.Tensor, y: torch.Tensor):
148+ output = torch.zeros_like(x)
149+ from helion.runtime.precompile_shim import make_precompiler
150+ return make_precompiler(_fn_kernel)(x, y, output, output.stride(0), x.stride(0), y.stride(0), num_warps=4, num_stages=3)""" ,
151+ )
152+
89153 def test_constant_true (self ):
90154 @helion .kernel (
91155 config = {
0 commit comments