Skip to content

sharding_test 2D mesh assertion has wrong expected shapes (exposed by #1733) #1734

@blasphemetheus

Description

@blasphemetheus

Summary

Two tests in exla/test/exla/defn/sharding_test.exs have wrong expected-value literals. The actual per-partition outputs are correct; the assertions have been silently passing because Nx.Testing.assert_equal used Nx.equal's broadcasting semantics to compare {4,1} against {4,2}.

Exposed by #1733, which added strict shape checking to assert_equal.

Failing tests

1. sharding_test.exs:138 — "generates correct MLIR with simple 2D mesh and sharding"

Test comments document the correct per-partition shape:

# First input: shape {8, 2} sharded as [[0], [1]] -> each partition gets {4, 1}
# Second input: shape {8, 1} sharded as [[0], []] -> each partition gets {4, 1}

Per-partition x + y is {4,1} + {4,1} = {4,1}, but lines 179/181/183/185 expect {4,2} tensors with each column duplicated.

-assert_equal(result0, Nx.tensor([[100, 100], [102, 102], [104, 104], [106, 106]]))
+assert_equal(result0, Nx.tensor([[100], [102], [104], [106]]))
-assert_equal(result1, Nx.tensor([[110, 110], [112, 112], [114, 114], [116, 116]]))
+assert_equal(result1, Nx.tensor([[110], [112], [114], [116]]))
-assert_equal(result2, Nx.tensor([[108, 108], [110, 110], [112, 112], [114, 114]]))
+assert_equal(result2, Nx.tensor([[108], [110], [112], [114]]))
-assert_equal(result3, Nx.tensor([[118, 118], [120, 120], [122, 122], [124, 124]]))
+assert_equal(result3, Nx.tensor([[118], [120], [122], [124]]))

2. sharding_test.exs:12 — "output sharding with tuple outputs"

The y*2 assertions for result2_d0..d3 expect {4,2} tensors, but the y input is {4,1} per device (lines 36-53), so y*2 is {4,1} per device. Comments on lines 66-68 claim "y broadcasts to {4,2}" but nothing in the sharding config triggers that broadcast.

-assert_equal(result2_d0, Nx.tensor([[200, 200], [202, 202], [204, 204], [206, 206]]))
+assert_equal(result2_d0, Nx.tensor([[200], [202], [204], [206]]))
-assert_equal(result2_d1, Nx.tensor([[200, 200], [202, 202], [204, 204], [206, 206]]))
+assert_equal(result2_d1, Nx.tensor([[200], [202], [204], [206]]))
-assert_equal(result2_d2, Nx.tensor([[208, 208], [210, 210], [212, 212], [214, 214]]))
+assert_equal(result2_d2, Nx.tensor([[208], [210], [212], [214]]))
-assert_equal(result2_d3, Nx.tensor([[208, 208], [210, 210], [212, 212], [214, 214]]))
+assert_equal(result2_d3, Nx.tensor([[208], [210], [212], [214]]))

Scope

After fixing both tests, no other EXLA tests rely on broadcast-hidden shape mismatches. The only other failure under #1733's strict check is an unrelated f64 rsqrt precision doctest in backend_test.exs:28, which is a separate upstream bug being fixed by openxla/xla#40844.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions