Skip to content
53 changes: 53 additions & 0 deletions tests/test_sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,59 @@ def test_atomic_cas():
# This test verifies that the operation doesn't crash


# ======== Reduce Operations (max, min) =========
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the tests are not as good as the previous tests. Should be refactored

def test_reduce_max_expr_eval():
"""Test that max reduce operation evaluates correctly."""
# Test with array of values - create const SymbolicExpr with numpy array
input_arr = SymbolicExpr("const", np.array([1, 5, 3, 2]), tl.int32)
max_expr = SymbolicExpr("max", input_arr, None, False)
result, _ = max_expr.eval()
assert result == 5


def test_reduce_min_expr_eval():
"""Test that min reduce operation evaluates correctly."""
# Test with array of values - create const SymbolicExpr with numpy array
input_arr = SymbolicExpr("const", np.array([1, 5, 3, 2]), tl.int32)
min_expr = SymbolicExpr("min", input_arr, None, False)
result, _ = min_expr.eval()
assert result == 1


def test_reduce_max_single_element():
"""Test that max reduce operation works with single element."""
# Test with single element - create const SymbolicExpr with numpy array
input_arr = SymbolicExpr("const", np.array([42]), tl.int32)
max_expr = SymbolicExpr("max", input_arr, None, False)
result, _ = max_expr.eval()
assert result == 42


def test_reduce_min_single_element():
"""Test that min reduce operation works with single element."""
# Test with single element - create const SymbolicExpr with numpy array
input_arr = SymbolicExpr("const", np.array([42]), tl.int32)
min_expr = SymbolicExpr("min", input_arr, None, False)
result, _ = min_expr.eval()
assert result == 42


def test_reduce_max_empty_array():
"""Test that max reduce operation raises ValueError for empty array."""
input_arr = SymbolicExpr("const", np.array([]), tl.int32)
max_expr = SymbolicExpr("max", input_arr, None, False)
with pytest.raises(ValueError, match="Cannot compute max of empty array"):
max_expr.eval()


def test_reduce_min_empty_array():
"""Test that min reduce operation raises ValueError for empty array."""
input_arr = SymbolicExpr("const", np.array([]), tl.int32)
min_expr = SymbolicExpr("min", input_arr, None, False)
with pytest.raises(ValueError, match="Cannot compute min of empty array"):
min_expr.eval()


# ======== Cache Ablation Tests =========

# ---- Symbol Cache Tests ----
Expand Down
12 changes: 9 additions & 3 deletions triton_viz/clients/sanitizer/sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import namedtuple
from collections.abc import Callable
from dataclasses import dataclass, field
from functools import cached_property
from functools import cached_property, reduce
from typing import Any, Optional, Union
import re

Expand Down Expand Up @@ -1194,10 +1194,16 @@ def _to_z3(self) -> tuple[ArithRef, list]:
self._z3 = Sum(arr)

if self.op == "max":
raise NotImplementedError("_to_z3 of max is not implemented yet")
arr, self._constraints = self.input._to_z3()
if not arr:
raise ValueError("Cannot compute max of empty array")
self._z3 = reduce(lambda a, b: If(a >= b, a, b), arr)

if self.op == "min":
raise NotImplementedError("_to_z3 of min is not implemented yet")
arr, self._constraints = self.input._to_z3()
if not arr:
raise ValueError("Cannot compute min of empty array")
self._z3 = reduce(lambda a, b: If(a <= b, a, b), arr)
Comment on lines 1196 to +1206

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Handle scalar inputs in max/min reduction

The new implementation assumes self.input._to_z3() returns an iterable, but many expressions (e.g., a reduction over a scalar tensor or the result of a symbolic load) yield a single ArithRef. In those cases if not arr: raises TypeError: Symbolic expressions cannot be cast to bool and the subsequent reduce(...) call raises TypeError: 'ArithRef' object is not iterable, so max/min still crash. sum works for both lists and scalars because Sum handles either case, but the max/min branches now fail for valid scalar inputs. Consider wrapping non-iterables in a list or skipping the emptiness check unless the result is a sequence.

Useful? React with 👍 / 👎.


if self.op == "load" or self.op == "store":
# Load and store operations
Expand Down