Skip to content

Nx.Testing: vectorized tensor support + diff diagnostics on failure#1733

Open
blasphemetheus wants to merge 6 commits intoelixir-nx:mainfrom
blasphemetheus:feat/nx-testing-vectorized-asserts
Open

Nx.Testing: vectorized tensor support + diff diagnostics on failure#1733
blasphemetheus wants to merge 6 commits intoelixir-nx:mainfrom
blasphemetheus:feat/nx-testing-vectorized-asserts

Conversation

@blasphemetheus
Copy link
Copy Markdown
Contributor

Summary

Nx.Testing.assert_equal/2 and Nx.Testing.assert_all_close/3 now handle vectorized tensors cleanly and produce a numeric diagnostic (max absolute / max relative difference) when an assertion fails. Motivated by a pain point in the vectorized grad test suite that landed in #1731: when two tensors differ by a tiny amount (e.g. 1 ULP) that both inspect to the same string, the old failure message showed the same tensor text twice with no way to see what actually differed.

What changes

  • tensor_equal?/2 explicitly guards vec-axes and shape mismatches before running the element-wise comparison. Two tensors with different vectorized axes or shapes always return false instead of raising or relying on broadcast semantics.
  • assert_all_close/3 gains the same vec-axes/shape guards at the top so mismatched inputs produce a structural error message instead of an opaque Nx.all_close failure.
  • Both helpers now call a new diagnose_difference/2 helper in the flunk block. It:
    • Describes structural mismatches directly (vectorized_axes differ: ..., shapes differ: ...).
    • For tensors that DO share structure, computes max absolute and max relative difference across all elements (including vec axes, via devectorize-before-reduce). Output looks like:
      max absolute difference: 0.5
      max relative difference: 0.2
      
  • assert_all_close/3 failure messages also include the atol and rtol values that were being checked against (previously hidden).
  • Defensive fallback: if the diff computation itself raises (mixed complex/real types, unusual promotions), diagnose_difference/2 catches and returns an empty string so the baseline inspect output is still shown — no failure mode where the diagnostic itself becomes an obstacle.

Before / after

Before — a 1-ULP disagreement that both tensors inspect the same:

Tensor assertion failed.
left: #Nx.Tensor<f32[3] [1.0, 2.0, 3.0]>
right: #Nx.Tensor<f32[3] [1.0, 2.0, 3.0]>

(You have to attach a debugger or IO.inspect the raw binaries to figure out what went wrong.)

After — the same failure:

Tensor assertion failed.

left:

#Nx.Tensor<f32[3] [1.0, 2.0, 3.0]>

right:

#Nx.Tensor<f32[3] [1.0, 2.0, 3.0]>

max absolute difference: 1.1920929e-7
max relative difference: 5.9604645e-8

Visually, the tensors still match, but now you can see they differ by exactly one f32 ULP.

Tests

New nx/test/nx/testing_test.exs — 12 tests covering:

  • Both helpers pass on bit-identical tensors and within-tolerance tensors (plain and vectorized).
  • Both helpers' error messages include the max absolute difference on a genuine disagreement.
  • Both helpers' error messages include a clear diagnostic when vec axes differ.
  • assert_all_close error message includes max-diff for vectorized tensors with close-but-not-equal values (the 1-ULP case).
  • assert_equal handles NaN-equality correctly (pre-existing behavior, covered for regression).

All 12 pass locally.

