Skip to content

Commit bd4447b

Browse files
authored
[DEV][PROFILER] Case 2: Add for-loop unroll tracking to Profiler (#197)
1 parent 2418ee9 commit bd4447b

File tree

5 files changed

+218
-19
lines changed

5 files changed

+218
-19
lines changed

tests/test_profiler.py

Lines changed: 118 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,115 @@
88
from triton_viz.clients import Profiler
99

1010

11+
# ======== Case 2: Check if for loop can be unrolled ========
12+
@triton_viz.trace(
13+
clients=(
14+
loop_profiler := Profiler(
15+
disable_buffer_load_check=True, disable_load_mask_percentage_check=True
16+
)
17+
)
18+
)
19+
@triton.jit
20+
def for_loop_test_kernel(
21+
in_ptr,
22+
out_ptr,
23+
N: tl.constexpr,
24+
BLOCK_SIZE: tl.constexpr,
25+
):
26+
"""
27+
Test kernel with for-loops to verify loop statistics tracking.
28+
29+
This kernel contains 4 for-loops:
30+
- First loop: Python range, iterates 10 times
31+
- Second loop: Python range, iterates 5 times
32+
- Third loop: tl.range, iterates 8 times
33+
- Fourth loop: tl.static_range, iterates 6 times
34+
"""
35+
pid = tl.program_id(axis=0)
36+
block_start = pid * BLOCK_SIZE
37+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
38+
mask = offsets < N
39+
40+
# Load input
41+
x = tl.load(in_ptr + offsets, mask=mask, other=0.0)
42+
43+
# First for-loop: Python range, 10 iterations
44+
result = x
45+
for i in range(10):
46+
result = result + 1.0
47+
48+
# Second for-loop: Python range, 5 iterations
49+
for j in range(5):
50+
result = result * 2.0
51+
52+
# Third for-loop: tl.range, 8 iterations
53+
for k in tl.range(2, 10): # 2, 3, 4, 5, 6, 7, 8, 9
54+
result = result + 0.5
55+
56+
# Fourth for-loop: tl.static_range, 6 iterations
57+
for m in tl.static_range(0, 12, 2): # 0, 2, 4, 6, 8, 10
58+
result = result - 0.25
59+
60+
# Store result
61+
tl.store(out_ptr + offsets, result, mask=mask)
62+
63+
64+
def test_for_loop_statistics():
65+
"""
66+
Test that the profiler correctly tracks for-loop statistics.
67+
"""
68+
N = 100
69+
BLOCK_SIZE = 32
70+
num_blocks = triton.cdiv(N, BLOCK_SIZE)
71+
72+
# Create input/output tensors
73+
x = torch.randn(N, dtype=torch.float32)
74+
y = torch.empty(N, dtype=torch.float32)
75+
76+
# Run the kernel
77+
grid = (num_blocks,)
78+
for_loop_test_kernel[grid](x, y, N, BLOCK_SIZE)
79+
80+
# Expected loop statistics:
81+
# The kernel has 4 for-loops, but they are executed once per grid
82+
# Each loop should be recorded only once (when first encountered)
83+
# Loop 1: range(10) -> 10 steps, type: python_range
84+
# Loop 2: range(5) -> 5 steps, type: python_range
85+
# Loop 3: tl.range(2, 10) -> 8 steps, type: tl_range
86+
# Loop 4: tl.static_range(0, 12, 2) -> 6 steps, type: tl_static_range
87+
88+
expected_num_loops = 4
89+
expected_loop_steps = [10, 5, 8, 6]
90+
expected_range_types = [
91+
"python_range",
92+
"python_range",
93+
"tl_range",
94+
"tl_static_range",
95+
]
96+
97+
# Verify the loop statistics from profiler
98+
assert (
99+
len(loop_profiler.loop_info) == expected_num_loops
100+
), f"Expected {expected_num_loops} loops, got {len(loop_profiler.loop_info)}"
101+
102+
for idx, (lineno, loop_info) in enumerate(loop_profiler.loop_info.items()):
103+
expected_steps = expected_loop_steps[idx]
104+
expected_type = expected_range_types[idx]
105+
106+
assert (
107+
loop_info.length == expected_steps
108+
), f"Loop #{idx+1}: Expected {expected_steps} steps, got {loop_info.length}"
109+
assert (
110+
loop_info.range_type == expected_type
111+
), f"Loop #{idx+1}: Expected type {expected_type}, got {loop_info.range_type}"
112+
assert isinstance(
113+
lineno, int
114+
), f"Loop #{idx+1}: lineno should be int, got {type(lineno)}"
115+
assert lineno > 0, f"Loop #{idx+1}: lineno should be positive, got {lineno}"
116+
117+
11118
# ======== Case 3: Check masked element percentage for tuning BLOCK_SIZE ========
12-
@triton_viz.trace(clients=(profiler := Profiler(disable_buffer_load_check=True)))
119+
@triton_viz.trace(clients=(mask_profiler := Profiler(disable_buffer_load_check=True)))
13120
@triton.jit
14121
def mask_percentage_test_kernel(
15122
in_ptr,
@@ -98,27 +205,27 @@ def test_mask_percentage():
98205

99206
# Verify the statistics from profiler
100207
assert (
101-
profiler.load_mask_total_count == expected_load_mask_total
102-
), f"Expected {expected_load_mask_total} total load mask elements, got {profiler.load_mask_total_count}"
208+
mask_profiler.load_mask_total_count == expected_load_mask_total
209+
), f"Expected {expected_load_mask_total} total load mask elements, got {mask_profiler.load_mask_total_count}"
103210
assert (
104-
profiler.load_mask_false_count == expected_load_mask_false
105-
), f"Expected {expected_load_mask_false} false load mask elements, got {profiler.load_mask_false_count}"
211+
mask_profiler.load_mask_false_count == expected_load_mask_false
212+
), f"Expected {expected_load_mask_false} false load mask elements, got {mask_profiler.load_mask_false_count}"
106213
assert (
107-
profiler.store_mask_total_count == expected_store_mask_total
108-
), f"Expected {expected_store_mask_total} total store mask elements, got {profiler.store_mask_total_count}"
214+
mask_profiler.store_mask_total_count == expected_store_mask_total
215+
), f"Expected {expected_store_mask_total} total store mask elements, got {mask_profiler.store_mask_total_count}"
109216
assert (
110-
profiler.store_mask_false_count == expected_store_mask_false
111-
), f"Expected {expected_store_mask_false} false store mask elements, got {profiler.store_mask_false_count}"
217+
mask_profiler.store_mask_false_count == expected_store_mask_false
218+
), f"Expected {expected_store_mask_false} false store mask elements, got {mask_profiler.store_mask_false_count}"
112219

113220
# Verify the masked percentage calculation
114221
expected_load_masked_percentage = (56 / 384) * 100 # ~14.58%
115222
expected_store_masked_percentage = (56 / 384) * 100 # ~14.58%
116223

117224
actual_load_masked_percentage = (
118-
profiler.load_mask_false_count / profiler.load_mask_total_count
225+
mask_profiler.load_mask_false_count / mask_profiler.load_mask_total_count
119226
) * 100
120227
actual_store_masked_percentage = (
121-
profiler.store_mask_false_count / profiler.store_mask_total_count
228+
mask_profiler.store_mask_false_count / mask_profiler.store_mask_total_count
122229
) * 100
123230

124231
assert (

triton_viz/clients/profiler/profiler.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,24 @@
44
from .data import LoadStoreBytes
55
from triton.runtime.interpreter import _get_np_dtype, TensorHandle
66
import numpy as np
7+
from dataclasses import dataclass, replace
78
from typing import Callable, Optional
89

910

11+
@dataclass(frozen=False)
12+
class LoopInfo:
13+
length: Optional[int] = None
14+
range_type: str = "unknown"
15+
16+
1017
class Profiler(Client):
1118
NAME = "profiler"
1219

1320
def __init__(
1421
self,
1522
callpath: bool = True,
1623
disable_buffer_load_check: bool = False,
24+
disable_for_loop_unroll_check: bool = False,
1725
disable_load_mask_percentage_check: bool = False,
1826
disable_load_store_skipping: bool = False,
1927
block_sampling: bool = False,
@@ -26,6 +34,10 @@ def __init__(
2634
self.store_bytes = LoadStoreBytes("store", 0, 0)
2735
self.has_buffer_load = False
2836
self.disable_buffer_load_check = disable_buffer_load_check
37+
self.disable_for_loop_unroll_check = disable_for_loop_unroll_check
38+
39+
# For-loop statistics
40+
self.loop_info: dict[int, LoopInfo] = {}
2941
self.disable_load_mask_percentage_check = disable_load_mask_percentage_check
3042
self.disable_load_store_skipping = disable_load_store_skipping
3143
self.block_sampling = block_sampling
@@ -221,9 +233,54 @@ def pre_addptr_callback(ptr, offset):
221233
return OpCallbacks()
222234

223235
def register_for_loop_callback(self):
224-
return ForLoopCallbacks()
236+
def loop_hook_range_type(lineno: int, range_type: str) -> None:
237+
cur = self.loop_info.get(lineno, LoopInfo())
238+
self.loop_info[lineno] = replace(cur, range_type=range_type)
239+
240+
def loop_hook_before(lineno, iterable):
241+
if self.disable_for_loop_unroll_check:
242+
return
243+
244+
if not isinstance(iterable, range):
245+
return
246+
247+
# Only record each unique loop (by line number) once
248+
# Different blocks will execute the same loop, so we deduplicate by lineno
249+
if self.loop_info[lineno].length is not None:
250+
return
251+
252+
# Record loop information: line number and total steps
253+
length = len(iterable)
254+
# Update length in LoopInfo
255+
cur = self.loop_info.get(lineno, LoopInfo())
256+
self.loop_info[lineno] = replace(cur, length=length)
257+
258+
def loop_hook_after(lineno: int) -> None:
259+
# No action needed after loop for profiler
260+
pass
261+
262+
return ForLoopCallbacks(
263+
range_type_callback=loop_hook_range_type,
264+
before_loop_callback=loop_hook_before,
265+
after_loop_callback=loop_hook_after,
266+
)
225267

226268
def finalize(self) -> list:
269+
# Print for-loop statistics if enabled
270+
if not self.disable_for_loop_unroll_check and self.loop_info:
271+
print("\n" + "=" * 60)
272+
print("Profiler: For-Loop Statistics")
273+
print("=" * 60)
274+
print(f"\nTotal for-loops detected: {len(self.loop_info)}\n")
275+
276+
for idx, (lineno, loop_info) in enumerate(self.loop_info.items(), 1):
277+
print(f"Loop #{idx}:")
278+
print(f" Line number: {lineno}")
279+
print(f" Range type: {loop_info.range_type}")
280+
print(f" Total steps: {loop_info.length}")
281+
282+
print("=" * 60)
283+
227284
# Calculate and print mask statistics only if load mask percentage check is enabled
228285
if not self.disable_load_mask_percentage_check:
229286
print("\n" + "=" * 60)

triton_viz/core/callbacks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class OpCallbacks:
1212

1313
@dataclass
1414
class ForLoopCallbacks:
15+
range_type_callback: Optional[Callable] = None
1516
before_loop_callback: Optional[Callable] = None
1617
loop_iter_overrider: Optional[Callable] = None
1718
loop_iter_listener: Optional[Callable] = None

triton_viz/core/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def grid_idx_callback(self, grid_idx: tuple[int, ...]):
5454
...
5555

5656
@abstractmethod
57-
def register_op_callback(self, op: type[Op]) -> OpCallbacks:
57+
def register_op_callback(self, op_type: type[Op]) -> OpCallbacks:
5858
...
5959

6060
@abstractmethod

triton_viz/core/patch.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,12 @@ def unpatch_op(op_type: type[Op]):
261261

262262

263263
class _LoopIter:
264-
def __init__(self, iterable, lineno, hooks):
264+
def __init__(self, hooks, iterable, lineno, range_type):
265265
self._it = iter(iterable)
266266
self._lineno = lineno
267267
self._hooks = hooks
268+
# triggering range_type
269+
self._hooks.range_type(self._lineno, range_type)
268270
# triggering before_loop
269271
if self._hooks.before_loop:
270272
self._hooks.before_loop(self._lineno, iterable)
@@ -293,12 +295,16 @@ class _CombinedLoopHooks:
293295
"""
294296

295297
def __init__(self):
298+
self._range_type: list[Callable] = []
296299
self._before: list[Callable] = []
297300
self._iter_listeners: list[Callable] = []
298301
self._iter_overrider: Optional[Callable] = None
299302
self._after: list[Callable] = []
300303

301304
# Register hooks
305+
def add_range_type_callback(self, hook: Callable) -> None:
306+
self._range_type.append(hook)
307+
302308
def add_before(self, hook: Callable) -> None:
303309
self._before.append(hook)
304310

@@ -314,6 +320,10 @@ def add_after(self, hook: Callable) -> None:
314320
self._after.append(hook)
315321

316322
# Call combined hooks
323+
def range_type(self, lineno: int, range_type: str) -> None:
324+
for hook in self._range_type:
325+
hook(lineno, range_type)
326+
317327
def before_loop(self, lineno: int, iterable: Any) -> None:
318328
for hook in self._before:
319329
hook(lineno, iterable)
@@ -335,8 +345,10 @@ def after_loop(self, lineno: int) -> None:
335345
for hook in self._after:
336346
hook(lineno)
337347

338-
def loop_iter_wrapper(self, iterable: Any, lineno: int) -> "_LoopIter":
339-
return _LoopIter(iterable, lineno, self)
348+
def loop_iter_wrapper(
349+
self, iterable: Any, lineno: int, range_type: str
350+
) -> "_LoopIter":
351+
return _LoopIter(self, iterable, lineno, range_type)
340352

341353
def clear(self) -> None:
342354
self._before.clear()
@@ -380,13 +392,29 @@ def _visit_For(self, node: ast.For): # type: ignore[override]
380392
for i in R:
381393
...
382394
==>
383-
for i in _triton_viz_loop_patcher.hooks.loop_iter_wrapper(R, lineno):
395+
for i in _triton_viz_loop_patcher.hooks.loop_iter_wrapper(R, lineno, range_type):
384396
...
385397
where _triton_viz_loop_patcher.hooks.loop_iter_wrapper returns a _LoopIter object.
386398
"""
387399
self.generic_visit(node)
388400

389-
# _triton_viz_loop_patcher.hooks.loop_iter(range(...), lineno)
401+
# Detect range type
402+
range_type = "unknown"
403+
if isinstance(node.iter, ast.Call):
404+
func = node.iter.func
405+
if isinstance(func, ast.Name) and func.id == "range":
406+
range_type = "python_range"
407+
elif (
408+
isinstance(func, ast.Attribute)
409+
and isinstance(func.value, ast.Name)
410+
and func.value.id == "tl"
411+
):
412+
if func.attr == "range":
413+
range_type = "tl_range"
414+
elif func.attr == "static_range":
415+
range_type = "tl_static_range"
416+
417+
# _triton_viz_loop_patcher.hooks.loop_iter(range(...), lineno, range_type)
390418
new_iter = ast.Call(
391419
func=ast.Attribute(
392420
value=ast.Attribute(
@@ -397,7 +425,11 @@ def _visit_For(self, node: ast.For): # type: ignore[override]
397425
attr="loop_iter_wrapper",
398426
ctx=ast.Load(),
399427
),
400-
args=[node.iter, ast.Constant(value=node.lineno)],
428+
args=[
429+
node.iter,
430+
ast.Constant(value=node.lineno),
431+
ast.Constant(value=range_type),
432+
],
401433
keywords=[],
402434
)
403435

@@ -415,6 +447,8 @@ def patch_for_loop(loop_callbacks: ForLoopCallbacks):
415447
_loop_patcher.patch()
416448

417449
# Registering hooks
450+
if loop_callbacks.range_type_callback is not None:
451+
_loop_patcher.hooks.add_range_type_callback(loop_callbacks.range_type_callback)
418452
if loop_callbacks.before_loop_callback is not None:
419453
_loop_patcher.hooks.add_before(loop_callbacks.before_loop_callback)
420454
if loop_callbacks.loop_iter_overrider is not None:

0 commit comments

Comments
 (0)