Skip to content

Commit 7b89875

Browse files
CopilotJokeren
andauthored
Add support for atomic RMW operations (tl.atomic_add, etc.) (#200)
Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: Jokeren <[email protected]> Co-authored-by: Jokeren <[email protected]>
1 parent ca09f14 commit 7b89875

File tree

4 files changed

+63
-1
lines changed

4 files changed

+63
-1
lines changed

tests/test_sanitizer.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,43 @@ def test_copy_kernel():
351351
N,
352352
TILE_N=128,
353353
)
354+
355+
356+
# ======== Atomic Operations Tests =========
357+
@triton_viz.trace(clients=Sanitizer(abort_on_error=True))
358+
@triton.jit
359+
def atomic_add_kernel(
360+
output_ptr,
361+
value: tl.constexpr,
362+
):
363+
# Simple atomic add operation
364+
tl.atomic_add(output_ptr, value)
365+
366+
367+
def test_atomic_add():
368+
"""Test that atomic_add operations work with the sanitizer."""
369+
y = torch.zeros(1, dtype=torch.float32)
370+
grid = (1,)
371+
atomic_add_kernel[grid](y, value=5.0)
372+
# Note: The sanitizer analyzes symbolically, so the actual value may not be updated
373+
# This test verifies that the operation doesn't crash
374+
375+
376+
@triton_viz.trace(clients=Sanitizer(abort_on_error=True))
377+
@triton.jit
378+
def atomic_cas_kernel(
379+
output_ptr,
380+
cmp_value: tl.constexpr,
381+
new_value: tl.constexpr,
382+
):
383+
# Simple atomic compare-and-swap operation
384+
tl.atomic_cas(output_ptr, cmp_value, new_value)
385+
386+
387+
def test_atomic_cas():
388+
"""Test that atomic_cas operations work with the sanitizer."""
389+
y = torch.zeros(1, dtype=torch.float32)
390+
grid = (1,)
391+
atomic_cas_kernel[grid](y, cmp_value=0.0, new_value=5.0)
392+
# Note: The sanitizer analyzes symbolically, so the actual value may not be updated
393+
# This test verifies that the operation doesn't crash

triton_viz/clients/sanitizer/sanitizer.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
CumSum,
6666
Bitcast,
6767
AtomicCas,
68+
AtomicRMW,
6869
)
6970
from ..utils import (
7071
check_out_of_bounds_access,
@@ -723,7 +724,7 @@ class SymbolicExpr:
723724
POINTER_OPS = ("make_block_ptr", "addptr", "advance")
724725
BROADCAST_OPS = ("splat", "expand_dims", "broadcast", "reshape", "join")
725726
CAST_OPS = ("cast_impl", "bitcast")
726-
ATOMIC_OPS = ("atomic_cas",)
727+
ATOMIC_OPS = ("atomic_cas", "atomic_rmw")
727728
SUPPORTED_OPS = (
728729
BASIC_OPS
729730
+ INDIRECT_OPS
@@ -814,6 +815,7 @@ class SymbolicExpr:
814815
"bitcast": Spec(req=("src", "dst_type"), post=_cast_impl_post),
815816
# Atomic operations
816817
"atomic_cas": Spec(req=("ptr", "cmp", "val")),
818+
"atomic_rmw": Spec(req=("ptr", "val", "mask")),
817819
# Misc
818820
"advance": Spec(req=("ptr", "offsets")),
819821
"umulhi": Spec(req=("lhs", "rhs")),
@@ -1975,6 +1977,16 @@ def op_atomic_cas_overrider(ptr, cmp, val, sem, scope):
19751977
result.sem = sem # Store sem as an attribute
19761978
return result
19771979

1980+
def op_atomic_rmw_overrider(rmwOp, ptr, val, mask, sem, scope):
1981+
ptr_sym = SymbolicExpr.from_value(ptr)
1982+
val_sym = SymbolicExpr.from_value(val)
1983+
mask_sym = SymbolicExpr.from_value(mask)
1984+
# rmwOp and sem are enums, not regular values, so we pass them directly
1985+
result = SymbolicExpr("atomic_rmw", ptr_sym, val_sym, mask_sym)
1986+
result.rmwOp = rmwOp # Store rmwOp as an attribute
1987+
result.sem = sem # Store sem as an attribute
1988+
return result
1989+
19781990
OP_TYPE_TO_OVERRIDER: dict[type[Op], Callable] = {
19791991
ProgramId: op_program_id_overrider,
19801992
RawLoad: op_raw_load_overrider,
@@ -2010,6 +2022,7 @@ def op_atomic_cas_overrider(ptr, cmp, val, sem, scope):
20102022
CumSum: op_cumsum_overrider,
20112023
Bitcast: op_bitcast_overrider,
20122024
AtomicCas: op_atomic_cas_overrider,
2025+
AtomicRMW: op_atomic_rmw_overrider,
20132026
}
20142027

20152028
if op_type in OP_TYPE_TO_OVERRIDER:

triton_viz/core/data.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,11 @@ class AtomicCas(Op):
238238
name: ClassVar[str] = "atomic_cas"
239239

240240

241+
@dataclass
242+
class AtomicRMW(Op):
243+
name: ClassVar[str] = "atomic_rmw"
244+
245+
241246
@dataclass
242247
class Tensor:
243248
ptr: int

triton_viz/core/patch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
CumSum,
4343
Bitcast,
4444
AtomicCas,
45+
AtomicRMW,
4546
)
4647
import inspect
4748
import ast
@@ -89,6 +90,7 @@
8990
CumSum,
9091
Bitcast,
9192
AtomicCas,
93+
AtomicRMW,
9294
]
9395

9496
# Hardcoded operation attribute names to avoid issues with lambda functions
@@ -123,6 +125,7 @@
123125
Trans: "create_trans",
124126
Bitcast: "create_bitcast",
125127
AtomicCas: "create_atomic_cas",
128+
AtomicRMW: "create_atomic_rmw",
126129
}
127130

128131
original_ops = {
@@ -156,6 +159,7 @@
156159
Trans: interpreter_builder.create_trans,
157160
Bitcast: interpreter_builder.create_bitcast,
158161
AtomicCas: interpreter_builder.create_atomic_cas,
162+
AtomicRMW: interpreter_builder.create_atomic_rmw,
159163
}
160164
reduce_map: dict[type[Op], Callable] = {
161165
ReduceMax: tl.max,

0 commit comments

Comments
 (0)