Skip to content

Commit bb4eb3b

Browse files
committed
Implement XLAShardedTensor._spec and test
1 parent 93a5e58 commit bb4eb3b

File tree

6 files changed

+239
-8
lines changed

6 files changed

+239
-8
lines changed

test/neuron/run_tests.sh

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ function run_test {
5656
PJRT_DEVICE=NEURON NEURON_NUM_DEVICES=1 run_coverage "$@"
5757
}
5858

59+
function run_test_multi_device {
60+
if ! test_is_selected "$1"; then
61+
return
62+
fi
63+
echo "Running in PjRt runtime: $@"
64+
PJRT_DEVICE=NEURON run_coverage "$@"
65+
}
66+
5967
function run_test_without_functionalization {
6068
if ! test_is_selected "$1"; then
6169
return
@@ -246,7 +254,8 @@ function run_xla_op_tests3 {
246254
run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py"
247255
#run_test "$_TEST_DIR/spmd/test_dtensor_integration.py"
248256
#run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py"
249-
run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
257+
run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
258+
run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_spec_conv.py"
250259
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
251260
#run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
252261
run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py"

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ function run_xla_op_tests3 {
254254
run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py"
255255
run_test_multi_devices_without_func "$_TEST_DIR/spmd/test_dtensor_integration3.py"
256256
run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
257+
run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py"
257258
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
258259
run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
259260
run_test "$_TEST_DIR/spmd/test_mp_input_sharding.py"
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import os
2+
import sys
3+
4+
import torch
5+
from torch.distributed.tensor import DeviceMesh, Shard, distribute_tensor
6+
7+
import torch_xla
8+
import torch_xla.runtime as xr
9+
10+
import unittest
11+
import test_xla_sharding_base
12+
13+
14+
class XLADTensorSpecConversionTest(test_xla_sharding_base.XlaShardingTest):
15+
16+
@classmethod
17+
def setUpClass(cls):
18+
super().setUpClass()
19+
20+
def test_sample_test_case(self):
21+
world_size = xr.global_runtime_device_count()
22+
mesh = DeviceMesh("xla", torch.arange(world_size))
23+
big_tensor = torch.randn(100000, 88)
24+
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)])
25+
26+
assert my_dtensor._spec.mesh.device_type == mesh.device_type
27+
assert my_dtensor._spec.placements == (Shard(0),)
28+
29+
def test_xla_to_dtensor_spec_conversion(self):
30+
device_count = xr.global_runtime_device_count()
31+
mesh = DeviceMesh("xla", list(range(device_count)))
32+
33+
# Test different sharding patterns
34+
from torch.distributed.tensor.placement_types import Replicate
35+
test_cases = [
36+
(torch.randn(100, 50), [Shard(0)]),
37+
(torch.randn(100, 50), [Shard(1)]),
38+
(torch.randn(100, 50, 25), [Shard(0)]),
39+
(torch.randn(100, 50), [Replicate()]),
40+
]
41+
42+
for tensor, placements in test_cases:
43+
xla_tensor = distribute_tensor(tensor, mesh, placements)
44+
spec = xla_tensor._spec
45+
46+
assert spec is not None
47+
assert spec.mesh.device_type == "xla"
48+
assert spec.tensor_meta.shape == tensor.shape
49+
assert spec.tensor_meta.dtype == tensor.dtype
50+
assert len(spec.placements) >= 1
51+
assert spec.placements == tuple(placements)
52+
53+
def test_mesh_conversion(self):
54+
device_count = xr.global_runtime_device_count()
55+
original_mesh = DeviceMesh("xla", list(range(device_count)))
56+
tensor = torch.randn(50, 50)
57+
xla_tensor = distribute_tensor(tensor, original_mesh, [Shard(0)])
58+
59+
converted_spec = xla_tensor._spec
60+
61+
assert converted_spec.mesh.device_type == "xla"
62+
assert converted_spec.mesh.size() == device_count
63+
64+
def test_spec_caching(self):
65+
"""Test that _spec property caches results for better performance"""
66+
import time
67+
device_count = xr.global_runtime_device_count()
68+
mesh = DeviceMesh("xla", list(range(device_count)))
69+
tensor = torch.randn(1000, 1000) # Large tensor to make spec creation noticeable
70+
xla_tensor = distribute_tensor(tensor, mesh, [Shard(0)])
71+
72+
# first access should create and cache the spec
73+
start_time = time.time()
74+
spec1 = xla_tensor._spec
75+
first_access_time = time.time() - start_time
76+
77+
# should be much faster due to caching
78+
start_time = time.time()
79+
spec2 = xla_tensor._spec
80+
second_access_time = time.time() - start_time
81+
82+
assert spec1 is spec2
83+
print(f"First access: {first_access_time:.6f}s, Second access: {second_access_time:.6f}s")
84+
assert second_access_time * 10 < first_access_time, \
85+
f"Cached access should be much faster: {first_access_time:.6f}s vs {second_access_time:.6f}s"
86+
87+
def _create_test_tensor_and_mesh(self, tensor_shape, mesh_shape, placements):
88+
"""Helper to create tensor and mesh for testing"""
89+
device_count = xr.global_runtime_device_count()
90+
if device_count < max(mesh_shape):
91+
self.skipTest(f"Need at least {max(mesh_shape)} devices, got {device_count}")
92+
93+
mesh = DeviceMesh("xla", torch.arange(device_count).reshape(mesh_shape))
94+
tensor = torch.randn(*tensor_shape)
95+
return distribute_tensor(tensor, mesh, placements), mesh
96+
97+
def test_multi_dim_sharding_spec(self):
98+
"""Test _spec for multi-dimensional sharding"""
99+
device_count = xr.global_runtime_device_count()
100+
if device_count < 4:
101+
self.skipTest("Need at least 4 devices for 2D mesh")
102+
103+
mesh_shape = (2, device_count // 2)
104+
xla_tensor, mesh = self._create_test_tensor_and_mesh((100, 50), mesh_shape, [Shard(0), Shard(1)])
105+
spec = xla_tensor._spec
106+
107+
assert len(spec.placements) == 2
108+
assert spec.mesh.ndim == 2
109+
110+
def test_tensor_operations_preserve_spec(self):
111+
"""Test that tensor operations preserve sharding metadata"""
112+
xla_tensor, mesh = self._create_test_tensor_and_mesh((100, 50), (-1,), [Shard(0)])
113+
114+
result_add = xla_tensor + 1
115+
result_mul = xla_tensor * 2
116+
result_relu = torch.relu(xla_tensor)
117+
118+
for result in [result_add, result_mul, result_relu]:
119+
assert hasattr(result, '_spec')
120+
assert result._spec.mesh.device_type == "xla"
121+
122+
def test_mixed_placement_spec(self):
123+
"""Test _spec for tensors with mixed shard/replicate placements"""
124+
from torch.distributed.tensor.placement_types import Replicate
125+
device_count = xr.global_runtime_device_count()
126+
if device_count < 4:
127+
self.skipTest("Need at least 4 devices for 2D mesh")
128+
129+
mesh_shape = (2, device_count // 2)
130+
xla_tensor, mesh = self._create_test_tensor_and_mesh((100, 50), mesh_shape, [Shard(0), Replicate()])
131+
spec = xla_tensor._spec
132+
133+
assert len(spec.placements) == 2
134+
assert isinstance(spec.placements[0], Shard)
135+
assert isinstance(spec.placements[1], Replicate)
136+
137+
138+
if __name__ == '__main__':
139+
test = unittest.main()
140+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/tpu/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py"
6161
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
6262
run_test "$_TEST_DIR/spmd/test_fsdp_v2.py"
6363
run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
64+
run_test "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py"
6465
run_test "$_TEST_DIR/test_gradient_accumulation.py"
6566
XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v
6667
run_test "$_TEST_DIR/test_autocast.py"

torch_xla/distributed/spmd/xla_sharded_tensor.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,10 @@ class XLAShardedTensor(torch.Tensor):
9191
# >> assert len(input.shape) == len(partition_spec)
9292
partition_spec: Tuple[int, None]
9393

94-
__slots__ = ['global_tensor']
94+
__slots__ = ['global_tensor', 'mesh_shape', 'partition_spec', '_cached_spec']
9595

9696
@staticmethod
97-
def __new__(cls, elem: torch.Tensor, *args, **kwargs):
97+
def __new__(cls, elem: torch.Tensor, mesh_shape=None, partition_spec=None, *args, **kwargs):
9898
# TODO(yeounoh) wrapper can take different arguments
9999
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
100100
cls,
@@ -106,6 +106,11 @@ def __new__(cls, elem: torch.Tensor, *args, **kwargs):
106106
device=elem.device,
107107
requires_grad=kwargs.get("requires_grad", False))
108108
r.global_tensor = elem.detach() if r.requires_grad else elem
109+
# Store mesh and partition information for DTensor compatibility
110+
if mesh_shape is not None:
111+
r.mesh_shape = mesh_shape
112+
if partition_spec is not None:
113+
r.partition_spec = partition_spec
109114
return r
110115

111116
# Shards on the devices are materialized/available after the lazy
@@ -159,7 +164,25 @@ def unwrap(elem):
159164
return elem.global_tensor if isinstance(elem, XLAShardedTensor) else elem
160165

161166
def wrap(elem):
162-
return XLAShardedTensor(elem) if isinstance(elem, torch.Tensor) else elem
167+
if isinstance(elem, torch.Tensor) and not isinstance(elem, XLAShardedTensor):
168+
# Try to get mesh/partition info from any XLAShardedTensor in args
169+
mesh_shape = None
170+
partition_spec = None
171+
172+
def find_sharded_info(x):
173+
nonlocal mesh_shape, partition_spec
174+
if isinstance(x, XLAShardedTensor):
175+
if hasattr(x, 'mesh_shape') and x.mesh_shape:
176+
mesh_shape = x.mesh_shape
177+
if hasattr(x, 'partition_spec') and x.partition_spec:
178+
partition_spec = x.partition_spec
179+
180+
tree_map(find_sharded_info, args)
181+
if kwargs:
182+
tree_map(find_sharded_info, kwargs)
183+
184+
return XLAShardedTensor(elem, mesh_shape=mesh_shape, partition_spec=partition_spec)
185+
return elem
163186

164187
# no_dispatch is only needed if you use enable_python_mode.
165188
# It prevents infinite recursion.
@@ -169,6 +192,56 @@ def wrap(elem):
169192
func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
170193
return rs
171194

195+
@property
196+
def _spec(self):
197+
"""
198+
Convert XLA sharding information to DTensorSpec for DTensor interface compatibility.
199+
"""
200+
# Return cached spec if available
201+
if hasattr(self, '_cached_spec'):
202+
return self._cached_spec
203+
204+
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
205+
from torch.distributed.device_mesh import DeviceMesh
206+
from torch.distributed.tensor.placement_types import Shard, Replicate
207+
208+
# use existing mesh_shape
209+
if hasattr(self, 'mesh_shape') and self.mesh_shape:
210+
import torch_xla.runtime as xr
211+
device_count = xr.global_runtime_device_count()
212+
device_list = list(range(device_count))
213+
mesh = DeviceMesh("xla", torch.tensor(device_list).reshape(self.mesh_shape))
214+
else:
215+
# default to 1D mesh
216+
import torch_xla.runtime as xr
217+
device_count = xr.global_runtime_device_count()
218+
mesh = DeviceMesh("xla", list(range(device_count)))
219+
220+
# use existing partition_spec
221+
if hasattr(self, 'partition_spec') and self.partition_spec:
222+
placements = []
223+
for mesh_dim in range(len(self.mesh_shape) if hasattr(self, 'mesh_shape') and self.mesh_shape else 1):
224+
# find tensor dimension sharded on this mesh dimension
225+
tensor_dim = None
226+
for t_dim, m_dim in enumerate(self.partition_spec):
227+
if m_dim == mesh_dim:
228+
tensor_dim = t_dim
229+
break
230+
placements.append(Shard(tensor_dim) if tensor_dim is not None else Replicate())
231+
else:
232+
placements = [Replicate()]
233+
234+
# tensor metadata
235+
tensor_meta = TensorMeta(
236+
shape=self.global_tensor.shape,
237+
stride=self.global_tensor.stride(),
238+
dtype=self.global_tensor.dtype
239+
)
240+
241+
# Create and cache the spec
242+
self._cached_spec = DTensorSpec(mesh=mesh, placements=tuple(placements), tensor_meta=tensor_meta)
243+
return self._cached_spec
244+
172245
@classmethod
173246
def __torch_function__(cls, func, types, args=(), kwargs=None):
174247
return super().__torch_function__(func, types, args, kwargs)

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,8 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
651651
op_sharding = mesh.get_op_sharding(partition_spec)
652652
annotate_func = torch_xla._XLAC._xla_mark_sharding
653653
annotate_func(unwrap_sharded_tensor(t), op_sharding)
654-
return wrap_as_sharded_tensor(t)
654+
# Pass mesh and partition spec information for DTensor compatibility
655+
return wrap_as_sharded_tensor(t, mesh_shape=mesh.mesh_shape, partition_spec=partition_spec)
655656

656657

657658
def mark_sharding_with_gradients(
@@ -756,10 +757,16 @@ def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor:
756757

757758

758759
def wrap_as_sharded_tensor(
759-
t: Union[torch.Tensor, XLAShardedTensor]) -> XLAShardedTensor:
760+
t: Union[torch.Tensor, XLAShardedTensor], mesh_shape=None, partition_spec=None) -> XLAShardedTensor:
761+
# pass along mesh and partition spec information
760762
if not isinstance(t, XLAShardedTensor):
761-
return XLAShardedTensor(t)
762-
return t
763+
return XLAShardedTensor(t, mesh_shape=mesh_shape, partition_spec=partition_spec)
764+
else:
765+
if mesh_shape is not None:
766+
t.mesh_shape = mesh_shape
767+
if partition_spec is not None:
768+
t.partition_spec = partition_spec
769+
return t
763770

764771

765772
def unwrap_sharded_tensor(

0 commit comments

Comments
 (0)