Skip to content

Commit 28c4f6f

Browse files
committed
Implement XLAShardedTensor._spec and test
1 parent 8b95c5d commit 28c4f6f

File tree

6 files changed

+261
-9
lines changed

6 files changed

+261
-9
lines changed

test/neuron/run_tests.sh

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ function run_test {
5656
PJRT_DEVICE=NEURON NEURON_NUM_DEVICES=1 run_coverage "$@"
5757
}
5858

59+
function run_test_multi_device {
60+
if ! test_is_selected "$1"; then
61+
return
62+
fi
63+
echo "Running in PjRt runtime: $@"
64+
PJRT_DEVICE=NEURON run_coverage "$@"
65+
}
66+
5967
function run_test_without_functionalization {
6068
if ! test_is_selected "$1"; then
6169
return
@@ -246,7 +254,8 @@ function run_xla_op_tests3 {
246254
run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py"
247255
#run_test "$_TEST_DIR/spmd/test_dtensor_integration.py"
248256
#run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py"
249-
run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
257+
run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
258+
run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_spec_conv.py"
250259
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
251260
#run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
252261
run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py"

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ function run_xla_op_tests3 {
254254
run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py"
255255
run_test_multi_devices_without_func "$_TEST_DIR/spmd/test_dtensor_integration3.py"
256256
run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
257+
run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py"
257258
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
258259
run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
259260
run_test "$_TEST_DIR/spmd/test_mp_input_sharding.py"
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import os
2+
import sys
3+
4+
import torch
5+
from torch.distributed.tensor import DeviceMesh, Shard, distribute_tensor
6+
7+
import torch_xla
8+
import torch_xla.runtime as xr
9+
10+
import unittest
11+
import test_xla_sharding_base
12+
13+
14+
class XLADTensorSpecConversionTest(test_xla_sharding_base.XlaShardingTest):
15+
16+
@classmethod
17+
def setUpClass(cls):
18+
super().setUpClass()
19+
20+
def test_sample_test_case(self):
21+
world_size = xr.global_runtime_device_count()
22+
mesh = DeviceMesh("xla", torch.arange(world_size))
23+
big_tensor = torch.randn(100000, 88)
24+
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)])
25+
26+
assert my_dtensor._spec.mesh.device_type == mesh.device_type
27+
assert my_dtensor._spec.placements == (Shard(0),)
28+
29+
def test_xla_to_dtensor_spec_conversion(self):
30+
device_count = xr.global_runtime_device_count()
31+
mesh = DeviceMesh("xla", list(range(device_count)))
32+
33+
# Test different sharding patterns
34+
from torch.distributed.tensor.placement_types import Replicate
35+
test_cases = [
36+
(torch.randn(100, 50), [Shard(0)]),
37+
(torch.randn(100, 50), [Shard(1)]),
38+
(torch.randn(100, 50, 25), [Shard(0)]),
39+
(torch.randn(100, 50), [Replicate()]),
40+
]
41+
42+
for tensor, placements in test_cases:
43+
xla_tensor = distribute_tensor(tensor, mesh, placements)
44+
spec = xla_tensor._spec
45+
46+
assert spec is not None
47+
assert spec.mesh.device_type == "xla"
48+
assert spec.tensor_meta.shape == tensor.shape
49+
assert spec.tensor_meta.dtype == tensor.dtype
50+
assert len(spec.placements) >= 1
51+
assert spec.placements == tuple(placements)
52+
53+
def test_mesh_conversion(self):
54+
device_count = xr.global_runtime_device_count()
55+
original_mesh = DeviceMesh("xla", list(range(device_count)))
56+
tensor = torch.randn(50, 50)
57+
xla_tensor = distribute_tensor(tensor, original_mesh, [Shard(0)])
58+
59+
converted_spec = xla_tensor._spec
60+
61+
assert converted_spec.mesh.device_type == "xla"
62+
assert converted_spec.mesh.size() == device_count
63+
64+
def test_spec_caching(self):
65+
"""Test that _spec property caches results for better performance"""
66+
import time
67+
device_count = xr.global_runtime_device_count()
68+
mesh = DeviceMesh("xla", list(range(device_count)))
69+
tensor = torch.randn(1000,
70+
1000) # Large tensor to make spec creation noticeable
71+
xla_tensor = distribute_tensor(tensor, mesh, [Shard(0)])
72+
73+
# first access should create and cache the spec
74+
start_time = time.time()
75+
spec1 = xla_tensor._spec
76+
first_access_time = time.time() - start_time
77+
78+
# should be much faster due to caching
79+
start_time = time.time()
80+
spec2 = xla_tensor._spec
81+
second_access_time = time.time() - start_time
82+
83+
assert spec1 is spec2
84+
print(
85+
f"First access: {first_access_time:.6f}s, Second access: {second_access_time:.6f}s"
86+
)
87+
assert second_access_time * 10 < first_access_time, \
88+
f"Cached access should be much faster: {first_access_time:.6f}s vs {second_access_time:.6f}s"
89+
90+
def _create_test_tensor_and_mesh(self, tensor_shape, mesh_shape, placements):
91+
"""Helper to create tensor and mesh for testing"""
92+
device_count = xr.global_runtime_device_count()
93+
if device_count < max(mesh_shape):
94+
self.skipTest(
95+
f"Need at least {max(mesh_shape)} devices, got {device_count}")
96+
97+
mesh = DeviceMesh("xla", torch.arange(device_count).reshape(mesh_shape))
98+
tensor = torch.randn(*tensor_shape)
99+
return distribute_tensor(tensor, mesh, placements), mesh
100+
101+
def test_multi_dim_sharding_spec(self):
102+
"""Test _spec for multi-dimensional sharding"""
103+
device_count = xr.global_runtime_device_count()
104+
if device_count < 4:
105+
self.skipTest("Need at least 4 devices for 2D mesh")
106+
107+
mesh_shape = (2, device_count // 2)
108+
xla_tensor, mesh = self._create_test_tensor_and_mesh(
109+
(100, 50), mesh_shape, [Shard(0), Shard(1)])
110+
spec = xla_tensor._spec
111+
112+
assert len(spec.placements) == 2
113+
assert spec.mesh.ndim == 2
114+
115+
def test_tensor_operations_preserve_spec(self):
116+
"""Test that tensor operations preserve sharding metadata"""
117+
xla_tensor, mesh = self._create_test_tensor_and_mesh((100, 50), (-1,),
118+
[Shard(0)])
119+
120+
result_add = xla_tensor + 1
121+
result_mul = xla_tensor * 2
122+
result_relu = torch.relu(xla_tensor)
123+
124+
for result in [result_add, result_mul, result_relu]:
125+
assert hasattr(result, '_spec')
126+
assert result._spec.mesh.device_type == "xla"
127+
128+
def test_mixed_placement_spec(self):
129+
"""Test _spec for tensors with mixed shard/replicate placements"""
130+
from torch.distributed.tensor.placement_types import Replicate
131+
device_count = xr.global_runtime_device_count()
132+
if device_count < 4:
133+
self.skipTest("Need at least 4 devices for 2D mesh")
134+
135+
mesh_shape = (2, device_count // 2)
136+
xla_tensor, mesh = self._create_test_tensor_and_mesh(
137+
(100, 50), mesh_shape, [Shard(0), Replicate()])
138+
spec = xla_tensor._spec
139+
140+
assert len(spec.placements) == 2
141+
assert isinstance(spec.placements[0], Shard)
142+
assert isinstance(spec.placements[1], Replicate)
143+
144+
145+
if __name__ == '__main__':
146+
test = unittest.main()
147+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/tpu/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py"
6161
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
6262
run_test "$_TEST_DIR/spmd/test_fsdp_v2.py"
6363
run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
64+
run_test "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py"
6465
run_test "$_TEST_DIR/test_gradient_accumulation.py"
6566
XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v
6667
run_test "$_TEST_DIR/test_autocast.py"

torch_xla/distributed/spmd/xla_sharded_tensor.py

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,15 @@ class XLAShardedTensor(torch.Tensor):
9191
# >> assert len(input.shape) == len(partition_spec)
9292
partition_spec: Tuple[int, None]
9393

94-
__slots__ = ['global_tensor']
94+
__slots__ = ['global_tensor', 'mesh_shape', 'partition_spec', '_cached_spec']
9595

9696
@staticmethod
97-
def __new__(cls, elem: torch.Tensor, *args, **kwargs):
97+
def __new__(cls,
98+
elem: torch.Tensor,
99+
mesh_shape=None,
100+
partition_spec=None,
101+
*args,
102+
**kwargs):
98103
# TODO(yeounoh) wrapper can take different arguments
99104
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
100105
cls,
@@ -106,6 +111,11 @@ def __new__(cls, elem: torch.Tensor, *args, **kwargs):
106111
device=elem.device,
107112
requires_grad=kwargs.get("requires_grad", False))
108113
r.global_tensor = elem.detach() if r.requires_grad else elem
114+
# Store mesh and partition information for DTensor compatibility
115+
if mesh_shape is not None:
116+
r.mesh_shape = mesh_shape
117+
if partition_spec is not None:
118+
r.partition_spec = partition_spec
109119
return r
110120

111121
# Shards on the devices are materialized/available after the lazy
@@ -159,7 +169,27 @@ def unwrap(elem):
159169
return elem.global_tensor if isinstance(elem, XLAShardedTensor) else elem
160170

161171
def wrap(elem):
162-
return XLAShardedTensor(elem) if isinstance(elem, torch.Tensor) else elem
172+
if isinstance(elem,
173+
torch.Tensor) and not isinstance(elem, XLAShardedTensor):
174+
# Try to get mesh/partition info from any XLAShardedTensor in args
175+
mesh_shape = None
176+
partition_spec = None
177+
178+
def find_sharded_info(x):
179+
nonlocal mesh_shape, partition_spec
180+
if isinstance(x, XLAShardedTensor):
181+
if hasattr(x, 'mesh_shape') and x.mesh_shape:
182+
mesh_shape = x.mesh_shape
183+
if hasattr(x, 'partition_spec') and x.partition_spec:
184+
partition_spec = x.partition_spec
185+
186+
tree_map(find_sharded_info, args)
187+
if kwargs:
188+
tree_map(find_sharded_info, kwargs)
189+
190+
return XLAShardedTensor(
191+
elem, mesh_shape=mesh_shape, partition_spec=partition_spec)
192+
return elem
163193

164194
# no_dispatch is only needed if you use enable_python_mode.
165195
# It prevents infinite recursion.
@@ -169,6 +199,60 @@ def wrap(elem):
169199
func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
170200
return rs
171201

202+
@property
203+
def _spec(self):
204+
"""
205+
Convert XLA sharding information to DTensorSpec for DTensor interface compatibility.
206+
"""
207+
# Return cached spec if available
208+
if hasattr(self, '_cached_spec'):
209+
return self._cached_spec
210+
211+
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
212+
from torch.distributed.device_mesh import DeviceMesh
213+
from torch.distributed.tensor.placement_types import Shard, Replicate
214+
215+
# use existing mesh_shape
216+
if hasattr(self, 'mesh_shape') and self.mesh_shape:
217+
import torch_xla.runtime as xr
218+
device_count = xr.global_runtime_device_count()
219+
device_list = list(range(device_count))
220+
mesh = DeviceMesh("xla",
221+
torch.tensor(device_list).reshape(self.mesh_shape))
222+
else:
223+
# default to 1D mesh
224+
import torch_xla.runtime as xr
225+
device_count = xr.global_runtime_device_count()
226+
mesh = DeviceMesh("xla", list(range(device_count)))
227+
228+
# use existing partition_spec
229+
if hasattr(self, 'partition_spec') and self.partition_spec:
230+
placements = []
231+
for mesh_dim in range(
232+
len(self.mesh_shape
233+
) if hasattr(self, 'mesh_shape') and self.mesh_shape else 1):
234+
# find tensor dimension sharded on this mesh dimension
235+
tensor_dim = None
236+
for t_dim, m_dim in enumerate(self.partition_spec):
237+
if m_dim == mesh_dim:
238+
tensor_dim = t_dim
239+
break
240+
placements.append(
241+
Shard(tensor_dim) if tensor_dim is not None else Replicate())
242+
else:
243+
placements = [Replicate()]
244+
245+
# tensor metadata
246+
tensor_meta = TensorMeta(
247+
shape=self.global_tensor.shape,
248+
stride=self.global_tensor.stride(),
249+
dtype=self.global_tensor.dtype)
250+
251+
# Create and cache the spec
252+
self._cached_spec = DTensorSpec(
253+
mesh=mesh, placements=tuple(placements), tensor_meta=tensor_meta)
254+
return self._cached_spec
255+
172256
@classmethod
173257
def __torch_function__(cls, func, types, args=(), kwargs=None):
174258
return super().__torch_function__(func, types, args, kwargs)

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,9 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
651651
op_sharding = mesh.get_op_sharding(partition_spec)
652652
annotate_func = torch_xla._XLAC._xla_mark_sharding
653653
annotate_func(unwrap_sharded_tensor(t), op_sharding)
654-
return wrap_as_sharded_tensor(t)
654+
# Pass mesh and partition spec information for DTensor compatibility
655+
return wrap_as_sharded_tensor(
656+
t, mesh_shape=mesh.mesh_shape, partition_spec=partition_spec)
655657

656658

657659
def mark_sharding_with_gradients(
@@ -755,11 +757,19 @@ def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor:
755757
return t
756758

757759

758-
def wrap_as_sharded_tensor(
759-
t: Union[torch.Tensor, XLAShardedTensor]) -> XLAShardedTensor:
760+
def wrap_as_sharded_tensor(t: Union[torch.Tensor, XLAShardedTensor],
761+
mesh_shape=None,
762+
partition_spec=None) -> XLAShardedTensor:
763+
# pass along mesh and partition spec information
760764
if not isinstance(t, XLAShardedTensor):
761-
return XLAShardedTensor(t)
762-
return t
765+
return XLAShardedTensor(
766+
t, mesh_shape=mesh_shape, partition_spec=partition_spec)
767+
else:
768+
if mesh_shape is not None:
769+
t.mesh_shape = mesh_shape
770+
if partition_spec is not None:
771+
t.partition_spec = partition_spec
772+
return t
763773

764774

765775
def unwrap_sharded_tensor(

0 commit comments

Comments
 (0)