Skip to content

Commit ca47198

Browse files
authored
Implement XLAShardedTensor._spec and test (#9488)
1 parent 299a16b commit ca47198

File tree

7 files changed

+358
-14
lines changed

7 files changed

+358
-14
lines changed

test/neuron/run_tests.sh

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ function run_test {
5656
PJRT_DEVICE=NEURON NEURON_NUM_DEVICES=1 run_coverage "$@"
5757
}
5858

59+
function run_test_multi_device {
60+
if ! test_is_selected "$1"; then
61+
return
62+
fi
63+
echo "Running in PjRt runtime: $@"
64+
PJRT_DEVICE=NEURON run_coverage "$@"
65+
}
66+
5967
function run_test_without_functionalization {
6068
if ! test_is_selected "$1"; then
6169
return
@@ -246,7 +254,8 @@ function run_xla_op_tests3 {
246254
run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py"
247255
#run_test "$_TEST_DIR/spmd/test_dtensor_integration.py"
248256
#run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py"
249-
run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
257+
run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
258+
run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_spec_conv.py"
250259
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
251260
#run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
252261
run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py"

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ function run_xla_op_tests3 {
254254
run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py"
255255
run_test_multi_devices_without_func "$_TEST_DIR/spmd/test_dtensor_integration3.py"
256256
run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
257+
run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py"
257258
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
258259
run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
259260
run_test "$_TEST_DIR/spmd/test_mp_input_sharding.py"
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
import os
2+
import sys
3+
4+
import torch
5+
from torch.distributed.tensor import DeviceMesh, Shard, distribute_tensor
6+
from torch.distributed.tensor.placement_types import Replicate
7+
8+
import torch_xla
9+
import torch_xla.runtime as xr
10+
from torch_xla.distributed.spmd import XLAShardedTensor
11+
from torch_xla.distributed.spmd.xla_sharding import wrap_as_sharded_tensor
12+
13+
import unittest
14+
import test_xla_sharding_base
15+
16+
17+
class XLADTensorSpecConversionTest(test_xla_sharding_base.XlaShardingTest):
18+
19+
@classmethod
20+
def setUpClass(cls):
21+
super().setUpClass()
22+
23+
def test_sample_test_case(self):
24+
world_size = xr.global_runtime_device_count()
25+
mesh = DeviceMesh("xla", torch.arange(world_size))
26+
big_tensor = torch.randn(100000, 88)
27+
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)])
28+
29+
assert my_dtensor._spec.mesh.device_type == mesh.device_type
30+
assert my_dtensor._spec.placements == (Shard(0),)
31+
32+
def test_xla_to_dtensor_spec_conversion(self):
33+
device_count = xr.global_runtime_device_count()
34+
mesh = DeviceMesh("xla", list(range(device_count)))
35+
36+
# Test different sharding patterns
37+
test_cases = [
38+
(torch.randn(100, 50), [Shard(0)]),
39+
(torch.randn(100, 50), [Shard(1)]),
40+
(torch.randn(100, 50, 25), [Shard(0)]),
41+
(torch.randn(100, 50), [Replicate()]),
42+
]
43+
44+
for tensor, placements in test_cases:
45+
xla_tensor = distribute_tensor(tensor, mesh, placements)
46+
spec = xla_tensor._spec
47+
48+
assert spec is not None
49+
assert spec.mesh.device_type == "xla"
50+
assert spec.tensor_meta.shape == tensor.shape
51+
assert spec.tensor_meta.dtype == tensor.dtype
52+
assert len(spec.placements) >= 1
53+
assert spec.placements == tuple(placements)
54+
55+
def test_mesh_conversion(self):
56+
device_count = xr.global_runtime_device_count()
57+
original_mesh = DeviceMesh("xla", list(range(device_count)))
58+
tensor = torch.randn(50, 50)
59+
xla_tensor = distribute_tensor(tensor, original_mesh, [Shard(0)])
60+
61+
converted_spec = xla_tensor._spec
62+
63+
assert converted_spec.mesh.device_type == "xla"
64+
assert converted_spec.mesh.size() == device_count
65+
# assert on mesh dimensions
66+
assert converted_spec.mesh.shape == original_mesh.shape
67+
68+
def test_spec_caching(self):
69+
"""Test that _spec property caches results
70+
"""
71+
device_count = xr.global_runtime_device_count()
72+
mesh = DeviceMesh("xla", list(range(device_count)))
73+
tensor = torch.randn(100, 100)
74+
xla_tensor = distribute_tensor(tensor, mesh, [Shard(0)])
75+
76+
spec1 = xla_tensor._spec
77+
78+
assert xla_tensor._cached_spec is not None
79+
assert xla_tensor._cached_spec is spec1
80+
81+
spec2 = xla_tensor._spec
82+
assert spec1 is spec2
83+
84+
def _create_test_tensor_and_mesh(self, tensor_shape, mesh_shape, placements):
85+
"""Helper to create tensor and mesh for testing"""
86+
device_count = xr.global_runtime_device_count()
87+
if device_count < max(mesh_shape):
88+
self.skipTest(
89+
f"Need at least {max(mesh_shape)} devices, got {device_count}")
90+
91+
mesh = DeviceMesh("xla", torch.arange(device_count).reshape(mesh_shape))
92+
tensor = torch.randn(*tensor_shape)
93+
return distribute_tensor(tensor, mesh, placements), mesh
94+
95+
def test_multi_dim_sharding_spec(self):
96+
"""Test _spec for multi-dimensional sharding"""
97+
device_count = xr.global_runtime_device_count()
98+
if device_count < 4:
99+
self.skipTest("Need at least 4 devices for 2D mesh")
100+
101+
mesh_shape = (2, device_count // 2)
102+
xla_tensor, mesh = self._create_test_tensor_and_mesh(
103+
(100, 50), mesh_shape, [Shard(0), Shard(1)])
104+
spec = xla_tensor._spec
105+
106+
assert len(spec.placements) == 2
107+
assert spec.mesh.ndim == 2
108+
109+
def test_mixed_placement_spec(self):
110+
"""Test _spec for tensors with mixed shard/replicate placements"""
111+
device_count = xr.global_runtime_device_count()
112+
if device_count < 4:
113+
self.skipTest("Need at least 4 devices for 2D mesh")
114+
115+
mesh_shape = (2, device_count // 2)
116+
xla_tensor, mesh = self._create_test_tensor_and_mesh(
117+
(100, 50), mesh_shape, [Shard(0), Replicate()])
118+
spec = xla_tensor._spec
119+
120+
assert len(spec.placements) == 2
121+
assert isinstance(spec.placements[0], Shard)
122+
assert isinstance(spec.placements[1], Replicate)
123+
124+
def test_sharding_info_acquisition(self):
125+
"""Test that non-XLAShardedTensor can acquire sharding information
126+
127+
Tests case of 'elem is not an XLAShardedTensor but there exists
128+
sharding information we want to acquire'
129+
"""
130+
131+
device_count = xr.global_runtime_device_count()
132+
mesh_shape = (device_count,)
133+
partition_spec = (0, None)
134+
135+
regular_tensor = torch.randn(100, 50).to('xla')
136+
137+
sharded_tensor = wrap_as_sharded_tensor(
138+
regular_tensor, mesh_shape=mesh_shape, partition_spec=partition_spec)
139+
140+
# Verify the tensor acquired the sharding information
141+
assert isinstance(sharded_tensor, XLAShardedTensor)
142+
assert sharded_tensor.mesh_shape == mesh_shape
143+
assert sharded_tensor.partition_spec == partition_spec
144+
145+
def test_resharding_logic(self):
146+
"""
147+
Tests wrap_as_sharded_tensor resharding before returning XLAShardedTensor t.
148+
"""
149+
150+
device_count = xr.global_runtime_device_count()
151+
if device_count < 4:
152+
self.skipTest("Need at least 4 devices for resharding test")
153+
154+
# Initial sharding
155+
initial_mesh_shape = (device_count,)
156+
initial_partition_spec = (0, None)
157+
new_mesh_shape = (2, device_count // 2)
158+
new_partition_spec = (0, 1)
159+
160+
# Create tensor and verify resharding
161+
tensor = torch.randn(100, 50).to('xla')
162+
sharded_tensor = wrap_as_sharded_tensor(
163+
tensor,
164+
mesh_shape=initial_mesh_shape,
165+
partition_spec=initial_partition_spec)
166+
initial_spec = sharded_tensor._spec
167+
168+
resharded_tensor = wrap_as_sharded_tensor(
169+
sharded_tensor,
170+
mesh_shape=new_mesh_shape,
171+
partition_spec=new_partition_spec)
172+
173+
# Verify resharding worked and cache was invalidated
174+
assert resharded_tensor.mesh_shape == new_mesh_shape
175+
assert resharded_tensor.partition_spec == new_partition_spec
176+
assert resharded_tensor._spec is not initial_spec
177+
178+
def test_spec_invalidation_on_resharding(self):
179+
"""Tests cases where the cached spec may become outdated.
180+
"""
181+
182+
device_count = xr.global_runtime_device_count()
183+
if device_count < 4:
184+
self.skipTest("Need at least 4 devices for resharding test")
185+
186+
tensor = torch.randn(100, 50).to('xla')
187+
initial_mesh_shape = (device_count,)
188+
initial_partition_spec = (0, None)
189+
new_mesh_shape = (2, device_count // 2)
190+
new_partition_spec = (0, 1)
191+
192+
sharded_tensor = wrap_as_sharded_tensor(
193+
tensor,
194+
mesh_shape=initial_mesh_shape,
195+
partition_spec=initial_partition_spec)
196+
initial_spec = sharded_tensor._spec
197+
assert sharded_tensor._cached_spec is not None
198+
199+
# Changing mesh_shape / partition_spec through wrap_as_sharded_tensor invalidates cache
200+
resharded_tensor = wrap_as_sharded_tensor(
201+
sharded_tensor,
202+
mesh_shape=new_mesh_shape,
203+
partition_spec=initial_partition_spec)
204+
assert resharded_tensor._spec is not initial_spec
205+
assert resharded_tensor._spec.mesh.shape == new_mesh_shape
206+
207+
initial_spec = resharded_tensor._spec
208+
resharded_tensor = wrap_as_sharded_tensor(
209+
resharded_tensor,
210+
mesh_shape=new_mesh_shape,
211+
partition_spec=new_partition_spec)
212+
assert resharded_tensor._spec is not initial_spec
213+
assert resharded_tensor._spec.placements[1].dim == 1
214+
215+
def test_auto_wrapped_tensor_spec_failure(self):
216+
"""Test that auto-wrapped tensors fail when accessing _spec property.
217+
218+
Auto-wrapped tensors are created through operations that trigger __torch_dispatch__
219+
but don't yet have access to the sharding propagation done through open xla,
220+
causing ._spec to fail.
221+
"""
222+
device_count = xr.global_runtime_device_count()
223+
mesh = DeviceMesh("xla", torch.arange(device_count))
224+
tensor = torch.randn(4, 4)
225+
sharded_tensor = distribute_tensor(tensor, mesh, [Shard(0)])
226+
227+
auto_wrapped = sharded_tensor + sharded_tensor
228+
229+
with self.assertRaises(ValueError):
230+
_ = auto_wrapped._spec
231+
232+
233+
if __name__ == '__main__':
234+
test = unittest.main()
235+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/spmd/test_xla_sharding.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,10 +1162,6 @@ def test_mark_shard_scalar(self):
11621162
self.assertIsInstance(shard.indices, type(Ellipsis))
11631163
self.assertEqual(shard.replica_id, i)
11641164

1165-
# It looks like mesh_shape attribute is never implemented.
1166-
with self.assertRaises(AttributeError):
1167-
xt.mesh_shape
1168-
11691165
def test_global_mesh(self):
11701166
expected_mesh = self._get_mesh((1, self.n_devices))
11711167
xs.set_global_mesh(expected_mesh)

test/tpu/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py"
6161
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
6262
run_test "$_TEST_DIR/spmd/test_fsdp_v2.py"
6363
run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
64+
run_test "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py"
6465
run_test "$_TEST_DIR/test_gradient_accumulation.py"
6566
XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v
6667
run_test "$_TEST_DIR/test_autocast.py"

0 commit comments

Comments
 (0)