From 6eaff8be2dd755ef438d57b14a23289ff24ca62c Mon Sep 17 00:00:00 2001 From: Claire Huang Date: Tue, 15 Jul 2025 22:26:45 +0000 Subject: [PATCH 1/5] Implement XLAShardedTensor._spec and test --- test/neuron/run_tests.sh | 11 +- test/run_tests.sh | 1 + test/spmd/test_xla_dtensor_spec_conversion.py | 149 ++++++++++++++++++ test/tpu/run_tests.sh | 1 + .../distributed/spmd/xla_sharded_tensor.py | 87 +++++++++- torch_xla/distributed/spmd/xla_sharding.py | 29 +++- 6 files changed, 266 insertions(+), 12 deletions(-) create mode 100644 test/spmd/test_xla_dtensor_spec_conversion.py diff --git a/test/neuron/run_tests.sh b/test/neuron/run_tests.sh index 22e778945f9c..f7671cc3d827 100755 --- a/test/neuron/run_tests.sh +++ b/test/neuron/run_tests.sh @@ -56,6 +56,14 @@ function run_test { PJRT_DEVICE=NEURON NEURON_NUM_DEVICES=1 run_coverage "$@" } +function run_test_multi_device { + if ! test_is_selected "$1"; then + return + fi + echo "Running in PjRt runtime: $@" + PJRT_DEVICE=NEURON run_coverage "$@" +} + function run_test_without_functionalization { if ! test_is_selected "$1"; then return @@ -246,7 +254,8 @@ function run_xla_op_tests3 { run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py" #run_test "$_TEST_DIR/spmd/test_dtensor_integration.py" #run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py" - run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" + run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" + run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_spec_conv.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" #run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py" diff --git a/test/run_tests.sh b/test/run_tests.sh index 93f4cb33c061..b2cc8f751d2c 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -254,6 +254,7 @@ function run_xla_op_tests3 { run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py" run_test_multi_devices_without_func "$_TEST_DIR/spmd/test_dtensor_integration3.py" run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" + run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" run_test "$_TEST_DIR/spmd/test_mp_input_sharding.py" diff --git a/test/spmd/test_xla_dtensor_spec_conversion.py b/test/spmd/test_xla_dtensor_spec_conversion.py new file mode 100644 index 000000000000..12102f555c20 --- /dev/null +++ b/test/spmd/test_xla_dtensor_spec_conversion.py @@ -0,0 +1,149 @@ +import os +import sys + +import torch +from torch.distributed.tensor import DeviceMesh, Shard, distribute_tensor + +import torch_xla +import torch_xla.runtime as xr + +import unittest +import test_xla_sharding_base + + +class XLADTensorSpecConversionTest(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + super().setUpClass() + + def test_sample_test_case(self): + world_size = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", torch.arange(world_size)) + big_tensor = torch.randn(100000, 88) + my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)]) + + assert my_dtensor._spec.mesh.device_type == mesh.device_type + assert my_dtensor._spec.placements == (Shard(0),) + + def test_xla_to_dtensor_spec_conversion(self): + device_count = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", list(range(device_count))) + + # Test different sharding patterns + from torch.distributed.tensor.placement_types import Replicate + test_cases = [ + (torch.randn(100, 50), [Shard(0)]), + (torch.randn(100, 50), [Shard(1)]), + (torch.randn(100, 50, 25), [Shard(0)]), + (torch.randn(100, 50), [Replicate()]), + ] + + for tensor, placements in test_cases: + xla_tensor = distribute_tensor(tensor, mesh, placements) + spec = xla_tensor._spec + + assert spec is not None + assert spec.mesh.device_type == "xla" + assert spec.tensor_meta.shape == tensor.shape + assert spec.tensor_meta.dtype == tensor.dtype + assert len(spec.placements) >= 1 + assert spec.placements == tuple(placements) + + def test_mesh_conversion(self): + device_count = xr.global_runtime_device_count() + original_mesh = DeviceMesh("xla", list(range(device_count))) + tensor = torch.randn(50, 50) + xla_tensor = distribute_tensor(tensor, original_mesh, [Shard(0)]) + + converted_spec = xla_tensor._spec + + assert converted_spec.mesh.device_type == "xla" + assert converted_spec.mesh.size() == device_count + # assert on mesh dimensions + assert converted_spec.mesh.shape == original_mesh.shape + + def test_spec_caching(self): + """Test that _spec property caches results for better performance""" + import time + device_count = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", list(range(device_count))) + tensor = torch.randn(1000, + 1000) # Large tensor to make spec creation noticeable + xla_tensor = distribute_tensor(tensor, mesh, [Shard(0)]) + + # first access should create and cache the spec + start_time = time.time() + spec1 = xla_tensor._spec + first_access_time = time.time() - start_time + + # should be much faster due to caching + start_time = time.time() + spec2 = xla_tensor._spec + second_access_time = time.time() - start_time + + assert spec1 is spec2 + print( + f"First access: {first_access_time:.6f}s, Second access: {second_access_time:.6f}s" + ) + assert second_access_time * 10 < first_access_time, \ + f"Cached access should be much faster: {first_access_time:.6f}s vs {second_access_time:.6f}s" + + def _create_test_tensor_and_mesh(self, tensor_shape, mesh_shape, placements): + """Helper to create tensor and mesh for testing""" + device_count = xr.global_runtime_device_count() + if device_count < max(mesh_shape): + self.skipTest( + f"Need at least {max(mesh_shape)} devices, got {device_count}") + + mesh = DeviceMesh("xla", torch.arange(device_count).reshape(mesh_shape)) + tensor = torch.randn(*tensor_shape) + return distribute_tensor(tensor, mesh, placements), mesh + + def test_multi_dim_sharding_spec(self): + """Test _spec for multi-dimensional sharding""" + device_count = xr.global_runtime_device_count() + if device_count < 4: + self.skipTest("Need at least 4 devices for 2D mesh") + + mesh_shape = (2, device_count // 2) + xla_tensor, mesh = self._create_test_tensor_and_mesh( + (100, 50), mesh_shape, [Shard(0), Shard(1)]) + spec = xla_tensor._spec + + assert len(spec.placements) == 2 + assert spec.mesh.ndim == 2 + + def test_tensor_operations_preserve_spec(self): + """Test that tensor operations preserve sharding metadata""" + xla_tensor, mesh = self._create_test_tensor_and_mesh((100, 50), (-1,), + [Shard(0)]) + + result_add = xla_tensor + 1 + result_mul = xla_tensor * 2 + result_relu = torch.relu(xla_tensor) + + for result in [result_add, result_mul, result_relu]: + assert hasattr(result, '_spec') + assert result._spec.mesh.device_type == "xla" + + def test_mixed_placement_spec(self): + """Test _spec for tensors with mixed shard/replicate placements""" + from torch.distributed.tensor.placement_types import Replicate + device_count = xr.global_runtime_device_count() + if device_count < 4: + self.skipTest("Need at least 4 devices for 2D mesh") + + mesh_shape = (2, device_count // 2) + xla_tensor, mesh = self._create_test_tensor_and_mesh( + (100, 50), mesh_shape, [Shard(0), Replicate()]) + spec = xla_tensor._spec + + assert len(spec.placements) == 2 + assert isinstance(spec.placements[0], Shard) + assert isinstance(spec.placements[1], Replicate) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 1f6f5249b93b..24f18d3bdcda 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -61,6 +61,7 @@ run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" run_test "$_TEST_DIR/spmd/test_fsdp_v2.py" run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" +run_test "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" run_test "$_TEST_DIR/test_gradient_accumulation.py" XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v run_test "$_TEST_DIR/test_autocast.py" diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index aedfd6a801e3..2f0e432a5b2e 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -6,6 +6,10 @@ from typing import List, Tuple, Iterator, Union import contextlib import collections +import torch_xla.runtime as xr +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import Shard, Replicate @dataclass @@ -91,10 +95,15 @@ class XLAShardedTensor(torch.Tensor): # >> assert len(input.shape) == len(partition_spec) partition_spec: Tuple[int, None] - __slots__ = ['global_tensor'] + __slots__ = ['global_tensor', 'mesh_shape', 'partition_spec', '_cached_spec'] @staticmethod - def __new__(cls, elem: torch.Tensor, *args, **kwargs): + def __new__(cls, + elem: torch.Tensor, + mesh_shape=None, + partition_spec=None, + *args, + **kwargs): # TODO(yeounoh) wrapper can take different arguments r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] cls, @@ -106,6 +115,11 @@ def __new__(cls, elem: torch.Tensor, *args, **kwargs): device=elem.device, requires_grad=kwargs.get("requires_grad", False)) r.global_tensor = elem.detach() if r.requires_grad else elem + # Store mesh and partition information for DTensor compatibility + if mesh_shape is not None: + r.mesh_shape = mesh_shape + if partition_spec is not None: + r.partition_spec = partition_spec return r # Shards on the devices are materialized/available after the lazy @@ -159,7 +173,27 @@ def unwrap(elem): return elem.global_tensor if isinstance(elem, XLAShardedTensor) else elem def wrap(elem): - return XLAShardedTensor(elem) if isinstance(elem, torch.Tensor) else elem + if isinstance(elem, + torch.Tensor) and not isinstance(elem, XLAShardedTensor): + # Try to get mesh/partition info from any XLAShardedTensor in args + mesh_shape = None + partition_spec = None + + def find_sharded_info(x): + nonlocal mesh_shape, partition_spec + if isinstance(x, XLAShardedTensor): + if hasattr(x, 'mesh_shape') and x.mesh_shape: + mesh_shape = x.mesh_shape + if hasattr(x, 'partition_spec') and x.partition_spec: + partition_spec = x.partition_spec + + tree_map(find_sharded_info, args) + if kwargs: + tree_map(find_sharded_info, kwargs) + + return XLAShardedTensor( + elem, mesh_shape=mesh_shape, partition_spec=partition_spec) + return elem # no_dispatch is only needed if you use enable_python_mode. # It prevents infinite recursion. @@ -169,6 +203,53 @@ def wrap(elem): func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) return rs + @property + def _spec(self): + """ + Convert XLA sharding information to DTensorSpec for DTensor interface compatibility. + """ + # Return cached spec if available + if hasattr(self, '_cached_spec'): + return self._cached_spec + + # use existing mesh_shape + if hasattr(self, 'mesh_shape') and self.mesh_shape: + import torch_xla.runtime as xr + device_count = xr.global_runtime_device_count() + device_list = list(range(device_count)) + mesh = DeviceMesh("xla", + torch.tensor(device_list).reshape(self.mesh_shape)) + else: + raise ValueError("mesh_shape must be specified to create DTensorSpec") + + # use existing partition_spec + if hasattr(self, 'partition_spec') and self.partition_spec: + placements = [] + for mesh_dim in range( + len(self.mesh_shape + ) if hasattr(self, 'mesh_shape') and self.mesh_shape else 1): + # find tensor dimension sharded on this mesh dimension + tensor_dim = None + for t_dim, m_dim in enumerate(self.partition_spec): + if m_dim == mesh_dim: + tensor_dim = t_dim + break + placements.append( + Shard(tensor_dim) if tensor_dim is not None else Replicate()) + else: + raise ValueError("partition_spec must be specified to create DTensorSpec") + + # tensor metadata + tensor_meta = TensorMeta( + shape=self.global_tensor.shape, + stride=self.global_tensor.stride(), + dtype=self.global_tensor.dtype) + + # Create and cache the spec + self._cached_spec = DTensorSpec( + mesh=mesh, placements=tuple(placements), tensor_meta=tensor_meta) + return self._cached_spec + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): return super().__torch_function__(func, types, args, kwargs) diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 49229b17cffe..751fe7e9a661 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -543,7 +543,8 @@ def enable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh = get_global_mesh() if mesh is None else mesh t = mark_sharding(unwrap_sharded_tensor(t), mesh, partition_spec) t = torch_xla._XLAC._spmd_full_to_shard_shape(unwrap_sharded_tensor(t)) - return wrap_as_sharded_tensor(t) + return wrap_as_sharded_tensor( + t, mesh_shape=mesh.mesh_shape, partition_spec=partition_spec) def disable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor], @@ -560,7 +561,8 @@ def disable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor], t = torch_xla._XLAC._spmd_shard_to_full_shape( unwrap_sharded_tensor(t), mesh.get_op_sharding(partition_spec), full_shape, t.dtype) - return wrap_as_sharded_tensor(t) + return wrap_as_sharded_tensor( + t, mesh_shape=mesh.mesh_shape, partition_spec=partition_spec) def annotate_custom_sharding(t: Union[torch.Tensor, @@ -594,7 +596,8 @@ def annotate_custom_sharding(t: Union[torch.Tensor, op_sharding = mesh.get_op_sharding(partition_spec) annotate_func = torch_xla._XLAC._xla_annotate_custom_sharding annotate_func(unwrap_sharded_tensor(t), op_sharding) - return wrap_as_sharded_tensor(t) + return wrap_as_sharded_tensor( + t, mesh_shape=mesh.mesh_shape, partition_spec=partition_spec) def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, @@ -651,7 +654,9 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, op_sharding = mesh.get_op_sharding(partition_spec) annotate_func = torch_xla._XLAC._xla_mark_sharding annotate_func(unwrap_sharded_tensor(t), op_sharding) - return wrap_as_sharded_tensor(t) + # Pass mesh and partition spec information for DTensor compatibility + return wrap_as_sharded_tensor( + t, mesh_shape=mesh.mesh_shape, partition_spec=partition_spec) def mark_sharding_with_gradients( @@ -755,11 +760,19 @@ def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor: return t -def wrap_as_sharded_tensor( - t: Union[torch.Tensor, XLAShardedTensor]) -> XLAShardedTensor: +def wrap_as_sharded_tensor(t: Union[torch.Tensor, XLAShardedTensor], + mesh_shape=None, + partition_spec=None) -> XLAShardedTensor: + # pass along mesh and partition spec information if not isinstance(t, XLAShardedTensor): - return XLAShardedTensor(t) - return t + return XLAShardedTensor( + t, mesh_shape=mesh_shape, partition_spec=partition_spec) + else: + if mesh_shape is not None: + t.mesh_shape = mesh_shape + if partition_spec is not None: + t.partition_spec = partition_spec + return t def unwrap_sharded_tensor( From d231adb9db7819ba519c00707a44dd3eeaf1756e Mon Sep 17 00:00:00 2001 From: Claire Huang Date: Tue, 22 Jul 2025 18:35:54 +0000 Subject: [PATCH 2/5] Removed auto wrapping sharding propagation, added cached spec invalidation --- test/spmd/test_xla_dtensor_spec_conversion.py | 138 ++++++++++++++---- .../distributed/spmd/xla_sharded_tensor.py | 52 +++---- torch_xla/distributed/spmd/xla_sharding.py | 25 +++- 3 files changed, 146 insertions(+), 69 deletions(-) diff --git a/test/spmd/test_xla_dtensor_spec_conversion.py b/test/spmd/test_xla_dtensor_spec_conversion.py index 12102f555c20..2fe53613e94a 100644 --- a/test/spmd/test_xla_dtensor_spec_conversion.py +++ b/test/spmd/test_xla_dtensor_spec_conversion.py @@ -3,9 +3,12 @@ import torch from torch.distributed.tensor import DeviceMesh, Shard, distribute_tensor +from torch.distributed.tensor.placement_types import Replicate import torch_xla import torch_xla.runtime as xr +from torch_xla.distributed.spmd import XLAShardedTensor +from torch_xla.distributed.spmd.xla_sharding import wrap_as_sharded_tensor import unittest import test_xla_sharding_base @@ -31,7 +34,6 @@ def test_xla_to_dtensor_spec_conversion(self): mesh = DeviceMesh("xla", list(range(device_count))) # Test different sharding patterns - from torch.distributed.tensor.placement_types import Replicate test_cases = [ (torch.randn(100, 50), [Shard(0)]), (torch.randn(100, 50), [Shard(1)]), @@ -64,30 +66,27 @@ def test_mesh_conversion(self): assert converted_spec.mesh.shape == original_mesh.shape def test_spec_caching(self): - """Test that _spec property caches results for better performance""" - import time + """Test that _spec property caches results + + Addresses PR comment: "These sorts of tests that rely on the wall clock often lead to + annoying flakes in my experience. I think it's sufficient to just test that + self._cached_spec has a permanent value after the first call." + """ device_count = xr.global_runtime_device_count() mesh = DeviceMesh("xla", list(range(device_count))) - tensor = torch.randn(1000, - 1000) # Large tensor to make spec creation noticeable + tensor = torch.randn(100, 100) xla_tensor = distribute_tensor(tensor, mesh, [Shard(0)]) - # first access should create and cache the spec - start_time = time.time() + # First access should create and cache the spec spec1 = xla_tensor._spec - first_access_time = time.time() - start_time - # should be much faster due to caching - start_time = time.time() - spec2 = xla_tensor._spec - second_access_time = time.time() - start_time + # Verify the spec is cached + assert xla_tensor._cached_spec is not None + assert xla_tensor._cached_spec is spec1 + # Second access should return the cached spec + spec2 = xla_tensor._spec assert spec1 is spec2 - print( - f"First access: {first_access_time:.6f}s, Second access: {second_access_time:.6f}s" - ) - assert second_access_time * 10 < first_access_time, \ - f"Cached access should be much faster: {first_access_time:.6f}s vs {second_access_time:.6f}s" def _create_test_tensor_and_mesh(self, tensor_shape, mesh_shape, placements): """Helper to create tensor and mesh for testing""" @@ -114,22 +113,8 @@ def test_multi_dim_sharding_spec(self): assert len(spec.placements) == 2 assert spec.mesh.ndim == 2 - def test_tensor_operations_preserve_spec(self): - """Test that tensor operations preserve sharding metadata""" - xla_tensor, mesh = self._create_test_tensor_and_mesh((100, 50), (-1,), - [Shard(0)]) - - result_add = xla_tensor + 1 - result_mul = xla_tensor * 2 - result_relu = torch.relu(xla_tensor) - - for result in [result_add, result_mul, result_relu]: - assert hasattr(result, '_spec') - assert result._spec.mesh.device_type == "xla" - def test_mixed_placement_spec(self): """Test _spec for tensors with mixed shard/replicate placements""" - from torch.distributed.tensor.placement_types import Replicate device_count = xr.global_runtime_device_count() if device_count < 4: self.skipTest("Need at least 4 devices for 2D mesh") @@ -143,6 +128,97 @@ def test_mixed_placement_spec(self): assert isinstance(spec.placements[0], Shard) assert isinstance(spec.placements[1], Replicate) + def test_sharding_info_acquisition(self): + """Test that non-XLAShardedTensor can acquire sharding information + + Tests case of 'elem is not an XLAShardedTensor but there exists + sharding information we want to acquire' + """ + + device_count = xr.global_runtime_device_count() + mesh_shape = (device_count,) + partition_spec = (0, None) + + regular_tensor = torch.randn(100, 50).to('xla') + + sharded_tensor = wrap_as_sharded_tensor( + regular_tensor, mesh_shape=mesh_shape, partition_spec=partition_spec) + + # Verify the tensor acquired the sharding information + assert isinstance(sharded_tensor, XLAShardedTensor) + assert sharded_tensor.mesh_shape == mesh_shape + assert sharded_tensor.partition_spec == partition_spec + + def test_resharding_logic(self): + """ + Tests wrap_as_sharded_tensor resharding before returning XLAShardedTensor t. + """ + + device_count = xr.global_runtime_device_count() + if device_count < 4: + self.skipTest("Need at least 4 devices for resharding test") + + # Initial sharding + initial_mesh_shape = (device_count,) + initial_partition_spec = (0, None) + new_mesh_shape = (2, device_count // 2) + new_partition_spec = (0, 1) + + # Create tensor and verify resharding + tensor = torch.randn(100, 50).to('xla') + sharded_tensor = wrap_as_sharded_tensor( + tensor, + mesh_shape=initial_mesh_shape, + partition_spec=initial_partition_spec) + initial_spec = sharded_tensor._spec + + resharded_tensor = wrap_as_sharded_tensor( + sharded_tensor, + mesh_shape=new_mesh_shape, + partition_spec=new_partition_spec) + + # Verify resharding worked and cache was invalidated + assert resharded_tensor.mesh_shape == new_mesh_shape + assert resharded_tensor.partition_spec == new_partition_spec + assert resharded_tensor._spec is not initial_spec + + def test_spec_invalidation_on_resharding(self): + """Tests cases where the cached spec may become outdated. + """ + + device_count = xr.global_runtime_device_count() + if device_count < 4: + self.skipTest("Need at least 4 devices for resharding test") + + tensor = torch.randn(100, 50).to('xla') + initial_mesh_shape = (device_count,) + initial_partition_spec = (0, None) + new_mesh_shape = (2, device_count // 2) + new_partition_spec = (0, 1) + + sharded_tensor = wrap_as_sharded_tensor( + tensor, + mesh_shape=initial_mesh_shape, + partition_spec=initial_partition_spec) + initial_spec = sharded_tensor._spec + assert sharded_tensor._cached_spec is not None + + # Changing mesh_shape / partition_spec through wrap_as_sharded_tensor invalidates cache + resharded_tensor = wrap_as_sharded_tensor( + sharded_tensor, + mesh_shape=new_mesh_shape, + partition_spec=initial_partition_spec) + assert resharded_tensor._spec is not initial_spec + assert resharded_tensor._spec.mesh.shape == new_mesh_shape + + initial_spec = resharded_tensor._spec + resharded_tensor = wrap_as_sharded_tensor( + resharded_tensor, + mesh_shape=new_mesh_shape, + partition_spec=new_partition_spec) + assert resharded_tensor._spec is not initial_spec + assert resharded_tensor._spec.placements[1].dim == 1 + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index 2f0e432a5b2e..3e90c4467afe 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -10,6 +10,7 @@ from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.placement_types import Shard, Replicate +from torch.utils._pytree import tree_map_only @dataclass @@ -115,11 +116,13 @@ def __new__(cls, device=elem.device, requires_grad=kwargs.get("requires_grad", False)) r.global_tensor = elem.detach() if r.requires_grad else elem - # Store mesh and partition information for DTensor compatibility - if mesh_shape is not None: - r.mesh_shape = mesh_shape - if partition_spec is not None: - r.partition_spec = partition_spec + + # Initialize mesh, partition, and spec information + r.mesh_shape = mesh_shape or (elem.mesh_shape if isinstance( + elem, XLAShardedTensor) else None) + r.partition_spec = partition_spec or (elem.partition_spec if isinstance( + elem, XLAShardedTensor) else None) + r._cached_spec = None return r # Shards on the devices are materialized/available after the lazy @@ -144,6 +147,9 @@ def load_local_shards_(self, shards: List[XLAShard]): devices = [s.shard_device for s in shards] torch_xla._XLAC._load_local_shards(self.global_tensor, data, devices) + # Invalidate cached spec since the global_tensor data has changed + self.invalidate_spec_cache() + @property def sharding_spec(self): return torch_xla._XLAC._get_xla_sharding_spec(self.global_tensor) @@ -173,27 +179,7 @@ def unwrap(elem): return elem.global_tensor if isinstance(elem, XLAShardedTensor) else elem def wrap(elem): - if isinstance(elem, - torch.Tensor) and not isinstance(elem, XLAShardedTensor): - # Try to get mesh/partition info from any XLAShardedTensor in args - mesh_shape = None - partition_spec = None - - def find_sharded_info(x): - nonlocal mesh_shape, partition_spec - if isinstance(x, XLAShardedTensor): - if hasattr(x, 'mesh_shape') and x.mesh_shape: - mesh_shape = x.mesh_shape - if hasattr(x, 'partition_spec') and x.partition_spec: - partition_spec = x.partition_spec - - tree_map(find_sharded_info, args) - if kwargs: - tree_map(find_sharded_info, kwargs) - - return XLAShardedTensor( - elem, mesh_shape=mesh_shape, partition_spec=partition_spec) - return elem + return XLAShardedTensor(elem) if isinstance(elem, torch.Tensor) else elem # no_dispatch is only needed if you use enable_python_mode. # It prevents infinite recursion. @@ -209,11 +195,11 @@ def _spec(self): Convert XLA sharding information to DTensorSpec for DTensor interface compatibility. """ # Return cached spec if available - if hasattr(self, '_cached_spec'): + if self._cached_spec is not None: return self._cached_spec # use existing mesh_shape - if hasattr(self, 'mesh_shape') and self.mesh_shape: + if self.mesh_shape is not None: import torch_xla.runtime as xr device_count = xr.global_runtime_device_count() device_list = list(range(device_count)) @@ -223,11 +209,9 @@ def _spec(self): raise ValueError("mesh_shape must be specified to create DTensorSpec") # use existing partition_spec - if hasattr(self, 'partition_spec') and self.partition_spec: + if self.partition_spec is not None: placements = [] - for mesh_dim in range( - len(self.mesh_shape - ) if hasattr(self, 'mesh_shape') and self.mesh_shape else 1): + for mesh_dim in range(len(self.mesh_shape)): # find tensor dimension sharded on this mesh dimension tensor_dim = None for t_dim, m_dim in enumerate(self.partition_spec): @@ -250,6 +234,10 @@ def _spec(self): mesh=mesh, placements=tuple(placements), tensor_meta=tensor_meta) return self._cached_spec + def invalidate_spec_cache(self): + """Invalidate the cached DTensorSpec.""" + self._cached_spec = None + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): return super().__torch_function__(func, types, args, kwargs) diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 751fe7e9a661..5f4d4378e7d2 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -765,14 +765,27 @@ def wrap_as_sharded_tensor(t: Union[torch.Tensor, XLAShardedTensor], partition_spec=None) -> XLAShardedTensor: # pass along mesh and partition spec information if not isinstance(t, XLAShardedTensor): + # Create a new XLAShardedTensor return XLAShardedTensor( t, mesh_shape=mesh_shape, partition_spec=partition_spec) - else: - if mesh_shape is not None: - t.mesh_shape = mesh_shape - if partition_spec is not None: - t.partition_spec = partition_spec - return t + + # Update existing XLAShardedTensor if needed + needs_invalidate = False + + # Always set mesh_shape and partition_spec if provided + if mesh_shape is not None: + t.mesh_shape = mesh_shape + needs_invalidate = True + + if partition_spec is not None: + t.partition_spec = partition_spec + needs_invalidate = True + + # Invalidate cached spec if resharding occurred + if needs_invalidate: + t.invalidate_spec_cache() + + return t def unwrap_sharded_tensor( From 60cd54296b73f91785bebf7ead8122a822185cfd Mon Sep 17 00:00:00 2001 From: Claire Huang Date: Tue, 22 Jul 2025 18:40:12 +0000 Subject: [PATCH 3/5] Removing lazy import --- test/spmd/test_xla_dtensor_spec_conversion.py | 7 ------- torch_xla/distributed/spmd/xla_sharded_tensor.py | 1 - 2 files changed, 8 deletions(-) diff --git a/test/spmd/test_xla_dtensor_spec_conversion.py b/test/spmd/test_xla_dtensor_spec_conversion.py index 2fe53613e94a..e48a323bd151 100644 --- a/test/spmd/test_xla_dtensor_spec_conversion.py +++ b/test/spmd/test_xla_dtensor_spec_conversion.py @@ -67,24 +67,17 @@ def test_mesh_conversion(self): def test_spec_caching(self): """Test that _spec property caches results - - Addresses PR comment: "These sorts of tests that rely on the wall clock often lead to - annoying flakes in my experience. I think it's sufficient to just test that - self._cached_spec has a permanent value after the first call." """ device_count = xr.global_runtime_device_count() mesh = DeviceMesh("xla", list(range(device_count))) tensor = torch.randn(100, 100) xla_tensor = distribute_tensor(tensor, mesh, [Shard(0)]) - # First access should create and cache the spec spec1 = xla_tensor._spec - # Verify the spec is cached assert xla_tensor._cached_spec is not None assert xla_tensor._cached_spec is spec1 - # Second access should return the cached spec spec2 = xla_tensor._spec assert spec1 is spec2 diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index 3e90c4467afe..ed590e22fd87 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -200,7 +200,6 @@ def _spec(self): # use existing mesh_shape if self.mesh_shape is not None: - import torch_xla.runtime as xr device_count = xr.global_runtime_device_count() device_list = list(range(device_count)) mesh = DeviceMesh("xla", From 91714d4809806eb7f1110da97b00c383b730a613 Mon Sep 17 00:00:00 2001 From: Claire Huang Date: Tue, 22 Jul 2025 20:43:48 +0000 Subject: [PATCH 4/5] Added test for catching thrown error in spec --- test/spmd/test_xla_dtensor_spec_conversion.py | 17 +++++++++++++++++ .../distributed/spmd/xla_sharded_tensor.py | 12 ++++++++++-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/test/spmd/test_xla_dtensor_spec_conversion.py b/test/spmd/test_xla_dtensor_spec_conversion.py index e48a323bd151..81cb8a4aa2e4 100644 --- a/test/spmd/test_xla_dtensor_spec_conversion.py +++ b/test/spmd/test_xla_dtensor_spec_conversion.py @@ -212,6 +212,23 @@ def test_spec_invalidation_on_resharding(self): assert resharded_tensor._spec is not initial_spec assert resharded_tensor._spec.placements[1].dim == 1 + def test_auto_wrapped_tensor_spec_failure(self): + """Test that auto-wrapped tensors fail when accessing _spec property. + + Auto-wrapped tensors are created through operations that trigger __torch_dispatch__ + but don't yet have access to the sharding propagation done through open xla, + causing ._spec to fail. + """ + device_count = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", torch.arange(device_count)) + tensor = torch.randn(4, 4) + sharded_tensor = distribute_tensor(tensor, mesh, [Shard(0)]) + + auto_wrapped = sharded_tensor + sharded_tensor + + with self.assertRaises(ValueError): + _ = auto_wrapped._spec + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index ed590e22fd87..a20d530f3faa 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -205,7 +205,11 @@ def _spec(self): mesh = DeviceMesh("xla", torch.tensor(device_list).reshape(self.mesh_shape)) else: - raise ValueError("mesh_shape must be specified to create DTensorSpec") + raise ValueError( + "mesh_shape must be specified to create DTensorSpec. " + "If this tensor was created through torch operations, it may be auto-wrapped. " + "Use wrap_as_sharded_tensor() to set mesh_shape before accessing _spec. " + ) # use existing partition_spec if self.partition_spec is not None: @@ -220,7 +224,11 @@ def _spec(self): placements.append( Shard(tensor_dim) if tensor_dim is not None else Replicate()) else: - raise ValueError("partition_spec must be specified to create DTensorSpec") + raise ValueError( + "partition_spec must be specified to create DTensorSpec. " + "If this tensor was created through torch operations, it may be auto-wrapped. " + "Use wrap_as_sharded_tensor() to set partition_spec before accessing _spec. " + ) # tensor metadata tensor_meta = TensorMeta( From 5a61b1dd2cdeb5d78db143df9e628e9414c8be03 Mon Sep 17 00:00:00 2001 From: Claire Huang Date: Wed, 23 Jul 2025 18:20:50 +0000 Subject: [PATCH 5/5] Update test_mark_shard_scalar to reflect mesh_shape implementation --- test/spmd/test_xla_sharding.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 7b1be7574a1f..48b760f6e3f0 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -1162,10 +1162,6 @@ def test_mark_shard_scalar(self): self.assertIsInstance(shard.indices, type(Ellipsis)) self.assertEqual(shard.replica_id, i) - # It looks like mesh_shape attribute is never implemented. - with self.assertRaises(AttributeError): - xt.mesh_shape - def test_global_mesh(self): expected_mesh = self._get_mesh((1, self.n_devices)) xs.set_global_mesh(expected_mesh)