Summary
Two tests in exla/test/exla/defn/sharding_test.exs have wrong expected-value literals. The actual per-partition outputs are correct; the assertions have been silently passing because Nx.Testing.assert_equal used Nx.equal's broadcasting semantics to compare {4,1} against {4,2}.
Exposed by #1733, which added strict shape checking to assert_equal.
Failing tests
1. sharding_test.exs:138 — "generates correct MLIR with simple 2D mesh and sharding"
Test comments document the correct per-partition shape:
# First input: shape {8, 2} sharded as [[0], [1]] -> each partition gets {4, 1}
# Second input: shape {8, 1} sharded as [[0], []] -> each partition gets {4, 1}
Per-partition x + y is {4,1} + {4,1} = {4,1}, but lines 179/181/183/185 expect {4,2} tensors with each column duplicated.
-assert_equal(result0, Nx.tensor([[100, 100], [102, 102], [104, 104], [106, 106]]))
+assert_equal(result0, Nx.tensor([[100], [102], [104], [106]]))
-assert_equal(result1, Nx.tensor([[110, 110], [112, 112], [114, 114], [116, 116]]))
+assert_equal(result1, Nx.tensor([[110], [112], [114], [116]]))
-assert_equal(result2, Nx.tensor([[108, 108], [110, 110], [112, 112], [114, 114]]))
+assert_equal(result2, Nx.tensor([[108], [110], [112], [114]]))
-assert_equal(result3, Nx.tensor([[118, 118], [120, 120], [122, 122], [124, 124]]))
+assert_equal(result3, Nx.tensor([[118], [120], [122], [124]]))
2. sharding_test.exs:12 — "output sharding with tuple outputs"
The y*2 assertions for result2_d0..d3 expect {4,2} tensors, but the y input is {4,1} per device (lines 36-53), so y*2 is {4,1} per device. Comments on lines 66-68 claim "y broadcasts to {4,2}" but nothing in the sharding config triggers that broadcast.
-assert_equal(result2_d0, Nx.tensor([[200, 200], [202, 202], [204, 204], [206, 206]]))
+assert_equal(result2_d0, Nx.tensor([[200], [202], [204], [206]]))
-assert_equal(result2_d1, Nx.tensor([[200, 200], [202, 202], [204, 204], [206, 206]]))
+assert_equal(result2_d1, Nx.tensor([[200], [202], [204], [206]]))
-assert_equal(result2_d2, Nx.tensor([[208, 208], [210, 210], [212, 212], [214, 214]]))
+assert_equal(result2_d2, Nx.tensor([[208], [210], [212], [214]]))
-assert_equal(result2_d3, Nx.tensor([[208, 208], [210, 210], [212, 212], [214, 214]]))
+assert_equal(result2_d3, Nx.tensor([[208], [210], [212], [214]]))
Scope
After fixing both tests, no other EXLA tests rely on broadcast-hidden shape mismatches. The only other failure under #1733's strict check is an unrelated f64 rsqrt precision doctest in backend_test.exs:28, which is a separate upstream bug being fixed by openxla/xla#40844.
Summary
Two tests in
exla/test/exla/defn/sharding_test.exshave wrong expected-value literals. The actual per-partition outputs are correct; the assertions have been silently passing becauseNx.Testing.assert_equalusedNx.equal's broadcasting semantics to compare{4,1}against{4,2}.Exposed by #1733, which added strict shape checking to
assert_equal.Failing tests
1.
sharding_test.exs:138— "generates correct MLIR with simple 2D mesh and sharding"Test comments document the correct per-partition shape:
Per-partition
x + yis{4,1} + {4,1} = {4,1}, but lines 179/181/183/185 expect{4,2}tensors with each column duplicated.2.
sharding_test.exs:12— "output sharding with tuple outputs"The
y*2assertions forresult2_d0..d3expect{4,2}tensors, but the y input is{4,1}per device (lines 36-53), soy*2is{4,1}per device. Comments on lines 66-68 claim "y broadcasts to {4,2}" but nothing in the sharding config triggers that broadcast.Scope
After fixing both tests, no other EXLA tests rely on broadcast-hidden shape mismatches. The only other failure under #1733's strict check is an unrelated f64 rsqrt precision doctest in
backend_test.exs:28, which is a separate upstream bug being fixed by openxla/xla#40844.