Skip to content

Commit 3ccf510

Browse files
authored
Add tile.count (#955)
1 parent 843d962 commit 3ccf510

File tree

5 files changed

+142
-3
lines changed

5 files changed

+142
-3
lines changed

helion/language/tile_interface.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@ def block_size(self) -> int:
4545

4646
return tile_block_size(self)
4747

48+
@property
49+
def count(self) -> int:
50+
"""
51+
Alias for :func:`~helion.language.tile_count`, which retrieves the number of tiles.
52+
"""
53+
from .tile_ops import tile_count
54+
55+
return tile_count(self)
56+
4857
@property
4958
def id(self) -> int:
5059
"""

helion/language/tile_ops.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,53 @@ def _(tile: RefTile) -> int:
177177
return tile._block_size
178178

179179

180+
@_decorators.api(tiles_as_sizes=True)
181+
def tile_count(tile: TileInterface) -> int:
182+
"""
183+
Retrieve the number of tiles along the given tile dimension.
184+
This is equivalent to ``cdiv(tile_end, tile.block_size)`` when iterating
185+
from 0, and more generally ``cdiv(end - begin, block_size)`` for an
186+
iteration space [begin, end).
187+
188+
This can also be written as: `tile.count`.
189+
"""
190+
raise exc.NotInsideKernel
191+
192+
193+
@_decorators.register_fake(tile_count)
194+
def _(tile: torch.SymInt) -> torch.SymInt:
195+
index = _disable_flatten_get_tile(tile)
196+
result = CompileEnvironment.current().cached_create_unbacked_symint(
197+
("tile_count", tile)
198+
)
199+
_register_tile_symbol_origin(result, index)
200+
return result
201+
202+
203+
@_decorators.codegen(tile_count)
204+
def _(state: CodegenState) -> ast.AST:
205+
index = _disable_flatten_get_tile(state.proxy_arg(0))
206+
# Use device loop metadata to get end and block size
207+
end_var = (
208+
state.codegen.active_device_loops[index][-1]
209+
.block_id_to_info[index]
210+
.end_var_name
211+
)
212+
block_size_var = state.device_function.block_size_var(index)
213+
if block_size_var is None:
214+
block_size_var = "1"
215+
return expr_from_string(f"tl.cdiv({end_var}, {block_size_var})")
216+
217+
218+
@_decorators.ref(tile_count)
219+
def _(tile: RefTile) -> int:
220+
# Number of tiles covering [begin, end) at granularity block_size
221+
begin = tile._slice.start
222+
end = tile._slice.stop
223+
bs = tile._block_size
224+
return (end - begin + bs - 1) // bs
225+
226+
180227
@_decorators.api(tiles_as_sizes=True)
181228
def tile_id(tile: TileInterface) -> int:
182229
"""

helion/language/tile_proxy.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,14 @@ class Tile(TileInterface, torch.Tensor):
3434
3535
Tile's can be used as indices to tensors, e.g. `tensor[tile]`. Tile's
3636
can also be use as sizes for allocations, e.g. `torch.empty([tile])`.
37-
There are also properties such as :meth:`tile.index <index>`, :meth:`tile.begin <begin>`,
38-
:meth:`tile.end <end>`, :meth:`tile.id <id>` and :meth:`tile.block_size <block_size>` that can be used to retrieve various
39-
information about the tile.
37+
There are also properties such as
38+
* :meth:`tile.index <index>`
39+
* :meth:`tile.begin <begin>`
40+
* :meth:`tile.end <end>`
41+
* :meth:`tile.id <id>`
42+
* :meth:`tile.block_size <block_size>`
43+
* :meth:`tile.count <count>`
44+
that can be used to retrieve various information about the tile.
4045
4146
Masking is implicit for tiles, so if the final tile is smaller than
4247
the block size loading that tile will only load the valid elements

test/test_indexing.expected

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,48 @@ def arange_block_size_mul(x: torch.Tensor, *, _launcher=_default_launcher):
434434
_launcher(_helion_arange_block_size_mul, (triton.cdiv(64, _BLOCK_SIZE_0),), ones, out, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=2)
435435
return out
436436

437+
--- assertExpectedJournal(TestIndexing.test_tile_count_top_level)
438+
from __future__ import annotations
439+
440+
import torch
441+
import triton
442+
import triton.language as tl
443+
from helion.runtime import default_launcher as _default_launcher
444+
445+
@triton.jit
446+
def _helion_fn(out, n, _BLOCK_SIZE_0: tl.constexpr):
447+
pid_0 = tl.program_id(0)
448+
offset_0 = pid_0 * _BLOCK_SIZE_0
449+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
450+
mask_0 = indices_0 < n
451+
tile_count = tl.cdiv(n, _BLOCK_SIZE_0)
452+
tl.store(out + indices_0 * 1, tile_count, mask_0)
453+
454+
def fn(n: int, device: torch.device, *, _launcher=_default_launcher):
455+
out = torch.zeros([n], dtype=torch.int32, device=device)
456+
_BLOCK_SIZE_0 = 64
457+
_launcher(_helion_fn, (triton.cdiv(n, _BLOCK_SIZE_0),), out, n, _BLOCK_SIZE_0, num_warps=4, num_stages=2)
458+
return out
459+
460+
--- assertExpectedJournal(TestIndexing.test_tile_count_with_begin_end)
461+
from __future__ import annotations
462+
463+
import torch
464+
import triton
465+
import triton.language as tl
466+
from helion.runtime import default_launcher as _default_launcher
467+
468+
@triton.jit
469+
def _helion_fn(out, begin, end, _BLOCK_SIZE_0: tl.constexpr):
470+
tile_count = tl.cdiv(end + -1 * begin, _BLOCK_SIZE_0)
471+
tl.store(out + tl.zeros([], tl.int32), tile_count, None)
472+
473+
def fn(begin: int, end: int, device: torch.device, *, _launcher=_default_launcher):
474+
out = torch.zeros([1], dtype=torch.int32, device=device)
475+
_BLOCK_SIZE_0 = 32
476+
_launcher(_helion_fn, (triton.cdiv(end + -1 * begin, _BLOCK_SIZE_0),), out, begin, end, _BLOCK_SIZE_0, num_warps=4, num_stages=2)
477+
return out
478+
437479
--- assertExpectedJournal(TestIndexing.test_tile_with_offset_block_ptr)
438480
from __future__ import annotations
439481

test/test_indexing.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,42 @@ def reduction_sum(x: torch.Tensor) -> torch.Tensor:
4646

4747

4848
class TestIndexing(RefEagerTestBase, TestCase):
49+
@skipIfRefEager(
50+
"Test is block size dependent which is not supported in ref eager mode"
51+
)
52+
def test_tile_count_top_level(self):
53+
@helion.kernel
54+
def fn(n: int, device: torch.device) -> torch.Tensor:
55+
out = torch.zeros([n], dtype=torch.int32, device=device)
56+
for tile in hl.tile(n, block_size=64):
57+
out[tile] = tile.count
58+
return out
59+
60+
n = 100
61+
code, result = code_and_output(fn, (n, DEVICE))
62+
expected = torch.full([n], (n + 64 - 1) // 64, dtype=torch.int32, device=DEVICE)
63+
torch.testing.assert_close(result, expected)
64+
self.assertExpectedJournal(code)
65+
66+
@skipIfRefEager(
67+
"Test is block size dependent which is not supported in ref eager mode"
68+
)
69+
def test_tile_count_with_begin_end(self):
70+
@helion.kernel
71+
def fn(begin: int, end: int, device: torch.device) -> torch.Tensor:
72+
out = torch.zeros([1], dtype=torch.int32, device=device)
73+
for tile in hl.tile(begin, end, block_size=32):
74+
out[0] = tile.count
75+
return out
76+
77+
begin, end = 10, 97
78+
code, result = code_and_output(fn, (begin, end, DEVICE))
79+
expected = torch.tensor(
80+
[(end - begin + 32 - 1) // 32], dtype=torch.int32, device=DEVICE
81+
)
82+
torch.testing.assert_close(result, expected)
83+
self.assertExpectedJournal(code)
84+
4985
def test_arange(self):
5086
@helion.kernel
5187
def arange(length: int, device: torch.device) -> torch.Tensor:

0 commit comments

Comments
 (0)