Skip to content

Commit 91714d4

Browse files
committed
Added test for catching thrown error in spec
1 parent 60cd542 commit 91714d4

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

test/spmd/test_xla_dtensor_spec_conversion.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,23 @@ def test_spec_invalidation_on_resharding(self):
212212
assert resharded_tensor._spec is not initial_spec
213213
assert resharded_tensor._spec.placements[1].dim == 1
214214

215+
def test_auto_wrapped_tensor_spec_failure(self):
216+
"""Test that auto-wrapped tensors fail when accessing _spec property.
217+
218+
Auto-wrapped tensors are created through operations that trigger __torch_dispatch__
219+
but don't yet have access to the sharding propagation done through open xla,
220+
causing ._spec to fail.
221+
"""
222+
device_count = xr.global_runtime_device_count()
223+
mesh = DeviceMesh("xla", torch.arange(device_count))
224+
tensor = torch.randn(4, 4)
225+
sharded_tensor = distribute_tensor(tensor, mesh, [Shard(0)])
226+
227+
auto_wrapped = sharded_tensor + sharded_tensor
228+
229+
with self.assertRaises(ValueError):
230+
_ = auto_wrapped._spec
231+
215232

216233
if __name__ == '__main__':
217234
test = unittest.main()

torch_xla/distributed/spmd/xla_sharded_tensor.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,11 @@ def _spec(self):
205205
mesh = DeviceMesh("xla",
206206
torch.tensor(device_list).reshape(self.mesh_shape))
207207
else:
208-
raise ValueError("mesh_shape must be specified to create DTensorSpec")
208+
raise ValueError(
209+
"mesh_shape must be specified to create DTensorSpec. "
210+
"If this tensor was created through torch operations, it may be auto-wrapped. "
211+
"Use wrap_as_sharded_tensor() to set mesh_shape before accessing _spec. "
212+
)
209213

210214
# use existing partition_spec
211215
if self.partition_spec is not None:
@@ -220,7 +224,11 @@ def _spec(self):
220224
placements.append(
221225
Shard(tensor_dim) if tensor_dim is not None else Replicate())
222226
else:
223-
raise ValueError("partition_spec must be specified to create DTensorSpec")
227+
raise ValueError(
228+
"partition_spec must be specified to create DTensorSpec. "
229+
"If this tensor was created through torch operations, it may be auto-wrapped. "
230+
"Use wrap_as_sharded_tensor() to set partition_spec before accessing _spec. "
231+
)
224232

225233
# tensor metadata
226234
tensor_meta = TensorMeta(

0 commit comments

Comments
 (0)