diff --git a/tests/test_profiler.py b/tests/test_profiler.py index fb8cf1ce..3ad8d271 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -8,8 +8,115 @@ from triton_viz.clients import Profiler +# ======== Case 2: Check if for loop can be unrolled ======== +@triton_viz.trace( + clients=( + loop_profiler := Profiler( + disable_buffer_load_check=True, disable_load_mask_percentage_check=True + ) + ) +) +@triton.jit +def for_loop_test_kernel( + in_ptr, + out_ptr, + N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + Test kernel with for-loops to verify loop statistics tracking. + + This kernel contains 4 for-loops: + - First loop: Python range, iterates 10 times + - Second loop: Python range, iterates 5 times + - Third loop: tl.range, iterates 8 times + - Fourth loop: tl.static_range, iterates 6 times + """ + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < N + + # Load input + x = tl.load(in_ptr + offsets, mask=mask, other=0.0) + + # First for-loop: Python range, 10 iterations + result = x + for i in range(10): + result = result + 1.0 + + # Second for-loop: Python range, 5 iterations + for j in range(5): + result = result * 2.0 + + # Third for-loop: tl.range, 8 iterations + for k in tl.range(2, 10): # 2, 3, 4, 5, 6, 7, 8, 9 + result = result + 0.5 + + # Fourth for-loop: tl.static_range, 6 iterations + for m in tl.static_range(0, 12, 2): # 0, 2, 4, 6, 8, 10 + result = result - 0.25 + + # Store result + tl.store(out_ptr + offsets, result, mask=mask) + + +def test_for_loop_statistics(): + """ + Test that the profiler correctly tracks for-loop statistics. + """ + N = 100 + BLOCK_SIZE = 32 + num_blocks = triton.cdiv(N, BLOCK_SIZE) + + # Create input/output tensors + x = torch.randn(N, dtype=torch.float32) + y = torch.empty(N, dtype=torch.float32) + + # Run the kernel + grid = (num_blocks,) + for_loop_test_kernel[grid](x, y, N, BLOCK_SIZE) + + # Expected loop statistics: + # The kernel has 4 for-loops, but they are executed once per grid + # Each loop should be recorded only once (when first encountered) + # Loop 1: range(10) -> 10 steps, type: python_range + # Loop 2: range(5) -> 5 steps, type: python_range + # Loop 3: tl.range(2, 10) -> 8 steps, type: tl_range + # Loop 4: tl.static_range(0, 12, 2) -> 6 steps, type: tl_static_range + + expected_num_loops = 4 + expected_loop_steps = [10, 5, 8, 6] + expected_range_types = [ + "python_range", + "python_range", + "tl_range", + "tl_static_range", + ] + + # Verify the loop statistics from profiler + assert ( + len(loop_profiler.loop_info) == expected_num_loops + ), f"Expected {expected_num_loops} loops, got {len(loop_profiler.loop_info)}" + + for idx, (lineno, loop_info) in enumerate(loop_profiler.loop_info.items()): + expected_steps = expected_loop_steps[idx] + expected_type = expected_range_types[idx] + + assert ( + loop_info.length == expected_steps + ), f"Loop #{idx+1}: Expected {expected_steps} steps, got {loop_info.length}" + assert ( + loop_info.range_type == expected_type + ), f"Loop #{idx+1}: Expected type {expected_type}, got {loop_info.range_type}" + assert isinstance( + lineno, int + ), f"Loop #{idx+1}: lineno should be int, got {type(lineno)}" + assert lineno > 0, f"Loop #{idx+1}: lineno should be positive, got {lineno}" + + # ======== Case 3: Check masked element percentage for tuning BLOCK_SIZE ======== -@triton_viz.trace(clients=(profiler := Profiler(disable_buffer_load_check=True))) +@triton_viz.trace(clients=(mask_profiler := Profiler(disable_buffer_load_check=True))) @triton.jit def mask_percentage_test_kernel( in_ptr, @@ -98,27 +205,27 @@ def test_mask_percentage(): # Verify the statistics from profiler assert ( - profiler.load_mask_total_count == expected_load_mask_total - ), f"Expected {expected_load_mask_total} total load mask elements, got {profiler.load_mask_total_count}" + mask_profiler.load_mask_total_count == expected_load_mask_total + ), f"Expected {expected_load_mask_total} total load mask elements, got {mask_profiler.load_mask_total_count}" assert ( - profiler.load_mask_false_count == expected_load_mask_false - ), f"Expected {expected_load_mask_false} false load mask elements, got {profiler.load_mask_false_count}" + mask_profiler.load_mask_false_count == expected_load_mask_false + ), f"Expected {expected_load_mask_false} false load mask elements, got {mask_profiler.load_mask_false_count}" assert ( - profiler.store_mask_total_count == expected_store_mask_total - ), f"Expected {expected_store_mask_total} total store mask elements, got {profiler.store_mask_total_count}" + mask_profiler.store_mask_total_count == expected_store_mask_total + ), f"Expected {expected_store_mask_total} total store mask elements, got {mask_profiler.store_mask_total_count}" assert ( - profiler.store_mask_false_count == expected_store_mask_false - ), f"Expected {expected_store_mask_false} false store mask elements, got {profiler.store_mask_false_count}" + mask_profiler.store_mask_false_count == expected_store_mask_false + ), f"Expected {expected_store_mask_false} false store mask elements, got {mask_profiler.store_mask_false_count}" # Verify the masked percentage calculation expected_load_masked_percentage = (56 / 384) * 100 # ~14.58% expected_store_masked_percentage = (56 / 384) * 100 # ~14.58% actual_load_masked_percentage = ( - profiler.load_mask_false_count / profiler.load_mask_total_count + mask_profiler.load_mask_false_count / mask_profiler.load_mask_total_count ) * 100 actual_store_masked_percentage = ( - profiler.store_mask_false_count / profiler.store_mask_total_count + mask_profiler.store_mask_false_count / mask_profiler.store_mask_total_count ) * 100 assert ( diff --git a/triton_viz/clients/profiler/profiler.py b/triton_viz/clients/profiler/profiler.py index 8794f58a..9b06c099 100644 --- a/triton_viz/clients/profiler/profiler.py +++ b/triton_viz/clients/profiler/profiler.py @@ -4,9 +4,16 @@ from .data import LoadStoreBytes from triton.runtime.interpreter import _get_np_dtype, TensorHandle import numpy as np +from dataclasses import dataclass, replace from typing import Callable, Optional +@dataclass(frozen=False) +class LoopInfo: + length: Optional[int] = None + range_type: str = "unknown" + + class Profiler(Client): NAME = "profiler" @@ -14,6 +21,7 @@ def __init__( self, callpath: bool = True, disable_buffer_load_check: bool = False, + disable_for_loop_unroll_check: bool = False, disable_load_mask_percentage_check: bool = False, disable_load_store_skipping: bool = False, block_sampling: bool = False, @@ -26,6 +34,10 @@ def __init__( self.store_bytes = LoadStoreBytes("store", 0, 0) self.has_buffer_load = False self.disable_buffer_load_check = disable_buffer_load_check + self.disable_for_loop_unroll_check = disable_for_loop_unroll_check + + # For-loop statistics + self.loop_info: dict[int, LoopInfo] = {} self.disable_load_mask_percentage_check = disable_load_mask_percentage_check self.disable_load_store_skipping = disable_load_store_skipping self.block_sampling = block_sampling @@ -221,9 +233,54 @@ def pre_addptr_callback(ptr, offset): return OpCallbacks() def register_for_loop_callback(self): - return ForLoopCallbacks() + def loop_hook_range_type(lineno: int, range_type: str) -> None: + cur = self.loop_info.get(lineno, LoopInfo()) + self.loop_info[lineno] = replace(cur, range_type=range_type) + + def loop_hook_before(lineno, iterable): + if self.disable_for_loop_unroll_check: + return + + if not isinstance(iterable, range): + return + + # Only record each unique loop (by line number) once + # Different blocks will execute the same loop, so we deduplicate by lineno + if self.loop_info[lineno].length is not None: + return + + # Record loop information: line number and total steps + length = len(iterable) + # Update length in LoopInfo + cur = self.loop_info.get(lineno, LoopInfo()) + self.loop_info[lineno] = replace(cur, length=length) + + def loop_hook_after(lineno: int) -> None: + # No action needed after loop for profiler + pass + + return ForLoopCallbacks( + range_type_callback=loop_hook_range_type, + before_loop_callback=loop_hook_before, + after_loop_callback=loop_hook_after, + ) def finalize(self) -> list: + # Print for-loop statistics if enabled + if not self.disable_for_loop_unroll_check and self.loop_info: + print("\n" + "=" * 60) + print("Profiler: For-Loop Statistics") + print("=" * 60) + print(f"\nTotal for-loops detected: {len(self.loop_info)}\n") + + for idx, (lineno, loop_info) in enumerate(self.loop_info.items(), 1): + print(f"Loop #{idx}:") + print(f" Line number: {lineno}") + print(f" Range type: {loop_info.range_type}") + print(f" Total steps: {loop_info.length}") + + print("=" * 60) + # Calculate and print mask statistics only if load mask percentage check is enabled if not self.disable_load_mask_percentage_check: print("\n" + "=" * 60) diff --git a/triton_viz/core/callbacks.py b/triton_viz/core/callbacks.py index bdac8a50..2db57ac1 100644 --- a/triton_viz/core/callbacks.py +++ b/triton_viz/core/callbacks.py @@ -12,6 +12,7 @@ class OpCallbacks: @dataclass class ForLoopCallbacks: + range_type_callback: Optional[Callable] = None before_loop_callback: Optional[Callable] = None loop_iter_overrider: Optional[Callable] = None loop_iter_listener: Optional[Callable] = None diff --git a/triton_viz/core/client.py b/triton_viz/core/client.py index e1391f0b..8fb00feb 100644 --- a/triton_viz/core/client.py +++ b/triton_viz/core/client.py @@ -54,7 +54,7 @@ def grid_idx_callback(self, grid_idx: tuple[int, ...]): ... @abstractmethod - def register_op_callback(self, op: type[Op]) -> OpCallbacks: + def register_op_callback(self, op_type: type[Op]) -> OpCallbacks: ... @abstractmethod diff --git a/triton_viz/core/patch.py b/triton_viz/core/patch.py index ae5544da..e08d3015 100644 --- a/triton_viz/core/patch.py +++ b/triton_viz/core/patch.py @@ -261,10 +261,12 @@ def unpatch_op(op_type: type[Op]): class _LoopIter: - def __init__(self, iterable, lineno, hooks): + def __init__(self, hooks, iterable, lineno, range_type): self._it = iter(iterable) self._lineno = lineno self._hooks = hooks + # triggering range_type + self._hooks.range_type(self._lineno, range_type) # triggering before_loop if self._hooks.before_loop: self._hooks.before_loop(self._lineno, iterable) @@ -293,12 +295,16 @@ class _CombinedLoopHooks: """ def __init__(self): + self._range_type: list[Callable] = [] self._before: list[Callable] = [] self._iter_listeners: list[Callable] = [] self._iter_overrider: Optional[Callable] = None self._after: list[Callable] = [] # Register hooks + def add_range_type_callback(self, hook: Callable) -> None: + self._range_type.append(hook) + def add_before(self, hook: Callable) -> None: self._before.append(hook) @@ -314,6 +320,10 @@ def add_after(self, hook: Callable) -> None: self._after.append(hook) # Call combined hooks + def range_type(self, lineno: int, range_type: str) -> None: + for hook in self._range_type: + hook(lineno, range_type) + def before_loop(self, lineno: int, iterable: Any) -> None: for hook in self._before: hook(lineno, iterable) @@ -335,8 +345,10 @@ def after_loop(self, lineno: int) -> None: for hook in self._after: hook(lineno) - def loop_iter_wrapper(self, iterable: Any, lineno: int) -> "_LoopIter": - return _LoopIter(iterable, lineno, self) + def loop_iter_wrapper( + self, iterable: Any, lineno: int, range_type: str + ) -> "_LoopIter": + return _LoopIter(self, iterable, lineno, range_type) def clear(self) -> None: self._before.clear() @@ -380,13 +392,29 @@ def _visit_For(self, node: ast.For): # type: ignore[override] for i in R: ... ==> - for i in _triton_viz_loop_patcher.hooks.loop_iter_wrapper(R, lineno): + for i in _triton_viz_loop_patcher.hooks.loop_iter_wrapper(R, lineno, range_type): ... where _triton_viz_loop_patcher.hooks.loop_iter_wrapper returns a _LoopIter object. """ self.generic_visit(node) - # _triton_viz_loop_patcher.hooks.loop_iter(range(...), lineno) + # Detect range type + range_type = "unknown" + if isinstance(node.iter, ast.Call): + func = node.iter.func + if isinstance(func, ast.Name) and func.id == "range": + range_type = "python_range" + elif ( + isinstance(func, ast.Attribute) + and isinstance(func.value, ast.Name) + and func.value.id == "tl" + ): + if func.attr == "range": + range_type = "tl_range" + elif func.attr == "static_range": + range_type = "tl_static_range" + + # _triton_viz_loop_patcher.hooks.loop_iter(range(...), lineno, range_type) new_iter = ast.Call( func=ast.Attribute( value=ast.Attribute( @@ -397,7 +425,11 @@ def _visit_For(self, node: ast.For): # type: ignore[override] attr="loop_iter_wrapper", ctx=ast.Load(), ), - args=[node.iter, ast.Constant(value=node.lineno)], + args=[ + node.iter, + ast.Constant(value=node.lineno), + ast.Constant(value=range_type), + ], keywords=[], ) @@ -415,6 +447,8 @@ def patch_for_loop(loop_callbacks: ForLoopCallbacks): _loop_patcher.patch() # Registering hooks + if loop_callbacks.range_type_callback is not None: + _loop_patcher.hooks.add_range_type_callback(loop_callbacks.range_type_callback) if loop_callbacks.before_loop_callback is not None: _loop_patcher.hooks.add_before(loop_callbacks.before_loop_callback) if loop_callbacks.loop_iter_overrider is not None: