|
1 | 1 | This file is automatically generated by assertExpectedJournal calls in test_signal_wait.py.
|
2 | 2 | Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.
|
3 | 3 |
|
| 4 | +--- assertExpectedJournal(TestWait.test_global_sync) |
| 5 | +from __future__ import annotations |
| 6 | + |
| 7 | +import torch |
| 8 | +import helion |
| 9 | +import triton |
| 10 | +import triton.language as tl |
| 11 | + |
| 12 | +@triton.jit |
| 13 | +def _gmem_multi_bar_sync_kernel_kernel(signal_pad, signal_pad_stride_0, signal_pad_stride_1, N, _BLOCK_SIZE_1: tl.constexpr): |
| 14 | + pid_0 = tl.program_id(0) |
| 15 | + offset_0 = pid_0 |
| 16 | + for offset_1 in tl.range(0, N.to(tl.int32), _BLOCK_SIZE_1): |
| 17 | + indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) |
| 18 | + helion.runtime.triton_send_signal(addr=signal_pad + (indices_1 * signal_pad_stride_0 + offset_0 * signal_pad_stride_1), update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=True) |
| 19 | + helion.runtime.triton_wait_multiple_signal(addr=signal_pad + (offset_0 * signal_pad_stride_0 + indices_1 * signal_pad_stride_1), expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False) |
| 20 | + |
| 21 | +def gmem_multi_bar_sync_kernel(signal_pad: torch.Tensor): |
| 22 | + M, N = signal_pad.shape |
| 23 | + assert M == N |
| 24 | + _BLOCK_SIZE_1 = N |
| 25 | + _gmem_multi_bar_sync_kernel_kernel[N,](signal_pad, signal_pad.stride(0), signal_pad.stride(1), N, _BLOCK_SIZE_1, num_warps=4, num_stages=3) |
| 26 | + return signal_pad |
| 27 | + |
| 28 | +def _gmem_multi_bar_sync_kernel_make_precompiler(signal_pad: torch.Tensor): |
| 29 | + M, N = signal_pad.shape |
| 30 | + assert M == N |
| 31 | + _BLOCK_SIZE_1 = N |
| 32 | + from helion.runtime.precompile_shim import make_precompiler |
| 33 | + return make_precompiler(_gmem_multi_bar_sync_kernel_kernel)(signal_pad, signal_pad.stride(0), signal_pad.stride(1), N, _BLOCK_SIZE_1, num_warps=4, num_stages=3) |
| 34 | + |
4 | 35 | --- assertExpectedJournal(TestWait.test_signal_basic)
|
5 | 36 | from __future__ import annotations
|
6 | 37 |
|
@@ -120,3 +151,38 @@ def _gmem_wait_kernel_make_precompiler(signal_pad: torch.Tensor):
|
120 | 151 | from helion.runtime.precompile_shim import make_precompiler
|
121 | 152 | return make_precompiler(_gmem_wait_kernel_kernel)(signal_pad, out, out.stride(0), signal_pad.stride(0), num_warps=4, num_stages=3)
|
122 | 153 |
|
| 154 | +--- assertExpectedJournal(TestWait.test_wait_multi_bar) |
| 155 | +from __future__ import annotations |
| 156 | + |
| 157 | +import torch |
| 158 | +import helion |
| 159 | +import triton |
| 160 | +import triton.language as tl |
| 161 | + |
| 162 | +import __main__ as _source_module |
| 163 | + |
| 164 | +@triton.jit |
| 165 | +def _gmem_wait_multi_bar_kernel_kernel(signal_pad, out, out_stride_0, signal_pad_stride_0, _BLOCK_SIZE_0: tl.constexpr): |
| 166 | + pid_0 = tl.program_id(0) |
| 167 | + offset_0 = pid_0 * _BLOCK_SIZE_0 |
| 168 | + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) |
| 169 | + helion.runtime.triton_wait_multiple_signal(addr=signal_pad + indices_0 * signal_pad_stride_0, expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False) |
| 170 | + tile_id = offset_0 // _BLOCK_SIZE_0 |
| 171 | + tl.store(out + tile_id * out_stride_0, tile_id, None) |
| 172 | + |
| 173 | +def gmem_wait_multi_bar_kernel(signal_pad: torch.Tensor): |
| 174 | + N, = signal_pad.shape |
| 175 | + n = 4 |
| 176 | + out = torch.empty(n, dtype=torch.int32, device=_source_module.DEVICE) |
| 177 | + _BLOCK_SIZE_0 = 4 |
| 178 | + _gmem_wait_multi_bar_kernel_kernel[triton.cdiv(N, _BLOCK_SIZE_0),](signal_pad, out, out.stride(0), signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) |
| 179 | + return out |
| 180 | + |
| 181 | +def _gmem_wait_multi_bar_kernel_make_precompiler(signal_pad: torch.Tensor): |
| 182 | + N, = signal_pad.shape |
| 183 | + n = 4 |
| 184 | + out = torch.empty(n, dtype=torch.int32, device=_source_module.DEVICE) |
| 185 | + _BLOCK_SIZE_0 = 4 |
| 186 | + from helion.runtime.precompile_shim import make_precompiler |
| 187 | + return make_precompiler(_gmem_wait_multi_bar_kernel_kernel)(signal_pad, out, out.stride(0), signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) |
| 188 | + |
0 commit comments