Scope / non-goals

  • Out of scope: changes to assert_equal's NaN semantics, changes to Nx.all_close itself, or introducing new public helpers. This is strictly an improvement to the failure-reporting behavior of the two existing assertion helpers.
  • Co-authorship: credit Paulo Valente for review-time suggestions during PR Support vectorized gradients via boundary devectorization (#1533) #1731 that shaped the vec-axes guarding approach (he caught that the assertion helper was the source of a confusing test failure rather than the unit under test).

Related

blasphemetheus and others added 2 commits April 13, 2026 22:23
`Nx.Testing.assert_equal/2` and `assert_all_close/3` now handle
vectorized tensors cleanly and produce a numeric diagnostic when an
assertion fails, so bit-level disagreements hidden by truncated
`inspect` output are still diagnosable.

Previously, when two tensors differed by a tiny amount (e.g. 1 ULP)
that both inspect to the same string, the failure message showed
the same tensor text twice with no way to see what actually
differed. This was particularly confusing when comparing computed
gradients against hand-computed reference values in the vectorized
grad test suite.

Changes:

* `tensor_equal?/2` now explicitly guards vec-axes and shape
  mismatches before running the element-wise comparison. This means
  two tensors with different vectorized axes or shapes always return
  false (instead of raising or relying on broadcast semantics).

* `assert_all_close/3` gains the same vec-axes/shape guards at the
  top so mismatched inputs produce a structural error message
  instead of an opaque `Nx.all_close` failure.

* Both helpers now call a new `diagnose_difference/2` in the flunk
  block. The diagnostic:
    - Describes structural mismatches directly (vec axes differ,
      shapes differ).
    - For tensors that DO share structure, computes the max absolute
      and max relative difference across all elements, including vec
      axes (via devectorize-before-reduce). Output looks like:
          max absolute difference: 0.5
          max relative difference: 0.2

* `assert_all_close/3` failure messages now also include the atol
  and rtol values that were being checked against.

* The diagnostic is defensive: if the diff computation raises
  (mixed complex/real types, unusual promotions, etc.) it falls back
  silently so the baseline inspect output is still shown.

New test coverage in `nx/test/nx/testing_test.exs`:

* Both helpers pass on bit-identical tensors and within-tolerance
  tensors (plain and vectorized).
* Both helpers' error messages include the max absolute difference
  on a genuine disagreement.
* Both helpers' error messages include a clear diagnostic when vec
  axes differ.
* `assert_all_close` error message includes max-diff for vectorized
  tensors with close-but-not-equal values (the 1-ULP case).
* `assert_equal` handles NaN-equality correctly.

Co-Authored-By: Paulo Valente <16843419+polvalente@users.noreply.github.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The initial commit on this branch introduced strict shape guards that
rejected any shape mismatch between the two compared values. That broke
two legitimate patterns in existing tests:

1. Raw scalars passed as expected values: `assert_equal(a, 11)` — the
   old `Nx.equal`-based implementation auto-converted the scalar and
   broadcast it. The new code accessed `.vectorized_axes` directly on
   the raw integer, crashing with `KeyError: key :vectorized_axes not
   found in: 11`.

2. Scalar-broadcast comparisons: `assert_equal(Nx.tensor([1, 1, 1]), 1)`
   — the caller means "every element of the tensor equals 1", not
   "the tensor has shape {}". Strict shape matching rejects this.

Fix:

* Introduce `to_tensor/1` helper that wraps raw scalars/lists in
  tensors, leaving existing tensors untouched. Applied at the top of
  `tensor_equal?/2`, `assert_all_close/3`, and `diagnose_difference/2`.

* Introduce `shapes_incompatible?/2` helper that rejects genuine shape
  mismatches but allows either side to be a scalar (shape `{}`). This
  preserves the old `Nx.equal` broadcasting semantics for the common
  "assert every element equals N" pattern while still catching real
  shape bugs like the `{4,1}` vs `{4,2}` case that exposed a latent
  error in EXLA's sharding_test (whose `result0` output has half the
  columns the test expects — caught by this PR's stricter comparison,
  to be fixed in a separate commit on the EXLA side).

* Diagnose_difference is now a single clause guarded by a `rescue`
  fallback so any unusual promotion still falls back silently to the
  plain inspect output.

New regression tests in `testing_test.exs`:

