From c644c514499c4a62198ce48ce6f55fd7bef4bc88 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Sat, 18 Oct 2025 00:21:16 -0400 Subject: [PATCH 1/4] [DEV][PROFILER] Case 2: Add for-loop unroll tracking to Profiler Add functionality to track for-loop iteration counts in the Profiler client to help identify loops that can potentially be unrolled. Changes: - Add disable_for_loop_unroll_check parameter to Profiler.__init__ - Implement register_for_loop_callback to track loop line numbers and step counts - Deduplicate loop tracking by line number (same loop executed in different blocks) - Display loop statistics in finalize() method - Add test_for_loop_statistics test to verify tracking functionality --- tests/test_profiler.py | 85 +++++++++++++++++++++++++ triton_viz/clients/profiler/profiler.py | 56 +++++++++++++++- 2 files changed, 139 insertions(+), 2 deletions(-) diff --git a/tests/test_profiler.py b/tests/test_profiler.py index e69de29b..90b41b65 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -0,0 +1,85 @@ +import torch + +import triton +import triton.language as tl + +import triton_viz +from triton_viz.clients import Profiler + + +# ======== Case 2: Check if for loop can be unrolled ======== +@triton_viz.trace(clients=(profiler := Profiler(disable_buffer_load_check=True))) +@triton.jit +def for_loop_test_kernel( + in_ptr, + out_ptr, + N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + Test kernel with for-loops to verify loop statistics tracking. + + This kernel contains 2 for-loops: + - First loop: iterates 10 times + - Second loop: iterates 5 times + """ + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < N + + # Load input + x = tl.load(in_ptr + offsets, mask=mask, other=0.0) + + # First for-loop: 10 iterations + result = x + for i in range(10): + result = result + 1.0 + + # Second for-loop: 5 iterations + for j in range(5): + result = result * 2.0 + + # Store result + tl.store(out_ptr + offsets, result, mask=mask) + + +def test_for_loop_statistics(): + """ + Test that the profiler correctly tracks for-loop statistics. + """ + N = 100 + BLOCK_SIZE = 32 + num_blocks = triton.cdiv(N, BLOCK_SIZE) + + # Create input/output tensors + x = torch.randn(N, dtype=torch.float32) + y = torch.empty(N, dtype=torch.float32) + + # Run the kernel + grid = (num_blocks,) + for_loop_test_kernel[grid](x, y, N, BLOCK_SIZE) + + # Expected loop statistics: + # The kernel has 2 for-loops, but they are executed once per grid + # Each loop should be recorded only once (when first encountered) + # Loop 1: range(10) -> 10 steps + # Loop 2: range(5) -> 5 steps + + expected_num_loops = 2 + expected_loop_steps = [10, 5] + + # Verify the loop statistics from profiler + assert ( + len(profiler.loop_info) == expected_num_loops + ), f"Expected {expected_num_loops} loops, got {len(profiler.loop_info)}" + + for idx, (lineno, total_steps) in enumerate(profiler.loop_info): + expected_steps = expected_loop_steps[idx] + assert ( + total_steps == expected_steps + ), f"Loop #{idx+1}: Expected {expected_steps} steps, got {total_steps}" + assert isinstance( + lineno, int + ), f"Loop #{idx+1}: lineno should be int, got {type(lineno)}" + assert lineno > 0, f"Loop #{idx+1}: lineno should be positive, got {lineno}" diff --git a/triton_viz/clients/profiler/profiler.py b/triton_viz/clients/profiler/profiler.py index 2668b6ee..09fae6bf 100644 --- a/triton_viz/clients/profiler/profiler.py +++ b/triton_viz/clients/profiler/profiler.py @@ -10,7 +10,12 @@ class Profiler(Client): NAME = "profiler" - def __init__(self, callpath: bool = True, disable_buffer_load_check: bool = False): + def __init__( + self, + callpath: bool = True, + disable_buffer_load_check: bool = False, + disable_for_loop_unroll_check: bool = False, + ): super().__init__() # Initialize parent class # Enable ASM collection for the profiler self.callpath = callpath @@ -18,6 +23,15 @@ def __init__(self, callpath: bool = True, disable_buffer_load_check: bool = Fals self.store_bytes = LoadStoreBytes("store", 0, 0) self.has_buffer_load = False self.disable_buffer_load_check = disable_buffer_load_check + self.disable_for_loop_unroll_check = disable_for_loop_unroll_check + + # For-loop statistics + self.loop_info: list[ + tuple[int, int] + ] = [] # List of (lineno, total_steps) tuples + self.loop_linenos_seen: set[ + int + ] = set() # Set to track already seen line numbers def pre_run_callback(self, fn: Callable) -> bool: return True @@ -125,7 +139,45 @@ def pre_addptr_callback(ptr, offset): return OpCallbacks() def register_for_loop_callback(self): - return ForLoopCallbacks() + def loop_hook_before(lineno, iterable): + if self.disable_for_loop_unroll_check: + return + + if not isinstance(iterable, range): + return + + # Only record each unique loop (by line number) once + # Different blocks will execute the same loop, so we deduplicate by lineno + if lineno in self.loop_linenos_seen: + return + + # Record loop information: line number and total steps + length = len(iterable) + self.loop_info.append((lineno, length)) + self.loop_linenos_seen.add(lineno) + + def loop_hook_after(lineno: int) -> None: + # No action needed after loop for profiler + pass + + return ForLoopCallbacks( + before_loop_callback=loop_hook_before, + after_loop_callback=loop_hook_after, + ) def finalize(self) -> list: + # Print for-loop statistics if enabled + if not self.disable_for_loop_unroll_check and self.loop_info: + print("\n" + "=" * 60) + print("Profiler: For-Loop Statistics") + print("=" * 60) + print(f"\nTotal for-loops detected: {len(self.loop_info)}\n") + + for idx, (lineno, total_steps) in enumerate(self.loop_info, 1): + print(f"Loop #{idx}:") + print(f" Line number: {lineno}") + print(f" Total steps: {total_steps}") + + print("=" * 60 + "\n") + return [self.load_bytes, self.store_bytes] From 6ee5a5be1fa37ae21f4d3c1defda528c25d6b3b2 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 23 Oct 2025 12:35:33 -0400 Subject: [PATCH 2/4] fix merge issues and lint --- tests/test_profiler.py | 7 +++++++ triton_viz/clients/profiler/profiler.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/test_profiler.py b/tests/test_profiler.py index b4344bc9..d94926d7 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -84,10 +84,17 @@ def test_for_loop_statistics(): ), f"Loop #{idx+1}: lineno should be int, got {type(lineno)}" assert lineno > 0, f"Loop #{idx+1}: lineno should be positive, got {lineno}" + # ======== Case 3: Check masked element percentage for tuning BLOCK_SIZE ======== @triton_viz.trace(clients=(profiler := Profiler(disable_buffer_load_check=True))) @triton.jit def mask_percentage_test_kernel( + in_ptr, + out_ptr, + N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ Test kernel with a known mix of masked and unmasked load/store operations. Expected operations per block: diff --git a/triton_viz/clients/profiler/profiler.py b/triton_viz/clients/profiler/profiler.py index 9e37cf27..a596108b 100644 --- a/triton_viz/clients/profiler/profiler.py +++ b/triton_viz/clients/profiler/profiler.py @@ -214,7 +214,7 @@ def finalize(self) -> list: print(f"Loop #{idx}:") print(f" Line number: {lineno}") print(f" Total steps: {total_steps}") - + # Calculate and print mask statistics only if load mask percentage check is enabled if not self.disable_load_mask_percentage_check: print("\n" + "=" * 60) From 90c3a0f700f381bcfe29b865e8f0bd957ffd73f0 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Mon, 3 Nov 2025 14:22:55 -0500 Subject: [PATCH 3/4] resolve name conflicts --- tests/test_profiler.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 90162e23..74ce9267 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -9,7 +9,7 @@ # ======== Case 2: Check if for loop can be unrolled ======== -@triton_viz.trace(clients=(profiler := Profiler(disable_buffer_load_check=True))) +@triton_viz.trace(clients=(loop_profiler := Profiler(disable_buffer_load_check=True))) @triton.jit def for_loop_test_kernel( in_ptr, @@ -72,10 +72,10 @@ def test_for_loop_statistics(): # Verify the loop statistics from profiler assert ( - len(profiler.loop_info) == expected_num_loops - ), f"Expected {expected_num_loops} loops, got {len(profiler.loop_info)}" + len(loop_profiler.loop_info) == expected_num_loops + ), f"Expected {expected_num_loops} loops, got {len(loop_profiler.loop_info)}" - for idx, (lineno, total_steps) in enumerate(profiler.loop_info): + for idx, (lineno, total_steps) in enumerate(loop_profiler.loop_info): expected_steps = expected_loop_steps[idx] assert ( total_steps == expected_steps @@ -87,7 +87,7 @@ def test_for_loop_statistics(): # ======== Case 3: Check masked element percentage for tuning BLOCK_SIZE ======== -@triton_viz.trace(clients=(profiler := Profiler(disable_buffer_load_check=True))) +@triton_viz.trace(clients=(mask_profiler := Profiler(disable_buffer_load_check=True))) @triton.jit def mask_percentage_test_kernel( in_ptr, @@ -176,27 +176,27 @@ def test_mask_percentage(): # Verify the statistics from profiler assert ( - profiler.load_mask_total_count == expected_load_mask_total - ), f"Expected {expected_load_mask_total} total load mask elements, got {profiler.load_mask_total_count}" + mask_profiler.load_mask_total_count == expected_load_mask_total + ), f"Expected {expected_load_mask_total} total load mask elements, got {mask_profiler.load_mask_total_count}" assert ( - profiler.load_mask_false_count == expected_load_mask_false - ), f"Expected {expected_load_mask_false} false load mask elements, got {profiler.load_mask_false_count}" + mask_profiler.load_mask_false_count == expected_load_mask_false + ), f"Expected {expected_load_mask_false} false load mask elements, got {mask_profiler.load_mask_false_count}" assert ( - profiler.store_mask_total_count == expected_store_mask_total - ), f"Expected {expected_store_mask_total} total store mask elements, got {profiler.store_mask_total_count}" + mask_profiler.store_mask_total_count == expected_store_mask_total + ), f"Expected {expected_store_mask_total} total store mask elements, got {mask_profiler.store_mask_total_count}" assert ( - profiler.store_mask_false_count == expected_store_mask_false - ), f"Expected {expected_store_mask_false} false store mask elements, got {profiler.store_mask_false_count}" + mask_profiler.store_mask_false_count == expected_store_mask_false + ), f"Expected {expected_store_mask_false} false store mask elements, got {mask_profiler.store_mask_false_count}" # Verify the masked percentage calculation expected_load_masked_percentage = (56 / 384) * 100 # ~14.58% expected_store_masked_percentage = (56 / 384) * 100 # ~14.58% actual_load_masked_percentage = ( - profiler.load_mask_false_count / profiler.load_mask_total_count + mask_profiler.load_mask_false_count / mask_profiler.load_mask_total_count ) * 100 actual_store_masked_percentage = ( - profiler.store_mask_false_count / profiler.store_mask_total_count + mask_profiler.store_mask_false_count / mask_profiler.store_mask_total_count ) * 100 assert ( From 7d8032707b8e0ac6ab513d8d2c559207b1811f03 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Mon, 3 Nov 2025 17:55:31 -0500 Subject: [PATCH 4/4] [TEST][PROFILER] Add tests for tl.range and tl.static_range loop types - Extended for_loop_test_kernel to include tl.range and tl.static_range loops - Updated test_for_loop_statistics to verify different range types (python_range, tl_range, tl_static_range) - Modified test assertions to match new loop_info structure with range_type and length attributes - Added validation for 4 different loop types with their expected iteration counts --- tests/test_profiler.py | 59 ++++++++++++++++++------- triton_viz/clients/profiler/profiler.py | 33 +++++++++----- triton_viz/core/callbacks.py | 1 + triton_viz/core/client.py | 2 +- triton_viz/core/patch.py | 46 ++++++++++++++++--- 5 files changed, 108 insertions(+), 33 deletions(-) diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 74ce9267..1d97c0fe 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -9,7 +9,13 @@ # ======== Case 2: Check if for loop can be unrolled ======== -@triton_viz.trace(clients=(loop_profiler := Profiler(disable_buffer_load_check=True))) +@triton_viz.trace( + clients=( + loop_profiler := Profiler( + disable_buffer_load_check=True, disable_load_mask_percentage_check=True + ) + ) +) @triton.jit def for_loop_test_kernel( in_ptr, @@ -20,9 +26,11 @@ def for_loop_test_kernel( """ Test kernel with for-loops to verify loop statistics tracking. - This kernel contains 2 for-loops: - - First loop: iterates 10 times - - Second loop: iterates 5 times + This kernel contains 4 for-loops: + - First loop: Python range, iterates 10 times + - Second loop: Python range, iterates 5 times + - Third loop: tl.range, iterates 8 times + - Fourth loop: tl.static_range, iterates 6 times """ pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE @@ -32,15 +40,23 @@ def for_loop_test_kernel( # Load input x = tl.load(in_ptr + offsets, mask=mask, other=0.0) - # First for-loop: 10 iterations + # First for-loop: Python range, 10 iterations result = x for i in range(10): result = result + 1.0 - # Second for-loop: 5 iterations + # Second for-loop: Python range, 5 iterations for j in range(5): result = result * 2.0 + # Third for-loop: tl.range, 8 iterations + for k in tl.range(2, 10): # 2, 3, 4, 5, 6, 7, 8, 9 + result = result + 0.5 + + # Fourth for-loop: tl.static_range, 6 iterations + for m in tl.static_range(0, 12, 2): # 0, 2, 4, 6, 8, 10 + result = result - 0.25 + # Store result tl.store(out_ptr + offsets, result, mask=mask) @@ -62,24 +78,37 @@ def test_for_loop_statistics(): for_loop_test_kernel[grid](x, y, N, BLOCK_SIZE) # Expected loop statistics: - # The kernel has 2 for-loops, but they are executed once per grid + # The kernel has 4 for-loops, but they are executed once per grid # Each loop should be recorded only once (when first encountered) - # Loop 1: range(10) -> 10 steps - # Loop 2: range(5) -> 5 steps - - expected_num_loops = 2 - expected_loop_steps = [10, 5] + # Loop 1: range(10) -> 10 steps, type: python_range + # Loop 2: range(5) -> 5 steps, type: python_range + # Loop 3: tl.range(2, 10) -> 8 steps, type: tl_range + # Loop 4: tl.static_range(0, 12, 2) -> 6 steps, type: tl_static_range + + expected_num_loops = 4 + expected_loop_steps = [10, 5, 8, 6] + expected_range_types = [ + "python_range", + "python_range", + "tl_range", + "tl_static_range", + ] # Verify the loop statistics from profiler assert ( len(loop_profiler.loop_info) == expected_num_loops ), f"Expected {expected_num_loops} loops, got {len(loop_profiler.loop_info)}" - for idx, (lineno, total_steps) in enumerate(loop_profiler.loop_info): + for idx, (lineno, loop_info) in enumerate(loop_profiler.loop_info.items()): expected_steps = expected_loop_steps[idx] + expected_type = expected_range_types[idx] + + assert ( + loop_info.length == expected_steps + ), f"Loop #{idx+1}: Expected {expected_steps} steps, got {loop_info.length}" assert ( - total_steps == expected_steps - ), f"Loop #{idx+1}: Expected {expected_steps} steps, got {total_steps}" + loop_info.range_type == expected_type + ), f"Loop #{idx+1}: Expected type {expected_type}, got {loop_info.range_type}" assert isinstance( lineno, int ), f"Loop #{idx+1}: lineno should be int, got {type(lineno)}" diff --git a/triton_viz/clients/profiler/profiler.py b/triton_viz/clients/profiler/profiler.py index 7c83d8fb..297901e4 100644 --- a/triton_viz/clients/profiler/profiler.py +++ b/triton_viz/clients/profiler/profiler.py @@ -4,9 +4,16 @@ from .data import LoadStoreBytes from triton.runtime.interpreter import _get_np_dtype, TensorHandle import numpy as np +from dataclasses import dataclass, replace from typing import Callable, Optional +@dataclass(frozen=False) +class LoopInfo: + length: Optional[int] = None + range_type: str = "unknown" + + class Profiler(Client): NAME = "profiler" @@ -29,12 +36,7 @@ def __init__( self.disable_for_loop_unroll_check = disable_for_loop_unroll_check # For-loop statistics - self.loop_info: list[ - tuple[int, int] - ] = [] # List of (lineno, total_steps) tuples - self.loop_linenos_seen: set[ - int - ] = set() # Set to track already seen line numbers + self.loop_info: dict[int, LoopInfo] = {} self.disable_load_mask_percentage_check = disable_load_mask_percentage_check self.block_sampling = block_sampling self.k = k @@ -207,6 +209,10 @@ def pre_addptr_callback(ptr, offset): return OpCallbacks() def register_for_loop_callback(self): + def loop_hook_range_type(lineno: int, range_type: str) -> None: + cur = self.loop_info.get(lineno, LoopInfo()) + self.loop_info[lineno] = replace(cur, range_type=range_type) + def loop_hook_before(lineno, iterable): if self.disable_for_loop_unroll_check: return @@ -216,19 +222,21 @@ def loop_hook_before(lineno, iterable): # Only record each unique loop (by line number) once # Different blocks will execute the same loop, so we deduplicate by lineno - if lineno in self.loop_linenos_seen: + if self.loop_info[lineno].length is not None: return # Record loop information: line number and total steps length = len(iterable) - self.loop_info.append((lineno, length)) - self.loop_linenos_seen.add(lineno) + # Update length in LoopInfo + cur = self.loop_info.get(lineno, LoopInfo()) + self.loop_info[lineno] = replace(cur, length=length) def loop_hook_after(lineno: int) -> None: # No action needed after loop for profiler pass return ForLoopCallbacks( + range_type_callback=loop_hook_range_type, before_loop_callback=loop_hook_before, after_loop_callback=loop_hook_after, ) @@ -241,10 +249,13 @@ def finalize(self) -> list: print("=" * 60) print(f"\nTotal for-loops detected: {len(self.loop_info)}\n") - for idx, (lineno, total_steps) in enumerate(self.loop_info, 1): + for idx, (lineno, loop_info) in enumerate(self.loop_info.items(), 1): print(f"Loop #{idx}:") print(f" Line number: {lineno}") - print(f" Total steps: {total_steps}") + print(f" Range type: {loop_info.range_type}") + print(f" Total steps: {loop_info.length}") + + print("=" * 60) # Calculate and print mask statistics only if load mask percentage check is enabled if not self.disable_load_mask_percentage_check: diff --git a/triton_viz/core/callbacks.py b/triton_viz/core/callbacks.py index bdac8a50..2db57ac1 100644 --- a/triton_viz/core/callbacks.py +++ b/triton_viz/core/callbacks.py @@ -12,6 +12,7 @@ class OpCallbacks: @dataclass class ForLoopCallbacks: + range_type_callback: Optional[Callable] = None before_loop_callback: Optional[Callable] = None loop_iter_overrider: Optional[Callable] = None loop_iter_listener: Optional[Callable] = None diff --git a/triton_viz/core/client.py b/triton_viz/core/client.py index e1391f0b..8fb00feb 100644 --- a/triton_viz/core/client.py +++ b/triton_viz/core/client.py @@ -54,7 +54,7 @@ def grid_idx_callback(self, grid_idx: tuple[int, ...]): ... @abstractmethod - def register_op_callback(self, op: type[Op]) -> OpCallbacks: + def register_op_callback(self, op_type: type[Op]) -> OpCallbacks: ... @abstractmethod diff --git a/triton_viz/core/patch.py b/triton_viz/core/patch.py index ae5544da..e08d3015 100644 --- a/triton_viz/core/patch.py +++ b/triton_viz/core/patch.py @@ -261,10 +261,12 @@ def unpatch_op(op_type: type[Op]): class _LoopIter: - def __init__(self, iterable, lineno, hooks): + def __init__(self, hooks, iterable, lineno, range_type): self._it = iter(iterable) self._lineno = lineno self._hooks = hooks + # triggering range_type + self._hooks.range_type(self._lineno, range_type) # triggering before_loop if self._hooks.before_loop: self._hooks.before_loop(self._lineno, iterable) @@ -293,12 +295,16 @@ class _CombinedLoopHooks: """ def __init__(self): + self._range_type: list[Callable] = [] self._before: list[Callable] = [] self._iter_listeners: list[Callable] = [] self._iter_overrider: Optional[Callable] = None self._after: list[Callable] = [] # Register hooks + def add_range_type_callback(self, hook: Callable) -> None: + self._range_type.append(hook) + def add_before(self, hook: Callable) -> None: self._before.append(hook) @@ -314,6 +320,10 @@ def add_after(self, hook: Callable) -> None: self._after.append(hook) # Call combined hooks + def range_type(self, lineno: int, range_type: str) -> None: + for hook in self._range_type: + hook(lineno, range_type) + def before_loop(self, lineno: int, iterable: Any) -> None: for hook in self._before: hook(lineno, iterable) @@ -335,8 +345,10 @@ def after_loop(self, lineno: int) -> None: for hook in self._after: hook(lineno) - def loop_iter_wrapper(self, iterable: Any, lineno: int) -> "_LoopIter": - return _LoopIter(iterable, lineno, self) + def loop_iter_wrapper( + self, iterable: Any, lineno: int, range_type: str + ) -> "_LoopIter": + return _LoopIter(self, iterable, lineno, range_type) def clear(self) -> None: self._before.clear() @@ -380,13 +392,29 @@ def _visit_For(self, node: ast.For): # type: ignore[override] for i in R: ... ==> - for i in _triton_viz_loop_patcher.hooks.loop_iter_wrapper(R, lineno): + for i in _triton_viz_loop_patcher.hooks.loop_iter_wrapper(R, lineno, range_type): ... where _triton_viz_loop_patcher.hooks.loop_iter_wrapper returns a _LoopIter object. """ self.generic_visit(node) - # _triton_viz_loop_patcher.hooks.loop_iter(range(...), lineno) + # Detect range type + range_type = "unknown" + if isinstance(node.iter, ast.Call): + func = node.iter.func + if isinstance(func, ast.Name) and func.id == "range": + range_type = "python_range" + elif ( + isinstance(func, ast.Attribute) + and isinstance(func.value, ast.Name) + and func.value.id == "tl" + ): + if func.attr == "range": + range_type = "tl_range" + elif func.attr == "static_range": + range_type = "tl_static_range" + + # _triton_viz_loop_patcher.hooks.loop_iter(range(...), lineno, range_type) new_iter = ast.Call( func=ast.Attribute( value=ast.Attribute( @@ -397,7 +425,11 @@ def _visit_For(self, node: ast.For): # type: ignore[override] attr="loop_iter_wrapper", ctx=ast.Load(), ), - args=[node.iter, ast.Constant(value=node.lineno)], + args=[ + node.iter, + ast.Constant(value=node.lineno), + ast.Constant(value=range_type), + ], keywords=[], ) @@ -415,6 +447,8 @@ def patch_for_loop(loop_callbacks: ForLoopCallbacks): _loop_patcher.patch() # Registering hooks + if loop_callbacks.range_type_callback is not None: + _loop_patcher.hooks.add_range_type_callback(loop_callbacks.range_type_callback) if loop_callbacks.before_loop_callback is not None: _loop_patcher.hooks.add_before(loop_callbacks.before_loop_callback) if loop_callbacks.loop_iter_overrider is not None: