diff --git a/triton_viz/clients/sanitizer/sanitizer.py b/triton_viz/clients/sanitizer/sanitizer.py index 881986ad..f5be63bb 100644 --- a/triton_viz/clients/sanitizer/sanitizer.py +++ b/triton_viz/clients/sanitizer/sanitizer.py @@ -2103,9 +2103,34 @@ def loop_hook_after(lineno: int) -> None: # add constraints for loop_i iterator_constraints: list[BoolRef] = [] if ctx.values: - iterator_constraints.append( - Or(*[ctx.idx_z3 == v for v in set(ctx.values)]) - ) + + def check_if_affine(seq): + """Check whether seq forms an arithmetic sequence; + if it does, return (True, LowerBound, step, UpperBound);, + otherwise return (False, None, None, None). + """ + n = len(seq) + if n == 0 or n == 1: + return False, None, None, None + step = seq[1] - seq[0] + for i in range(2, n): + if seq[i] - seq[i - 1] != step: + return False, None, None, None + return True, seq[0], step, seq[-1] + + is_affine, start, step, end = check_if_affine(ctx.values) + if is_affine: + iterator_constraints.append( + And( + ctx.idx_z3 >= start, + ctx.idx_z3 <= end, + (ctx.idx_z3 - start) % step == 0, + ) + ) + else: + iterator_constraints.append( + Or(*[ctx.idx_z3 == v for v in set(ctx.values)]) + ) # execute pending checks for check_tuple in ctx.pending_checks: