Nx.Testing: vectorized tensor support + diff diagnostics on failure#1733
Open
blasphemetheus wants to merge 6 commits intoelixir-nx:mainfrom
Open
Nx.Testing: vectorized tensor support + diff diagnostics on failure#1733blasphemetheus wants to merge 6 commits intoelixir-nx:mainfrom
blasphemetheus wants to merge 6 commits intoelixir-nx:mainfrom
Conversation
`Nx.Testing.assert_equal/2` and `assert_all_close/3` now handle
vectorized tensors cleanly and produce a numeric diagnostic when an
assertion fails, so bit-level disagreements hidden by truncated
`inspect` output are still diagnosable.
Previously, when two tensors differed by a tiny amount (e.g. 1 ULP)
that both inspect to the same string, the failure message showed
the same tensor text twice with no way to see what actually
differed. This was particularly confusing when comparing computed
gradients against hand-computed reference values in the vectorized
grad test suite.
Changes:
* `tensor_equal?/2` now explicitly guards vec-axes and shape
mismatches before running the element-wise comparison. This means
two tensors with different vectorized axes or shapes always return
false (instead of raising or relying on broadcast semantics).
* `assert_all_close/3` gains the same vec-axes/shape guards at the
top so mismatched inputs produce a structural error message
instead of an opaque `Nx.all_close` failure.
* Both helpers now call a new `diagnose_difference/2` in the flunk
block. The diagnostic:
- Describes structural mismatches directly (vec axes differ,
shapes differ).
- For tensors that DO share structure, computes the max absolute
and max relative difference across all elements, including vec
axes (via devectorize-before-reduce). Output looks like:
max absolute difference: 0.5
max relative difference: 0.2
* `assert_all_close/3` failure messages now also include the atol
and rtol values that were being checked against.
* The diagnostic is defensive: if the diff computation raises
(mixed complex/real types, unusual promotions, etc.) it falls back
silently so the baseline inspect output is still shown.
New test coverage in `nx/test/nx/testing_test.exs`:
* Both helpers pass on bit-identical tensors and within-tolerance
tensors (plain and vectorized).
* Both helpers' error messages include the max absolute difference
on a genuine disagreement.
* Both helpers' error messages include a clear diagnostic when vec
axes differ.
* `assert_all_close` error message includes max-diff for vectorized
tensors with close-but-not-equal values (the 1-ULP case).
* `assert_equal` handles NaN-equality correctly.
Co-Authored-By: Paulo Valente <16843419+polvalente@users.noreply.github.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The initial commit on this branch introduced strict shape guards that
rejected any shape mismatch between the two compared values. That broke
two legitimate patterns in existing tests:
1. Raw scalars passed as expected values: `assert_equal(a, 11)` — the
old `Nx.equal`-based implementation auto-converted the scalar and
broadcast it. The new code accessed `.vectorized_axes` directly on
the raw integer, crashing with `KeyError: key :vectorized_axes not
found in: 11`.
2. Scalar-broadcast comparisons: `assert_equal(Nx.tensor([1, 1, 1]), 1)`
— the caller means "every element of the tensor equals 1", not
"the tensor has shape {}". Strict shape matching rejects this.
Fix:
* Introduce `to_tensor/1` helper that wraps raw scalars/lists in
tensors, leaving existing tensors untouched. Applied at the top of
`tensor_equal?/2`, `assert_all_close/3`, and `diagnose_difference/2`.
* Introduce `shapes_incompatible?/2` helper that rejects genuine shape
mismatches but allows either side to be a scalar (shape `{}`). This
preserves the old `Nx.equal` broadcasting semantics for the common
"assert every element equals N" pattern while still catching real
shape bugs like the `{4,1}` vs `{4,2}` case that exposed a latent
error in EXLA's sharding_test (whose `result0` output has half the
columns the test expects — caught by this PR's stricter comparison,
to be fixed in a separate commit on the EXLA side).
* Diagnose_difference is now a single clause guarded by a `rescue`
fallback so any unusual promotion still falls back silently to the
plain inspect output.
New regression tests in `testing_test.exs`:
* `assert_equal` accepts a raw scalar against a scalar tensor.
* `assert_equal` accepts a raw scalar broadcast against a non-scalar
tensor (and rejects when values don't match).
* `assert_equal` rejects a genuine shape mismatch between two
non-scalar tensors (the sharding_test case).
* Parallel `assert_all_close` tests for each of the above.
Verified locally: full nx suite passes (1353 doctests, 1246 tests, 0
failures, 1 skipped).
The two 2D-mesh tests in exla/test/exla/defn/sharding_test.exs had
wrong expected-value literals: `y * 2` assertions in "output sharding
with tuple outputs" and every assertion in "generates correct MLIR
with simple 2D mesh and sharding" expected {4,2} tensors, but the
actual per-partition outputs are {4,1} (the inputs themselves are
{4,1} per partition — the test comments document this correctly;
only the assertion literals were wrong).
The old `Nx.Testing.assert_equal` masked the mismatch via
`Nx.equal`'s broadcasting: comparing a {4,1} result against a {4,2}
expectation broadcast the result up to {4,2} and matched exactly.
The stricter shape guard added in this PR correctly rejects those
comparisons, which exposed the bugs.
Bundling per maintainer request (elixir-nx#1734 discussion).
test/torchx/nx_test.exs:959 (output_permutation) and :990 (input_dilation) had expected-value literals of lower rank than the actual conv result; the old assert_all_close / assert_equal broadcast the lower-rank literal up to match. Fix: wrap the literal in the missing leading [...] layers to match the shape already asserted explicitly via \`result.shape == ...\` on the same test. Also applies mix format to the sharding_test.exs changes from the previous commit — formatter wants blank lines between assert_equal calls and the subsequent comments. Elixir 1.18 formatter rules caught this; 1.17 local run didn't.
The previous commit (1c51a6b) included edits to torchx/test/torchx/nx_test.exs. Those fixes address the same class of latent shape bug as the sharding_test fix in this PR, but Pol's ask on elixir-nx#1734 was scoped specifically to the sharding cases. The torchx tests are in a different sub-project entirely, and the right default when expanding past a maintainer's explicit scope is to ask rather than assume. Moving those fixes to a separate PR. The mix-format changes to sharding_test.exs from 1c51a6b are retained (Elixir 1.18 formatter needs blank lines between assert_equal calls and trailing comments).
Contributor
Author
|
Update on CI after bundling the sharding fix:
Bundled the torchx fix here too (commit 163f7da0) on the same "get CI green" rationale as the sharding bundling. Same minimal shape: wrap each expected literal in the missing leading |
Restores the torchx/test/torchx/nx_test.exs edits that were reverted in 8c3686e after a brief scope question. These are the same class of latent shape bug the sharding_test fix addresses: test author explicitly asserted \`result.shape == {1, 3, 3, 2}\` (and \`{1, 1, 1, 2, 6}\` for input_dilation) on one line, then passed a value literal of lower rank to assert_all_close / assert_equal on the next. The old broadcast-tolerant assertion hid the contradiction; the new strict shape check in this PR correctly rejects it. Fix: wrap each literal in the missing leading [...] brackets to match the shape the test is already explicitly asserting.
Contributor
Author
|
remaining failing ci is just precision issue |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Nx.Testing.assert_equal/2andNx.Testing.assert_all_close/3now handle vectorized tensors cleanly and produce a numeric diagnostic (max absolute / max relative difference) when an assertion fails. Motivated by a pain point in the vectorized grad test suite that landed in #1731: when two tensors differ by a tiny amount (e.g. 1 ULP) that bothinspectto the same string, the old failure message showed the same tensor text twice with no way to see what actually differed.What changes
tensor_equal?/2explicitly guards vec-axes and shape mismatches before running the element-wise comparison. Two tensors with different vectorized axes or shapes always returnfalseinstead of raising or relying on broadcast semantics.assert_all_close/3gains the same vec-axes/shape guards at the top so mismatched inputs produce a structural error message instead of an opaqueNx.all_closefailure.diagnose_difference/2helper in the flunk block. It:vectorized_axes differ: ...,shapes differ: ...).assert_all_close/3failure messages also include theatolandrtolvalues that were being checked against (previously hidden).diagnose_difference/2catches and returns an empty string so the baseline inspect output is still shown — no failure mode where the diagnostic itself becomes an obstacle.Before / after
Before — a 1-ULP disagreement that both tensors inspect the same:
(You have to attach a debugger or
IO.inspectthe raw binaries to figure out what went wrong.)After — the same failure:
Visually, the tensors still match, but now you can see they differ by exactly one
f32ULP.Tests
New
nx/test/nx/testing_test.exs— 12 tests covering:assert_all_closeerror message includes max-diff for vectorized tensors with close-but-not-equal values (the 1-ULP case).assert_equalhandles NaN-equality correctly (pre-existing behavior, covered for regression).All 12 pass locally.
Scope / non-goals
assert_equal's NaN semantics, changes toNx.all_closeitself, or introducing new public helpers. This is strictly an improvement to the failure-reporting behavior of the two existing assertion helpers.Related