Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 118 additions & 11 deletions tests/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,115 @@
from triton_viz.clients import Profiler


# ======== Case 2: Check if for loop can be unrolled ========
@triton_viz.trace(
clients=(
loop_profiler := Profiler(
disable_buffer_load_check=True, disable_load_mask_percentage_check=True
)
)
)
@triton.jit
def for_loop_test_kernel(
in_ptr,
out_ptr,
N: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Test kernel with for-loops to verify loop statistics tracking.

This kernel contains 4 for-loops:
- First loop: Python range, iterates 10 times
- Second loop: Python range, iterates 5 times
- Third loop: tl.range, iterates 8 times
- Fourth loop: tl.static_range, iterates 6 times
"""
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < N

# Load input
x = tl.load(in_ptr + offsets, mask=mask, other=0.0)

# First for-loop: Python range, 10 iterations
result = x
for i in range(10):
result = result + 1.0

# Second for-loop: Python range, 5 iterations
for j in range(5):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about tl.range and tl.static_range? If already static we don't need to unroll the loop

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Now the loop callback returns range type.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason the current for-loop callbacks are not managed by ClientManager but by class _CombinedLoopHooks. I will move those code back to ClientManager in the future refactoring.

result = result * 2.0

# Third for-loop: tl.range, 8 iterations
for k in tl.range(2, 10): # 2, 3, 4, 5, 6, 7, 8, 9
result = result + 0.5

# Fourth for-loop: tl.static_range, 6 iterations
for m in tl.static_range(0, 12, 2): # 0, 2, 4, 6, 8, 10
result = result - 0.25

# Store result
tl.store(out_ptr + offsets, result, mask=mask)


def test_for_loop_statistics():
"""
Test that the profiler correctly tracks for-loop statistics.
"""
N = 100
BLOCK_SIZE = 32
num_blocks = triton.cdiv(N, BLOCK_SIZE)

# Create input/output tensors
x = torch.randn(N, dtype=torch.float32)
y = torch.empty(N, dtype=torch.float32)

# Run the kernel
grid = (num_blocks,)
for_loop_test_kernel[grid](x, y, N, BLOCK_SIZE)

# Expected loop statistics:
# The kernel has 4 for-loops, but they are executed once per grid
# Each loop should be recorded only once (when first encountered)
# Loop 1: range(10) -> 10 steps, type: python_range
# Loop 2: range(5) -> 5 steps, type: python_range
# Loop 3: tl.range(2, 10) -> 8 steps, type: tl_range
# Loop 4: tl.static_range(0, 12, 2) -> 6 steps, type: tl_static_range

expected_num_loops = 4
expected_loop_steps = [10, 5, 8, 6]
expected_range_types = [
"python_range",
"python_range",
"tl_range",
"tl_static_range",
]

# Verify the loop statistics from profiler
assert (
len(loop_profiler.loop_info) == expected_num_loops
), f"Expected {expected_num_loops} loops, got {len(loop_profiler.loop_info)}"

for idx, (lineno, loop_info) in enumerate(loop_profiler.loop_info.items()):
expected_steps = expected_loop_steps[idx]
expected_type = expected_range_types[idx]

assert (
loop_info.length == expected_steps
), f"Loop #{idx+1}: Expected {expected_steps} steps, got {loop_info.length}"
assert (
loop_info.range_type == expected_type
), f"Loop #{idx+1}: Expected type {expected_type}, got {loop_info.range_type}"
assert isinstance(
lineno, int
), f"Loop #{idx+1}: lineno should be int, got {type(lineno)}"
assert lineno > 0, f"Loop #{idx+1}: lineno should be positive, got {lineno}"


# ======== Case 3: Check masked element percentage for tuning BLOCK_SIZE ========
@triton_viz.trace(clients=(profiler := Profiler(disable_buffer_load_check=True)))
@triton_viz.trace(clients=(mask_profiler := Profiler(disable_buffer_load_check=True)))
@triton.jit
def mask_percentage_test_kernel(
in_ptr,
Expand Down Expand Up @@ -98,27 +205,27 @@ def test_mask_percentage():

# Verify the statistics from profiler
assert (
profiler.load_mask_total_count == expected_load_mask_total
), f"Expected {expected_load_mask_total} total load mask elements, got {profiler.load_mask_total_count}"
mask_profiler.load_mask_total_count == expected_load_mask_total
), f"Expected {expected_load_mask_total} total load mask elements, got {mask_profiler.load_mask_total_count}"
assert (
profiler.load_mask_false_count == expected_load_mask_false
), f"Expected {expected_load_mask_false} false load mask elements, got {profiler.load_mask_false_count}"
mask_profiler.load_mask_false_count == expected_load_mask_false
), f"Expected {expected_load_mask_false} false load mask elements, got {mask_profiler.load_mask_false_count}"
assert (
profiler.store_mask_total_count == expected_store_mask_total
), f"Expected {expected_store_mask_total} total store mask elements, got {profiler.store_mask_total_count}"
mask_profiler.store_mask_total_count == expected_store_mask_total
), f"Expected {expected_store_mask_total} total store mask elements, got {mask_profiler.store_mask_total_count}"
assert (
profiler.store_mask_false_count == expected_store_mask_false
), f"Expected {expected_store_mask_false} false store mask elements, got {profiler.store_mask_false_count}"
mask_profiler.store_mask_false_count == expected_store_mask_false
), f"Expected {expected_store_mask_false} false store mask elements, got {mask_profiler.store_mask_false_count}"

# Verify the masked percentage calculation
expected_load_masked_percentage = (56 / 384) * 100 # ~14.58%
expected_store_masked_percentage = (56 / 384) * 100 # ~14.58%

actual_load_masked_percentage = (
profiler.load_mask_false_count / profiler.load_mask_total_count
mask_profiler.load_mask_false_count / mask_profiler.load_mask_total_count
) * 100
actual_store_masked_percentage = (
profiler.store_mask_false_count / profiler.store_mask_total_count
mask_profiler.store_mask_false_count / mask_profiler.store_mask_total_count
) * 100

assert (
Expand Down
59 changes: 58 additions & 1 deletion triton_viz/clients/profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,24 @@
from .data import LoadStoreBytes
from triton.runtime.interpreter import _get_np_dtype, TensorHandle
import numpy as np
from dataclasses import dataclass, replace
from typing import Callable, Optional


@dataclass(frozen=False)
class LoopInfo:
length: Optional[int] = None
range_type: str = "unknown"


class Profiler(Client):
NAME = "profiler"

def __init__(
self,
callpath: bool = True,
disable_buffer_load_check: bool = False,
disable_for_loop_unroll_check: bool = False,
disable_load_mask_percentage_check: bool = False,
disable_load_store_skipping: bool = False,
block_sampling: bool = False,
Expand All @@ -26,6 +34,10 @@ def __init__(
self.store_bytes = LoadStoreBytes("store", 0, 0)
self.has_buffer_load = False
self.disable_buffer_load_check = disable_buffer_load_check
self.disable_for_loop_unroll_check = disable_for_loop_unroll_check

# For-loop statistics
self.loop_info: dict[int, LoopInfo] = {}
self.disable_load_mask_percentage_check = disable_load_mask_percentage_check
self.disable_load_store_skipping = disable_load_store_skipping
self.block_sampling = block_sampling
Expand Down Expand Up @@ -221,9 +233,54 @@ def pre_addptr_callback(ptr, offset):
return OpCallbacks()

def register_for_loop_callback(self):
return ForLoopCallbacks()
def loop_hook_range_type(lineno: int, range_type: str) -> None:
cur = self.loop_info.get(lineno, LoopInfo())
self.loop_info[lineno] = replace(cur, range_type=range_type)

def loop_hook_before(lineno, iterable):
if self.disable_for_loop_unroll_check:
return

if not isinstance(iterable, range):
return

# Only record each unique loop (by line number) once
# Different blocks will execute the same loop, so we deduplicate by lineno
if self.loop_info[lineno].length is not None:
return

# Record loop information: line number and total steps
length = len(iterable)
# Update length in LoopInfo
cur = self.loop_info.get(lineno, LoopInfo())
self.loop_info[lineno] = replace(cur, length=length)

def loop_hook_after(lineno: int) -> None:
# No action needed after loop for profiler
pass

return ForLoopCallbacks(
range_type_callback=loop_hook_range_type,
before_loop_callback=loop_hook_before,
after_loop_callback=loop_hook_after,
)

def finalize(self) -> list:
# Print for-loop statistics if enabled
if not self.disable_for_loop_unroll_check and self.loop_info:
print("\n" + "=" * 60)
print("Profiler: For-Loop Statistics")
print("=" * 60)
print(f"\nTotal for-loops detected: {len(self.loop_info)}\n")

for idx, (lineno, loop_info) in enumerate(self.loop_info.items(), 1):
print(f"Loop #{idx}:")
print(f" Line number: {lineno}")
print(f" Range type: {loop_info.range_type}")
print(f" Total steps: {loop_info.length}")

print("=" * 60)

# Calculate and print mask statistics only if load mask percentage check is enabled
if not self.disable_load_mask_percentage_check:
print("\n" + "=" * 60)
Expand Down
1 change: 1 addition & 0 deletions triton_viz/core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class OpCallbacks:

@dataclass
class ForLoopCallbacks:
range_type_callback: Optional[Callable] = None
before_loop_callback: Optional[Callable] = None
loop_iter_overrider: Optional[Callable] = None
loop_iter_listener: Optional[Callable] = None
Expand Down
2 changes: 1 addition & 1 deletion triton_viz/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def grid_idx_callback(self, grid_idx: tuple[int, ...]):
...

@abstractmethod
def register_op_callback(self, op: type[Op]) -> OpCallbacks:
def register_op_callback(self, op_type: type[Op]) -> OpCallbacks:
...

@abstractmethod
Expand Down
46 changes: 40 additions & 6 deletions triton_viz/core/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,12 @@ def unpatch_op(op_type: type[Op]):


class _LoopIter:
def __init__(self, iterable, lineno, hooks):
def __init__(self, hooks, iterable, lineno, range_type):
self._it = iter(iterable)
self._lineno = lineno
self._hooks = hooks
# triggering range_type
self._hooks.range_type(self._lineno, range_type)
# triggering before_loop
if self._hooks.before_loop:
self._hooks.before_loop(self._lineno, iterable)
Expand Down Expand Up @@ -293,12 +295,16 @@ class _CombinedLoopHooks:
"""

def __init__(self):
self._range_type: list[Callable] = []
self._before: list[Callable] = []
self._iter_listeners: list[Callable] = []
self._iter_overrider: Optional[Callable] = None
self._after: list[Callable] = []

# Register hooks
def add_range_type_callback(self, hook: Callable) -> None:
self._range_type.append(hook)

def add_before(self, hook: Callable) -> None:
self._before.append(hook)

Expand All @@ -314,6 +320,10 @@ def add_after(self, hook: Callable) -> None:
self._after.append(hook)

# Call combined hooks
def range_type(self, lineno: int, range_type: str) -> None:
for hook in self._range_type:
hook(lineno, range_type)

def before_loop(self, lineno: int, iterable: Any) -> None:
for hook in self._before:
hook(lineno, iterable)
Expand All @@ -335,8 +345,10 @@ def after_loop(self, lineno: int) -> None:
for hook in self._after:
hook(lineno)

def loop_iter_wrapper(self, iterable: Any, lineno: int) -> "_LoopIter":
return _LoopIter(iterable, lineno, self)
def loop_iter_wrapper(
self, iterable: Any, lineno: int, range_type: str
) -> "_LoopIter":
return _LoopIter(self, iterable, lineno, range_type)

def clear(self) -> None:
self._before.clear()
Expand Down Expand Up @@ -380,13 +392,29 @@ def _visit_For(self, node: ast.For): # type: ignore[override]
for i in R:
...
==>
for i in _triton_viz_loop_patcher.hooks.loop_iter_wrapper(R, lineno):
for i in _triton_viz_loop_patcher.hooks.loop_iter_wrapper(R, lineno, range_type):
...
where _triton_viz_loop_patcher.hooks.loop_iter_wrapper returns a _LoopIter object.
"""
self.generic_visit(node)

# _triton_viz_loop_patcher.hooks.loop_iter(range(...), lineno)
# Detect range type
range_type = "unknown"
if isinstance(node.iter, ast.Call):
func = node.iter.func
if isinstance(func, ast.Name) and func.id == "range":
range_type = "python_range"
elif (
isinstance(func, ast.Attribute)
and isinstance(func.value, ast.Name)
and func.value.id == "tl"
):
if func.attr == "range":
range_type = "tl_range"
elif func.attr == "static_range":
range_type = "tl_static_range"

# _triton_viz_loop_patcher.hooks.loop_iter(range(...), lineno, range_type)
new_iter = ast.Call(
func=ast.Attribute(
value=ast.Attribute(
Expand All @@ -397,7 +425,11 @@ def _visit_For(self, node: ast.For): # type: ignore[override]
attr="loop_iter_wrapper",
ctx=ast.Load(),
),
args=[node.iter, ast.Constant(value=node.lineno)],
args=[
node.iter,
ast.Constant(value=node.lineno),
ast.Constant(value=range_type),
],
keywords=[],
)

Expand All @@ -415,6 +447,8 @@ def patch_for_loop(loop_callbacks: ForLoopCallbacks):
_loop_patcher.patch()

# Registering hooks
if loop_callbacks.range_type_callback is not None:
_loop_patcher.hooks.add_range_type_callback(loop_callbacks.range_type_callback)
if loop_callbacks.before_loop_callback is not None:
_loop_patcher.hooks.add_before(loop_callbacks.before_loop_callback)
if loop_callbacks.loop_iter_overrider is not None:
Expand Down