Skip to content

Commit 01f99c6

Browse files
committed
Add helion prefix to Triton kernel name
1 parent 8819331 commit 01f99c6

27 files changed

+707
-707
lines changed

helion/_compiler/generate_ast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(self, func: HostFunction, config: Config) -> None:
5353
self.next_else_block: list[ast.AST] | None = None
5454

5555
# Now create device function and initialize CodegenInterface
56-
self.device_function = DeviceFunction(f"_{func.name}_kernel", config, self)
56+
self.device_function = DeviceFunction(f"_helion_{func.name}", config, self)
5757
CodegenInterface.__init__(self, self.device_function)
5858

5959
def offset_var(self, block_idx: int) -> str:

test/test_associative_scan.expected

Lines changed: 84 additions & 84 deletions
Large diffs are not rendered by default.

test/test_atomic_add.expected

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import triton.language as tl
1010
from helion.runtime import default_launcher as _default_launcher
1111

1212
@triton.jit
13-
def _atomic_add_2d_kernel_kernel(y, x, y_size_0, y_size_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
13+
def _helion_atomic_add_2d_kernel(y, x, y_size_0, y_size_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
1414
num_blocks_0 = tl.cdiv(y_size_0, _BLOCK_SIZE_0)
1515
pid_0 = tl.program_id(0) % num_blocks_0
1616
pid_1 = tl.program_id(0) // num_blocks_0
@@ -27,7 +27,7 @@ def atomic_add_2d_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default
2727
"""Test atomic_add with 2D indexing."""
2828
_BLOCK_SIZE_0 = 8
2929
_BLOCK_SIZE_1 = 8
30-
_launcher(_atomic_add_2d_kernel_kernel, (triton.cdiv(y.size(0), _BLOCK_SIZE_0) * triton.cdiv(y.size(1), _BLOCK_SIZE_1),), y, x, y.size(0), y.size(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
30+
_launcher(_helion_atomic_add_2d_kernel, (triton.cdiv(y.size(0), _BLOCK_SIZE_0) * triton.cdiv(y.size(1), _BLOCK_SIZE_1),), y, x, y.size(0), y.size(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
3131
return x
3232

3333
--- assertExpectedJournal(TestAtomicOperations.test_atomic_add_float)
@@ -39,7 +39,7 @@ import triton.language as tl
3939
from helion.runtime import default_launcher as _default_launcher
4040

4141
@triton.jit
42-
def _atomic_add_float_kernel_kernel(indices, x, indices_size_0, indices_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
42+
def _helion_atomic_add_float_kernel(indices, x, indices_size_0, indices_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
4343
pid_0 = tl.program_id(0)
4444
offset_0 = pid_0 * _BLOCK_SIZE_0
4545
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
@@ -50,7 +50,7 @@ def _atomic_add_float_kernel_kernel(indices, x, indices_size_0, indices_stride_0
5050
def atomic_add_float_kernel(x: torch.Tensor, indices: torch.Tensor, *, _launcher=_default_launcher):
5151
"""Test atomic_add with a float constant value and reading from lookup"""
5252
_BLOCK_SIZE_0 = 32
53-
_launcher(_atomic_add_float_kernel_kernel, (triton.cdiv(indices.size(0), _BLOCK_SIZE_0),), indices, x, indices.size(0), indices.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
53+
_launcher(_helion_atomic_add_float_kernel, (triton.cdiv(indices.size(0), _BLOCK_SIZE_0),), indices, x, indices.size(0), indices.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
5454
return x
5555

5656
--- assertExpectedJournal(TestAtomicOperations.test_atomic_add_w_tile_attr)
@@ -62,7 +62,7 @@ import triton.language as tl
6262
from helion.runtime import default_launcher as _default_launcher
6363

6464
@triton.jit
65-
def _atomic_add_w_tile_attr_kernel(y, y_stride_0, _BLOCK_SIZE_0: tl.constexpr):
65+
def _helion_atomic_add_w_tile_attr(y, y_stride_0, _BLOCK_SIZE_0: tl.constexpr):
6666
pid_0 = tl.program_id(0)
6767
offset_0 = pid_0 * _BLOCK_SIZE_0
6868
tl.atomic_add(y + offset_0 * y_stride_0, 1, mask=None, sem='relaxed')
@@ -71,7 +71,7 @@ def atomic_add_w_tile_attr(x: torch.Tensor, *, _launcher=_default_launcher):
7171
"""Test atomic_add where the index is a symbolic int"""
7272
y = torch.zeros_like(x, device=x.device, dtype=torch.int32)
7373
_BLOCK_SIZE_0 = 2
74-
_launcher(_atomic_add_w_tile_attr_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), y, y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
74+
_launcher(_helion_atomic_add_w_tile_attr, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), y, y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
7575
return y
7676

7777
--- assertExpectedJournal(TestAtomicOperations.test_basic_atomic_add)
@@ -83,7 +83,7 @@ import triton.language as tl
8383
from helion.runtime import default_launcher as _default_launcher
8484

8585
@triton.jit
86-
def _atomic_add_kernel_kernel(x, y, x_size_0, x_stride_0, y_stride_0, _BLOCK_SIZE_0: tl.constexpr):
86+
def _helion_atomic_add_kernel(x, y, x_size_0, x_stride_0, y_stride_0, _BLOCK_SIZE_0: tl.constexpr):
8787
pid_0 = tl.program_id(0)
8888
offset_0 = pid_0 * _BLOCK_SIZE_0
8989
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
@@ -94,7 +94,7 @@ def _atomic_add_kernel_kernel(x, y, x_size_0, x_stride_0, y_stride_0, _BLOCK_SIZ
9494
def atomic_add_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
9595
"""Test basic atomic_add functionality."""
9696
_BLOCK_SIZE_0 = 32
97-
_launcher(_atomic_add_kernel_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
97+
_launcher(_helion_atomic_add_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
9898
return x
9999

100100
--- assertExpectedJournal(TestAtomicOperations.test_overlapping_atomic_add)
@@ -106,7 +106,7 @@ import triton.language as tl
106106
from helion.runtime import default_launcher as _default_launcher
107107

108108
@triton.jit
109-
def _atomic_add_overlap_kernel_kernel(indices, y, x, _BLOCK_SIZE_0: tl.constexpr):
109+
def _helion_atomic_add_overlap_kernel(indices, y, x, _BLOCK_SIZE_0: tl.constexpr):
110110
pid_0 = tl.program_id(0)
111111
offset_0 = pid_0 * _BLOCK_SIZE_0
112112
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
@@ -118,5 +118,5 @@ def _atomic_add_overlap_kernel_kernel(indices, y, x, _BLOCK_SIZE_0: tl.constexpr
118118
def atomic_add_overlap_kernel(x: torch.Tensor, y: torch.Tensor, indices: torch.Tensor, *, _launcher=_default_launcher):
119119
"""Test atomic_add with overlapping indices."""
120120
_BLOCK_SIZE_0 = 32
121-
_launcher(_atomic_add_overlap_kernel_kernel, (triton.cdiv(10, _BLOCK_SIZE_0),), indices, y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
121+
_launcher(_helion_atomic_add_overlap_kernel, (triton.cdiv(10, _BLOCK_SIZE_0),), indices, y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
122122
return x

test/test_broadcasting.expected

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import triton.language as tl
1010
from helion.runtime import default_launcher as _default_launcher
1111

1212
@triton.jit
13-
def _broadcast_fn_kernel(a, b, out0, out1, a_size_0, a_size_1, a_stride_0, a_stride_1, b_stride_0, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
13+
def _helion_broadcast_fn(a, b, out0, out1, a_size_0, a_size_1, a_stride_0, a_stride_1, b_stride_0, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
1414
num_blocks_0 = tl.cdiv(a_size_0, _BLOCK_SIZE_0)
1515
pid_0 = tl.program_id(0) % num_blocks_0
1616
pid_1 = tl.program_id(0) // num_blocks_0
@@ -34,7 +34,7 @@ def broadcast_fn(a, b, *, _launcher=_default_launcher):
3434
out1 = torch.empty_like(a)
3535
_BLOCK_SIZE_0 = 16
3636
_BLOCK_SIZE_1 = 8
37-
_launcher(_broadcast_fn_kernel, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out0, out1, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
37+
_launcher(_helion_broadcast_fn, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out0, out1, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
3838
return (out0, out1)
3939

4040
--- assertExpectedJournal(TestBroadcasting.test_broadcast2)
@@ -46,7 +46,7 @@ import triton.language as tl
4646
from helion.runtime import default_launcher as _default_launcher
4747

4848
@triton.jit
49-
def _broadcast_fn_kernel(a, b, out0, out1, a_size_0, a_size_1, a_stride_0, a_stride_1, b_stride_0, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
49+
def _helion_broadcast_fn(a, b, out0, out1, a_size_0, a_size_1, a_stride_0, a_stride_1, b_stride_0, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
5050
num_blocks_0 = tl.cdiv(a_size_1, _BLOCK_SIZE_1)
5151
pid_0 = tl.program_id(0) % num_blocks_0
5252
pid_1 = tl.program_id(0) // num_blocks_0
@@ -70,7 +70,7 @@ def broadcast_fn(a, b, *, _launcher=_default_launcher):
7070
out1 = torch.empty_like(a)
7171
_BLOCK_SIZE_1 = 8
7272
_BLOCK_SIZE_0 = 16
73-
_launcher(_broadcast_fn_kernel, (triton.cdiv(a.size(1), _BLOCK_SIZE_1) * triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, b, out0, out1, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
73+
_launcher(_helion_broadcast_fn, (triton.cdiv(a.size(1), _BLOCK_SIZE_1) * triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, b, out0, out1, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
7474
return (out0, out1)
7575

7676
--- assertExpectedJournal(TestBroadcasting.test_broadcast3)
@@ -82,7 +82,7 @@ import triton.language as tl
8282
from helion.runtime import default_launcher as _default_launcher
8383

8484
@triton.jit
85-
def _broadcast_fn_kernel(a, b, out0, out1, a_size_0, a_stride_0, a_stride_1, b_stride_0, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, _BLOCK_SIZE_0: tl.constexpr):
85+
def _helion_broadcast_fn(a, b, out0, out1, a_size_0, a_stride_0, a_stride_1, b_stride_0, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, _BLOCK_SIZE_0: tl.constexpr):
8686
num_blocks_0 = tl.cdiv(a_size_0, _BLOCK_SIZE_0)
8787
pid_0 = tl.program_id(0) % num_blocks_0
8888
pid_1 = tl.program_id(0) // num_blocks_0
@@ -104,7 +104,7 @@ def broadcast_fn(a, b, *, _launcher=_default_launcher):
104104
out0 = torch.empty_like(a)
105105
out1 = torch.empty_like(a)
106106
_BLOCK_SIZE_0 = 64
107-
_launcher(_broadcast_fn_kernel, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * a.size(1),), a, b, out0, out1, a.size(0), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
107+
_launcher(_helion_broadcast_fn, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * a.size(1),), a, b, out0, out1, a.size(0), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
108108
return (out0, out1)
109109

110110
--- assertExpectedJournal(TestBroadcasting.test_broadcast4)
@@ -116,7 +116,7 @@ import triton.language as tl
116116
from helion.runtime import default_launcher as _default_launcher
117117

118118
@triton.jit
119-
def _broadcast_fn_kernel(a, b, out0, out1, a_size_0, a_size_1, a_stride_0, a_stride_1, b_stride_0, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, _BLOCK_SIZE_1: tl.constexpr):
119+
def _helion_broadcast_fn(a, b, out0, out1, a_size_0, a_size_1, a_stride_0, a_stride_1, b_stride_0, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, _BLOCK_SIZE_1: tl.constexpr):
120120
num_blocks_0 = a_size_0
121121
pid_0 = tl.program_id(0) % num_blocks_0
122122
pid_1 = tl.program_id(0) // num_blocks_0
@@ -138,7 +138,7 @@ def broadcast_fn(a, b, *, _launcher=_default_launcher):
138138
out0 = torch.empty_like(a)
139139
out1 = torch.empty_like(a)
140140
_BLOCK_SIZE_1 = 64
141-
_launcher(_broadcast_fn_kernel, (a.size(0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out0, out1, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_1, num_warps=4, num_stages=3)
141+
_launcher(_helion_broadcast_fn, (a.size(0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out0, out1, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_1, num_warps=4, num_stages=3)
142142
return (out0, out1)
143143

144144
--- assertExpectedJournal(TestBroadcasting.test_broadcast5)
@@ -150,7 +150,7 @@ import triton.language as tl
150150
from helion.runtime import default_launcher as _default_launcher
151151

152152
@triton.jit
153-
def _broadcast_fn_kernel(a, b, out0, out1, a_size_0, a_size_1, b_size_0, out0_size_0, out0_size_1, out1_size_0, out1_size_1, a_stride_0, a_stride_1, b_stride_0, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
153+
def _helion_broadcast_fn(a, b, out0, out1, a_size_0, a_size_1, b_size_0, out0_size_0, out0_size_1, out1_size_0, out1_size_1, a_stride_0, a_stride_1, b_stride_0, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
154154
num_blocks_0 = tl.cdiv(a_size_0, _BLOCK_SIZE_0)
155155
pid_0 = tl.program_id(0) % num_blocks_0
156156
pid_1 = tl.program_id(0) // num_blocks_0
@@ -170,7 +170,7 @@ def broadcast_fn(a, b, *, _launcher=_default_launcher):
170170
out1 = torch.empty_like(a)
171171
_BLOCK_SIZE_0 = 32
172172
_BLOCK_SIZE_1 = 32
173-
_launcher(_broadcast_fn_kernel, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out0, out1, a.size(0), a.size(1), b.size(0), out0.size(0), out0.size(1), out1.size(0), out1.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
173+
_launcher(_helion_broadcast_fn, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out0, out1, a.size(0), a.size(1), b.size(0), out0.size(0), out0.size(1), out1.size(0), out1.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
174174
return (out0, out1)
175175

176176
--- assertExpectedJournal(TestBroadcasting.test_constexpr_index)
@@ -182,7 +182,7 @@ import triton.language as tl
182182
from helion.runtime import default_launcher as _default_launcher
183183

184184
@triton.jit
185-
def _fn_kernel(a, out0, out1, out2, a_size_0, a_size_1, a_stride_0, a_stride_1, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, out2_stride_0, out2_stride_1, idx1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
185+
def _helion_fn(a, out0, out1, out2, a_size_0, a_size_1, a_stride_0, a_stride_1, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, out2_stride_0, out2_stride_1, idx1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
186186
num_blocks_0 = tl.cdiv(a_size_0, _BLOCK_SIZE_0)
187187
pid_0 = tl.program_id(0) % num_blocks_0
188188
pid_1 = tl.program_id(0) // num_blocks_0
@@ -212,7 +212,7 @@ def fn(a, idx1, *, _launcher=_default_launcher):
212212
out2 = torch.empty_like(a)
213213
_BLOCK_SIZE_0 = 16
214214
_BLOCK_SIZE_1 = 16
215-
_launcher(_fn_kernel, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, out0, out1, out2, a.size(0), a.size(1), a.stride(0), a.stride(1), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), out2.stride(0), out2.stride(1), idx1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
215+
_launcher(_helion_fn, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, out0, out1, out2, a.size(0), a.size(1), a.stride(0), a.stride(1), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), out2.stride(0), out2.stride(1), idx1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
216216
return (out0, out1, out2)
217217

218218
--- assertExpectedJournal(TestBroadcasting.test_implicit_broadcast)
@@ -224,7 +224,7 @@ import triton.language as tl
224224
from helion.runtime import default_launcher as _default_launcher
225225

226226
@triton.jit
227-
def _fn_kernel(a, b, out, a_size_0, a_size_1, a_stride_0, a_stride_1, b_stride_0, out_stride_0, out_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
227+
def _helion_fn(a, b, out, a_size_0, a_size_1, a_stride_0, a_stride_1, b_stride_0, out_stride_0, out_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
228228
num_blocks_0 = tl.cdiv(a_size_0, _BLOCK_SIZE_0)
229229
pid_0 = tl.program_id(0) % num_blocks_0
230230
pid_1 = tl.program_id(0) // num_blocks_0
@@ -244,5 +244,5 @@ def fn(a, b, *, _launcher=_default_launcher):
244244
out = torch.empty_like(a)
245245
_BLOCK_SIZE_0 = 16
246246
_BLOCK_SIZE_1 = 16
247-
_launcher(_fn_kernel, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out.stride(0), out.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
247+
_launcher(_helion_fn, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out.stride(0), out.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
248248
return out

0 commit comments

Comments
 (0)