File tree Expand file tree Collapse file tree 2 files changed +27
-2
lines changed
torch_xla/distributed/spmd Expand file tree Collapse file tree 2 files changed +27
-2
lines changed Original file line number Diff line number Diff line change @@ -212,6 +212,23 @@ def test_spec_invalidation_on_resharding(self):
212
212
assert resharded_tensor ._spec is not initial_spec
213
213
assert resharded_tensor ._spec .placements [1 ].dim == 1
214
214
215
+ def test_auto_wrapped_tensor_spec_failure (self ):
216
+ """Test that auto-wrapped tensors fail when accessing _spec property.
217
+
218
+ Auto-wrapped tensors are created through operations that trigger __torch_dispatch__
219
+ but don't yet have access to the sharding propagation done through open xla,
220
+ causing ._spec to fail.
221
+ """
222
+ device_count = xr .global_runtime_device_count ()
223
+ mesh = DeviceMesh ("xla" , torch .arange (device_count ))
224
+ tensor = torch .randn (4 , 4 )
225
+ sharded_tensor = distribute_tensor (tensor , mesh , [Shard (0 )])
226
+
227
+ auto_wrapped = sharded_tensor + sharded_tensor
228
+
229
+ with self .assertRaises (ValueError ):
230
+ _ = auto_wrapped ._spec
231
+
215
232
216
233
if __name__ == '__main__' :
217
234
test = unittest .main ()
Original file line number Diff line number Diff line change @@ -205,7 +205,11 @@ def _spec(self):
205
205
mesh = DeviceMesh ("xla" ,
206
206
torch .tensor (device_list ).reshape (self .mesh_shape ))
207
207
else :
208
- raise ValueError ("mesh_shape must be specified to create DTensorSpec" )
208
+ raise ValueError (
209
+ "mesh_shape must be specified to create DTensorSpec. "
210
+ "If this tensor was created through torch operations, it may be auto-wrapped. "
211
+ "Use wrap_as_sharded_tensor() to set mesh_shape before accessing _spec. "
212
+ )
209
213
210
214
# use existing partition_spec
211
215
if self .partition_spec is not None :
@@ -220,7 +224,11 @@ def _spec(self):
220
224
placements .append (
221
225
Shard (tensor_dim ) if tensor_dim is not None else Replicate ())
222
226
else :
223
- raise ValueError ("partition_spec must be specified to create DTensorSpec" )
227
+ raise ValueError (
228
+ "partition_spec must be specified to create DTensorSpec. "
229
+ "If this tensor was created through torch operations, it may be auto-wrapped. "
230
+ "Use wrap_as_sharded_tensor() to set partition_spec before accessing _spec. "
231
+ )
224
232
225
233
# tensor metadata
226
234
tensor_meta = TensorMeta (
You can’t perform that action at this time.
0 commit comments