|
8 | 8 | from triton_viz.clients import Profiler |
9 | 9 |
|
10 | 10 |
|
| 11 | +# ======== Case 2: Check if for loop can be unrolled ======== |
| 12 | +@triton_viz.trace( |
| 13 | + clients=( |
| 14 | + loop_profiler := Profiler( |
| 15 | + disable_buffer_load_check=True, disable_load_mask_percentage_check=True |
| 16 | + ) |
| 17 | + ) |
| 18 | +) |
| 19 | +@triton.jit |
| 20 | +def for_loop_test_kernel( |
| 21 | + in_ptr, |
| 22 | + out_ptr, |
| 23 | + N: tl.constexpr, |
| 24 | + BLOCK_SIZE: tl.constexpr, |
| 25 | +): |
| 26 | + """ |
| 27 | + Test kernel with for-loops to verify loop statistics tracking. |
| 28 | +
|
| 29 | + This kernel contains 4 for-loops: |
| 30 | + - First loop: Python range, iterates 10 times |
| 31 | + - Second loop: Python range, iterates 5 times |
| 32 | + - Third loop: tl.range, iterates 8 times |
| 33 | + - Fourth loop: tl.static_range, iterates 6 times |
| 34 | + """ |
| 35 | + pid = tl.program_id(axis=0) |
| 36 | + block_start = pid * BLOCK_SIZE |
| 37 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 38 | + mask = offsets < N |
| 39 | + |
| 40 | + # Load input |
| 41 | + x = tl.load(in_ptr + offsets, mask=mask, other=0.0) |
| 42 | + |
| 43 | + # First for-loop: Python range, 10 iterations |
| 44 | + result = x |
| 45 | + for i in range(10): |
| 46 | + result = result + 1.0 |
| 47 | + |
| 48 | + # Second for-loop: Python range, 5 iterations |
| 49 | + for j in range(5): |
| 50 | + result = result * 2.0 |
| 51 | + |
| 52 | + # Third for-loop: tl.range, 8 iterations |
| 53 | + for k in tl.range(2, 10): # 2, 3, 4, 5, 6, 7, 8, 9 |
| 54 | + result = result + 0.5 |
| 55 | + |
| 56 | + # Fourth for-loop: tl.static_range, 6 iterations |
| 57 | + for m in tl.static_range(0, 12, 2): # 0, 2, 4, 6, 8, 10 |
| 58 | + result = result - 0.25 |
| 59 | + |
| 60 | + # Store result |
| 61 | + tl.store(out_ptr + offsets, result, mask=mask) |
| 62 | + |
| 63 | + |
| 64 | +def test_for_loop_statistics(): |
| 65 | + """ |
| 66 | + Test that the profiler correctly tracks for-loop statistics. |
| 67 | + """ |
| 68 | + N = 100 |
| 69 | + BLOCK_SIZE = 32 |
| 70 | + num_blocks = triton.cdiv(N, BLOCK_SIZE) |
| 71 | + |
| 72 | + # Create input/output tensors |
| 73 | + x = torch.randn(N, dtype=torch.float32) |
| 74 | + y = torch.empty(N, dtype=torch.float32) |
| 75 | + |
| 76 | + # Run the kernel |
| 77 | + grid = (num_blocks,) |
| 78 | + for_loop_test_kernel[grid](x, y, N, BLOCK_SIZE) |
| 79 | + |
| 80 | + # Expected loop statistics: |
| 81 | + # The kernel has 4 for-loops, but they are executed once per grid |
| 82 | + # Each loop should be recorded only once (when first encountered) |
| 83 | + # Loop 1: range(10) -> 10 steps, type: python_range |
| 84 | + # Loop 2: range(5) -> 5 steps, type: python_range |
| 85 | + # Loop 3: tl.range(2, 10) -> 8 steps, type: tl_range |
| 86 | + # Loop 4: tl.static_range(0, 12, 2) -> 6 steps, type: tl_static_range |
| 87 | + |
| 88 | + expected_num_loops = 4 |
| 89 | + expected_loop_steps = [10, 5, 8, 6] |
| 90 | + expected_range_types = [ |
| 91 | + "python_range", |
| 92 | + "python_range", |
| 93 | + "tl_range", |
| 94 | + "tl_static_range", |
| 95 | + ] |
| 96 | + |
| 97 | + # Verify the loop statistics from profiler |
| 98 | + assert ( |
| 99 | + len(loop_profiler.loop_info) == expected_num_loops |
| 100 | + ), f"Expected {expected_num_loops} loops, got {len(loop_profiler.loop_info)}" |
| 101 | + |
| 102 | + for idx, (lineno, loop_info) in enumerate(loop_profiler.loop_info.items()): |
| 103 | + expected_steps = expected_loop_steps[idx] |
| 104 | + expected_type = expected_range_types[idx] |
| 105 | + |
| 106 | + assert ( |
| 107 | + loop_info.length == expected_steps |
| 108 | + ), f"Loop #{idx+1}: Expected {expected_steps} steps, got {loop_info.length}" |
| 109 | + assert ( |
| 110 | + loop_info.range_type == expected_type |
| 111 | + ), f"Loop #{idx+1}: Expected type {expected_type}, got {loop_info.range_type}" |
| 112 | + assert isinstance( |
| 113 | + lineno, int |
| 114 | + ), f"Loop #{idx+1}: lineno should be int, got {type(lineno)}" |
| 115 | + assert lineno > 0, f"Loop #{idx+1}: lineno should be positive, got {lineno}" |
| 116 | + |
| 117 | + |
11 | 118 | # ======== Case 3: Check masked element percentage for tuning BLOCK_SIZE ======== |
12 | | -@triton_viz.trace(clients=(profiler := Profiler(disable_buffer_load_check=True))) |
| 119 | +@triton_viz.trace(clients=(mask_profiler := Profiler(disable_buffer_load_check=True))) |
13 | 120 | @triton.jit |
14 | 121 | def mask_percentage_test_kernel( |
15 | 122 | in_ptr, |
@@ -98,27 +205,27 @@ def test_mask_percentage(): |
98 | 205 |
|
99 | 206 | # Verify the statistics from profiler |
100 | 207 | assert ( |
101 | | - profiler.load_mask_total_count == expected_load_mask_total |
102 | | - ), f"Expected {expected_load_mask_total} total load mask elements, got {profiler.load_mask_total_count}" |
| 208 | + mask_profiler.load_mask_total_count == expected_load_mask_total |
| 209 | + ), f"Expected {expected_load_mask_total} total load mask elements, got {mask_profiler.load_mask_total_count}" |
103 | 210 | assert ( |
104 | | - profiler.load_mask_false_count == expected_load_mask_false |
105 | | - ), f"Expected {expected_load_mask_false} false load mask elements, got {profiler.load_mask_false_count}" |
| 211 | + mask_profiler.load_mask_false_count == expected_load_mask_false |
| 212 | + ), f"Expected {expected_load_mask_false} false load mask elements, got {mask_profiler.load_mask_false_count}" |
106 | 213 | assert ( |
107 | | - profiler.store_mask_total_count == expected_store_mask_total |
108 | | - ), f"Expected {expected_store_mask_total} total store mask elements, got {profiler.store_mask_total_count}" |
| 214 | + mask_profiler.store_mask_total_count == expected_store_mask_total |
| 215 | + ), f"Expected {expected_store_mask_total} total store mask elements, got {mask_profiler.store_mask_total_count}" |
109 | 216 | assert ( |
110 | | - profiler.store_mask_false_count == expected_store_mask_false |
111 | | - ), f"Expected {expected_store_mask_false} false store mask elements, got {profiler.store_mask_false_count}" |
| 217 | + mask_profiler.store_mask_false_count == expected_store_mask_false |
| 218 | + ), f"Expected {expected_store_mask_false} false store mask elements, got {mask_profiler.store_mask_false_count}" |
112 | 219 |
|
113 | 220 | # Verify the masked percentage calculation |
114 | 221 | expected_load_masked_percentage = (56 / 384) * 100 # ~14.58% |
115 | 222 | expected_store_masked_percentage = (56 / 384) * 100 # ~14.58% |
116 | 223 |
|
117 | 224 | actual_load_masked_percentage = ( |
118 | | - profiler.load_mask_false_count / profiler.load_mask_total_count |
| 225 | + mask_profiler.load_mask_false_count / mask_profiler.load_mask_total_count |
119 | 226 | ) * 100 |
120 | 227 | actual_store_masked_percentage = ( |
121 | | - profiler.store_mask_false_count / profiler.store_mask_total_count |
| 228 | + mask_profiler.store_mask_false_count / mask_profiler.store_mask_total_count |
122 | 229 | ) * 100 |
123 | 230 |
|
124 | 231 | assert ( |
|
0 commit comments