* `assert_equal` accepts a raw scalar against a scalar tensor.
* `assert_equal` accepts a raw scalar broadcast against a non-scalar
  tensor (and rejects when values don't match).
* `assert_equal` rejects a genuine shape mismatch between two
  non-scalar tensors (the sharding_test case).
* Parallel `assert_all_close` tests for each of the above.

Verified locally: full nx suite passes (1353 doctests, 1246 tests, 0
failures, 1 skipped).
The two 2D-mesh tests in exla/test/exla/defn/sharding_test.exs had
wrong expected-value literals: `y * 2` assertions in "output sharding
with tuple outputs" and every assertion in "generates correct MLIR
with simple 2D mesh and sharding" expected {4,2} tensors, but the
actual per-partition outputs are {4,1} (the inputs themselves are
{4,1} per partition — the test comments document this correctly;
only the assertion literals were wrong).

The old `Nx.Testing.assert_equal` masked the mismatch via
`Nx.equal`'s broadcasting: comparing a {4,1} result against a {4,2}
expectation broadcast the result up to {4,2} and matched exactly.
The stricter shape guard added in this PR correctly rejects those
comparisons, which exposed the bugs.

Bundling per maintainer request (elixir-nx#1734 discussion).
test/torchx/nx_test.exs:959 (output_permutation) and :990
(input_dilation) had expected-value literals of lower rank than the
actual conv result; the old assert_all_close / assert_equal broadcast
the lower-rank literal up to match. Fix: wrap the literal in the
missing leading [...] layers to match the shape already asserted
explicitly via \`result.shape == ...\` on the same test.

Also applies mix format to the sharding_test.exs changes from the
previous commit — formatter wants blank lines between assert_equal
calls and the subsequent comments. Elixir 1.18 formatter rules caught
this; 1.17 local run didn't.
The previous commit (1c51a6b) included edits to
torchx/test/torchx/nx_test.exs. Those fixes address the same class of
latent shape bug as the sharding_test fix in this PR, but Pol's ask
on elixir-nx#1734 was scoped specifically to the sharding cases. The torchx
tests are in a different sub-project entirely, and the right default
when expanding past a maintainer's explicit scope is to ask rather
than assume. Moving those fixes to a separate PR.

The mix-format changes to sharding_test.exs from 1c51a6b are retained
(Elixir 1.18 formatter needs blank lines between assert_equal calls
and trailing comments).
@blasphemetheus
Copy link
Copy Markdown
Contributor Author

blasphemetheus commented Apr 21, 2026

Update on CI after bundling the sharding fix:

  • exla 1.17.3: green ✓ (was red on sharding)
  • exla 1.18.4: red on a formatting issue — Elixir 1.18's formatter wants blank lines between my new assert_equals and the trailing comments on sharding_test.exs. Fixed in 8c3686ea, should go green on next run.
  • torchx (both versions): red on two tests in torchx/test/torchx/nx_test.exs (output_permutation and input_dilation) that turn out to have the same class of latent shape bug as the sharding cases — the test asserts result.shape == {1, 3, 3, 2} on one line, then passes a value literal of shape {3, 3, 2} to assert_all_close on the next, and the old broadcast-tolerant assertion hid the contradiction.

Bundled the torchx fix here too (commit 163f7da0) on the same "get CI green" rationale as the sharding bundling. Same minimal shape: wrap each expected literal in the missing leading [...] brackets to match the shape already explicitly asserted via result.shape ==. Happy to split into a separate PR if you'd rather keep #1733's scope tight to nx/

Restores the torchx/test/torchx/nx_test.exs edits that were
reverted in 8c3686e after a brief scope question. These are the
same class of latent shape bug the sharding_test fix addresses:
test author explicitly asserted \`result.shape == {1, 3, 3, 2}\`
(and \`{1, 1, 1, 2, 6}\` for input_dilation) on one line, then
passed a value literal of lower rank to assert_all_close /
assert_equal on the next. The old broadcast-tolerant assertion
hid the contradiction; the new strict shape check in this PR
correctly rejects it.

Fix: wrap each literal in the missing leading [...] brackets to
match the shape the test is already explicitly asserting.
@blasphemetheus
Copy link
Copy Markdown
Contributor Author

remaining failing ci is just precision issue

@blasphemetheus blasphemetheus marked this pull request as ready for review April 21, 2026 07:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant