@@ -10,7 +10,7 @@ import triton.language as tl
10
10
from helion.runtime import default_launcher as _default_launcher
11
11
12
12
@triton.jit
13
- def _broadcast_fn_kernel (a, b, out0, out1, a_size_0, a_size_1, a_stride_0, a_stride_1, b_stride_0, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
13
+ def _helion_broadcast_fn (a, b, out0, out1, a_size_0, a_size_1, a_stride_0, a_stride_1, b_stride_0, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
14
14
num_blocks_0 = tl.cdiv(a_size_0, _BLOCK_SIZE_0)
15
15
pid_0 = tl.program_id(0) % num_blocks_0
16
16
pid_1 = tl.program_id(0) // num_blocks_0
@@ -34,7 +34,7 @@ def broadcast_fn(a, b, *, _launcher=_default_launcher):
34
34
out1 = torch.empty_like(a)
35
35
_BLOCK_SIZE_0 = 16
36
36
_BLOCK_SIZE_1 = 8
37
- _launcher(_broadcast_fn_kernel , (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out0, out1, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
37
+ _launcher(_helion_broadcast_fn , (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out0, out1, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
38
38
return (out0, out1)
39
39
40
40
--- assertExpectedJournal(TestBroadcasting.test_broadcast2)
@@ -46,7 +46,7 @@ import triton.language as tl
46
46
from helion.runtime import default_launcher as _default_launcher
47
47
48
48
@triton.jit
49
- def _broadcast_fn_kernel (a, b, out0, out1, a_size_0, a_size_1, a_stride_0, a_stride_1, b_stride_0, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
49
+ def _helion_broadcast_fn (a, b, out0, out1, a_size_0, a_size_1, a_stride_0, a_stride_1, b_stride_0, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
50
50
num_blocks_0 = tl.cdiv(a_size_1, _BLOCK_SIZE_1)
51
51
pid_0 = tl.program_id(0) % num_blocks_0
52
52
pid_1 = tl.program_id(0) // num_blocks_0
@@ -70,7 +70,7 @@ def broadcast_fn(a, b, *, _launcher=_default_launcher):
70
70
out1 = torch.empty_like(a)
71
71
_BLOCK_SIZE_1 = 8
72
72
_BLOCK_SIZE_0 = 16
73
- _launcher(_broadcast_fn_kernel , (triton.cdiv(a.size(1), _BLOCK_SIZE_1) * triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, b, out0, out1, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
73
+ _launcher(_helion_broadcast_fn , (triton.cdiv(a.size(1), _BLOCK_SIZE_1) * triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, b, out0, out1, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
74
74
return (out0, out1)
75
75
76
76
--- assertExpectedJournal(TestBroadcasting.test_broadcast3)
@@ -82,7 +82,7 @@ import triton.language as tl
82
82
from helion.runtime import default_launcher as _default_launcher
83
83
84
84
@triton.jit
85
- def _broadcast_fn_kernel (a, b, out0, out1, a_size_0, a_stride_0, a_stride_1, b_stride_0, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, _BLOCK_SIZE_0: tl.constexpr):
85
+ def _helion_broadcast_fn (a, b, out0, out1, a_size_0, a_stride_0, a_stride_1, b_stride_0, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, _BLOCK_SIZE_0: tl.constexpr):
86
86
num_blocks_0 = tl.cdiv(a_size_0, _BLOCK_SIZE_0)
87
87
pid_0 = tl.program_id(0) % num_blocks_0
88
88
pid_1 = tl.program_id(0) // num_blocks_0
@@ -104,7 +104,7 @@ def broadcast_fn(a, b, *, _launcher=_default_launcher):
104
104
out0 = torch.empty_like(a)
105
105
out1 = torch.empty_like(a)
106
106
_BLOCK_SIZE_0 = 64
107
- _launcher(_broadcast_fn_kernel , (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * a.size(1),), a, b, out0, out1, a.size(0), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
107
+ _launcher(_helion_broadcast_fn , (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * a.size(1),), a, b, out0, out1, a.size(0), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
108
108
return (out0, out1)
109
109
110
110
--- assertExpectedJournal(TestBroadcasting.test_broadcast4)
@@ -116,7 +116,7 @@ import triton.language as tl
116
116
from helion.runtime import default_launcher as _default_launcher
117
117
118
118
@triton.jit
119
- def _broadcast_fn_kernel (a, b, out0, out1, a_size_0, a_size_1, a_stride_0, a_stride_1, b_stride_0, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, _BLOCK_SIZE_1: tl.constexpr):
119
+ def _helion_broadcast_fn (a, b, out0, out1, a_size_0, a_size_1, a_stride_0, a_stride_1, b_stride_0, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, _BLOCK_SIZE_1: tl.constexpr):
120
120
num_blocks_0 = a_size_0
121
121
pid_0 = tl.program_id(0) % num_blocks_0
122
122
pid_1 = tl.program_id(0) // num_blocks_0
@@ -138,7 +138,7 @@ def broadcast_fn(a, b, *, _launcher=_default_launcher):
138
138
out0 = torch.empty_like(a)
139
139
out1 = torch.empty_like(a)
140
140
_BLOCK_SIZE_1 = 64
141
- _launcher(_broadcast_fn_kernel , (a.size(0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out0, out1, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_1, num_warps=4, num_stages=3)
141
+ _launcher(_helion_broadcast_fn , (a.size(0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out0, out1, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_1, num_warps=4, num_stages=3)
142
142
return (out0, out1)
143
143
144
144
--- assertExpectedJournal(TestBroadcasting.test_broadcast5)
@@ -150,7 +150,7 @@ import triton.language as tl
150
150
from helion.runtime import default_launcher as _default_launcher
151
151
152
152
@triton.jit
153
- def _broadcast_fn_kernel (a, b, out0, out1, a_size_0, a_size_1, b_size_0, out0_size_0, out0_size_1, out1_size_0, out1_size_1, a_stride_0, a_stride_1, b_stride_0, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
153
+ def _helion_broadcast_fn (a, b, out0, out1, a_size_0, a_size_1, b_size_0, out0_size_0, out0_size_1, out1_size_0, out1_size_1, a_stride_0, a_stride_1, b_stride_0, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
154
154
num_blocks_0 = tl.cdiv(a_size_0, _BLOCK_SIZE_0)
155
155
pid_0 = tl.program_id(0) % num_blocks_0
156
156
pid_1 = tl.program_id(0) // num_blocks_0
@@ -170,7 +170,7 @@ def broadcast_fn(a, b, *, _launcher=_default_launcher):
170
170
out1 = torch.empty_like(a)
171
171
_BLOCK_SIZE_0 = 32
172
172
_BLOCK_SIZE_1 = 32
173
- _launcher(_broadcast_fn_kernel , (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out0, out1, a.size(0), a.size(1), b.size(0), out0.size(0), out0.size(1), out1.size(0), out1.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
173
+ _launcher(_helion_broadcast_fn , (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out0, out1, a.size(0), a.size(1), b.size(0), out0.size(0), out0.size(1), out1.size(0), out1.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
174
174
return (out0, out1)
175
175
176
176
--- assertExpectedJournal(TestBroadcasting.test_constexpr_index)
@@ -182,7 +182,7 @@ import triton.language as tl
182
182
from helion.runtime import default_launcher as _default_launcher
183
183
184
184
@triton.jit
185
- def _fn_kernel (a, out0, out1, out2, a_size_0, a_size_1, a_stride_0, a_stride_1, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, out2_stride_0, out2_stride_1, idx1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
185
+ def _helion_fn (a, out0, out1, out2, a_size_0, a_size_1, a_stride_0, a_stride_1, out0_stride_0, out0_stride_1, out1_stride_0, out1_stride_1, out2_stride_0, out2_stride_1, idx1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
186
186
num_blocks_0 = tl.cdiv(a_size_0, _BLOCK_SIZE_0)
187
187
pid_0 = tl.program_id(0) % num_blocks_0
188
188
pid_1 = tl.program_id(0) // num_blocks_0
@@ -212,7 +212,7 @@ def fn(a, idx1, *, _launcher=_default_launcher):
212
212
out2 = torch.empty_like(a)
213
213
_BLOCK_SIZE_0 = 16
214
214
_BLOCK_SIZE_1 = 16
215
- _launcher(_fn_kernel , (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, out0, out1, out2, a.size(0), a.size(1), a.stride(0), a.stride(1), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), out2.stride(0), out2.stride(1), idx1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
215
+ _launcher(_helion_fn , (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, out0, out1, out2, a.size(0), a.size(1), a.stride(0), a.stride(1), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), out2.stride(0), out2.stride(1), idx1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
216
216
return (out0, out1, out2)
217
217
218
218
--- assertExpectedJournal(TestBroadcasting.test_implicit_broadcast)
@@ -224,7 +224,7 @@ import triton.language as tl
224
224
from helion.runtime import default_launcher as _default_launcher
225
225
226
226
@triton.jit
227
- def _fn_kernel (a, b, out, a_size_0, a_size_1, a_stride_0, a_stride_1, b_stride_0, out_stride_0, out_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
227
+ def _helion_fn (a, b, out, a_size_0, a_size_1, a_stride_0, a_stride_1, b_stride_0, out_stride_0, out_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
228
228
num_blocks_0 = tl.cdiv(a_size_0, _BLOCK_SIZE_0)
229
229
pid_0 = tl.program_id(0) % num_blocks_0
230
230
pid_1 = tl.program_id(0) // num_blocks_0
@@ -244,5 +244,5 @@ def fn(a, b, *, _launcher=_default_launcher):
244
244
out = torch.empty_like(a)
245
245
_BLOCK_SIZE_0 = 16
246
246
_BLOCK_SIZE_1 = 16
247
- _launcher(_fn_kernel , (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out.stride(0), out.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
247
+ _launcher(_helion_fn , (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out.stride(0), out.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
248
248
return out
0 commit comments