diff --git a/tests/test_sanitizer.py b/tests/test_sanitizer.py index 3908977..2e7eb3a 100644 --- a/tests/test_sanitizer.py +++ b/tests/test_sanitizer.py @@ -394,6 +394,59 @@ def test_atomic_cas(): # This test verifies that the operation doesn't crash +# ======== Reduce Operations (max, min) ========= +def test_reduce_max_expr_eval(): + """Test that max reduce operation evaluates correctly.""" + # Test with array of values - create const SymbolicExpr with numpy array + input_arr = SymbolicExpr("const", np.array([1, 5, 3, 2]), tl.int32) + max_expr = SymbolicExpr("max", input_arr, None, False) + result, _ = max_expr.eval() + assert result == 5 + + +def test_reduce_min_expr_eval(): + """Test that min reduce operation evaluates correctly.""" + # Test with array of values - create const SymbolicExpr with numpy array + input_arr = SymbolicExpr("const", np.array([1, 5, 3, 2]), tl.int32) + min_expr = SymbolicExpr("min", input_arr, None, False) + result, _ = min_expr.eval() + assert result == 1 + + +def test_reduce_max_single_element(): + """Test that max reduce operation works with single element.""" + # Test with single element - create const SymbolicExpr with numpy array + input_arr = SymbolicExpr("const", np.array([42]), tl.int32) + max_expr = SymbolicExpr("max", input_arr, None, False) + result, _ = max_expr.eval() + assert result == 42 + + +def test_reduce_min_single_element(): + """Test that min reduce operation works with single element.""" + # Test with single element - create const SymbolicExpr with numpy array + input_arr = SymbolicExpr("const", np.array([42]), tl.int32) + min_expr = SymbolicExpr("min", input_arr, None, False) + result, _ = min_expr.eval() + assert result == 42 + + +def test_reduce_max_empty_array(): + """Test that max reduce operation raises ValueError for empty array.""" + input_arr = SymbolicExpr("const", np.array([]), tl.int32) + max_expr = SymbolicExpr("max", input_arr, None, False) + with pytest.raises(ValueError, match="Cannot compute max of empty array"): + max_expr.eval() + + +def test_reduce_min_empty_array(): + """Test that min reduce operation raises ValueError for empty array.""" + input_arr = SymbolicExpr("const", np.array([]), tl.int32) + min_expr = SymbolicExpr("min", input_arr, None, False) + with pytest.raises(ValueError, match="Cannot compute min of empty array"): + min_expr.eval() + + # ======== Cache Ablation Tests ========= # ---- Symbol Cache Tests ---- diff --git a/triton_viz/clients/sanitizer/sanitizer.py b/triton_viz/clients/sanitizer/sanitizer.py index 118debc..ff62c5f 100644 --- a/triton_viz/clients/sanitizer/sanitizer.py +++ b/triton_viz/clients/sanitizer/sanitizer.py @@ -2,7 +2,7 @@ from collections import namedtuple from collections.abc import Callable from dataclasses import dataclass, field -from functools import cached_property +from functools import cached_property, reduce from typing import Any, Optional, Union import re @@ -1194,10 +1194,16 @@ def _to_z3(self) -> tuple[ArithRef, list]: self._z3 = Sum(arr) if self.op == "max": - raise NotImplementedError("_to_z3 of max is not implemented yet") + arr, self._constraints = self.input._to_z3() + if not arr: + raise ValueError("Cannot compute max of empty array") + self._z3 = reduce(lambda a, b: If(a >= b, a, b), arr) if self.op == "min": - raise NotImplementedError("_to_z3 of min is not implemented yet") + arr, self._constraints = self.input._to_z3() + if not arr: + raise ValueError("Cannot compute min of empty array") + self._z3 = reduce(lambda a, b: If(a <= b, a, b), arr) if self.op == "load" or self.op == "store": # Load and store operations