diff --git a/test/neuron/run_tests.sh b/test/neuron/run_tests.sh index 22e778945f9..f7671cc3d82 100755 --- a/test/neuron/run_tests.sh +++ b/test/neuron/run_tests.sh @@ -56,6 +56,14 @@ function run_test { PJRT_DEVICE=NEURON NEURON_NUM_DEVICES=1 run_coverage "$@" } +function run_test_multi_device { + if ! test_is_selected "$1"; then + return + fi + echo "Running in PjRt runtime: $@" + PJRT_DEVICE=NEURON run_coverage "$@" +} + function run_test_without_functionalization { if ! test_is_selected "$1"; then return @@ -246,7 +254,8 @@ function run_xla_op_tests3 { run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py" #run_test "$_TEST_DIR/spmd/test_dtensor_integration.py" #run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py" - run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" + run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" + run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_spec_conv.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" #run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py" diff --git a/test/run_tests.sh b/test/run_tests.sh index 93f4cb33c06..b2cc8f751d2 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -254,6 +254,7 @@ function run_xla_op_tests3 { run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py" run_test_multi_devices_without_func "$_TEST_DIR/spmd/test_dtensor_integration3.py" run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" + run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" run_test "$_TEST_DIR/spmd/test_mp_input_sharding.py" diff --git a/test/spmd/test_xla_dtensor_spec_conversion.py b/test/spmd/test_xla_dtensor_spec_conversion.py new file mode 100644 index 00000000000..81cb8a4aa2e --- /dev/null +++ b/test/spmd/test_xla_dtensor_spec_conversion.py @@ -0,0 +1,235 @@ +import os +import sys + +import torch +from torch.distributed.tensor import DeviceMesh, Shard, distribute_tensor +from torch.distributed.tensor.placement_types import Replicate + +import torch_xla +import torch_xla.runtime as xr +from torch_xla.distributed.spmd import XLAShardedTensor +from torch_xla.distributed.spmd.xla_sharding import wrap_as_sharded_tensor + +import unittest +import test_xla_sharding_base + + +class XLADTensorSpecConversionTest(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + super().setUpClass() + + def test_sample_test_case(self): + world_size = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", torch.arange(world_size)) + big_tensor = torch.randn(100000, 88) + my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)]) + + assert my_dtensor._spec.mesh.device_type == mesh.device_type + assert my_dtensor._spec.placements == (Shard(0),) + + def test_xla_to_dtensor_spec_conversion(self): + device_count = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", list(range(device_count))) + + # Test different sharding patterns + test_cases = [ + (torch.randn(100, 50), [Shard(0)]), + (torch.randn(100, 50), [Shard(1)]), + (torch.randn(100, 50, 25), [Shard(0)]), + (torch.randn(100, 50), [Replicate()]), + ] + + for tensor, placements in test_cases: + xla_tensor = distribute_tensor(tensor, mesh, placements) + spec = xla_tensor._spec + + assert spec is not None + assert spec.mesh.device_type == "xla" + assert spec.tensor_meta.shape == tensor.shape + assert spec.tensor_meta.dtype == tensor.dtype + assert len(spec.placements) >= 1 + assert spec.placements == tuple(placements) + + def test_mesh_conversion(self): + device_count = xr.global_runtime_device_count() + original_mesh = DeviceMesh("xla", list(range(device_count))) + tensor = torch.randn(50, 50) + xla_tensor = distribute_tensor(tensor, original_mesh, [Shard(0)]) + + converted_spec = xla_tensor._spec + + assert converted_spec.mesh.device_type == "xla" + assert converted_spec.mesh.size() == device_count + # assert on mesh dimensions + assert converted_spec.mesh.shape == original_mesh.shape + + def test_spec_caching(self): + """Test that _spec property caches results + """ + device_count = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", list(range(device_count))) + tensor = torch.randn(100, 100) + xla_tensor = distribute_tensor(tensor, mesh, [Shard(0)]) + + spec1 = xla_tensor._spec + + assert xla_tensor._cached_spec is not None + assert xla_tensor._cached_spec is spec1 + + spec2 = xla_tensor._spec + assert spec1 is spec2 + + def _create_test_tensor_and_mesh(self, tensor_shape, mesh_shape, placements): + """Helper to create tensor and mesh for testing""" + device_count = xr.global_runtime_device_count() + if device_count < max(mesh_shape): + self.skipTest( + f"Need at least {max(mesh_shape)} devices, got {device_count}") + + mesh = DeviceMesh("xla", torch.arange(device_count).reshape(mesh_shape)) + tensor = torch.randn(*tensor_shape) + return distribute_tensor(tensor, mesh, placements), mesh + + def test_multi_dim_sharding_spec(self): + """Test _spec for multi-dimensional sharding""" + device_count = xr.global_runtime_device_count() + if device_count < 4: + self.skipTest("Need at least 4 devices for 2D mesh") + + mesh_shape = (2, device_count // 2) + xla_tensor, mesh = self._create_test_tensor_and_mesh( + (100, 50), mesh_shape, [Shard(0), Shard(1)]) + spec = xla_tensor._spec + + assert len(spec.placements) == 2 + assert spec.mesh.ndim == 2 + + def test_mixed_placement_spec(self): + """Test _spec for tensors with mixed shard/replicate placements""" + device_count = xr.global_runtime_device_count() + if device_count < 4: + self.skipTest("Need at least 4 devices for 2D mesh") + + mesh_shape = (2, device_count // 2) + xla_tensor, mesh = self._create_test_tensor_and_mesh( + (100, 50), mesh_shape, [Shard(0), Replicate()]) + spec = xla_tensor._spec + + assert len(spec.placements) == 2 + assert isinstance(spec.placements[0], Shard) + assert isinstance(spec.placements[1], Replicate) + + def test_sharding_info_acquisition(self): + """Test that non-XLAShardedTensor can acquire sharding information + + Tests case of 'elem is not an XLAShardedTensor but there exists + sharding information we want to acquire' + """ + + device_count = xr.global_runtime_device_count() + mesh_shape = (device_count,) + partition_spec = (0, None) + + regular_tensor = torch.randn(100, 50).to('xla') + + sharded_tensor = wrap_as_sharded_tensor( + regular_tensor, mesh_shape=mesh_shape, partition_spec=partition_spec) + + # Verify the tensor acquired the sharding information + assert isinstance(sharded_tensor, XLAShardedTensor) + assert sharded_tensor.mesh_shape == mesh_shape + assert sharded_tensor.partition_spec == partition_spec + + def test_resharding_logic(self): + """ + Tests wrap_as_sharded_tensor resharding before returning XLAShardedTensor t. + """ + + device_count = xr.global_runtime_device_count() + if device_count < 4: + self.skipTest("Need at least 4 devices for resharding test") + + # Initial sharding + initial_mesh_shape = (device_count,) + initial_partition_spec = (0, None) + new_mesh_shape = (2, device_count // 2) + new_partition_spec = (0, 1) + + # Create tensor and verify resharding + tensor = torch.randn(100, 50).to('xla') + sharded_tensor = wrap_as_sharded_tensor( + tensor, + mesh_shape=initial_mesh_shape, + partition_spec=initial_partition_spec) + initial_spec = sharded_tensor._spec + + resharded_tensor = wrap_as_sharded_tensor( + sharded_tensor, + mesh_shape=new_mesh_shape, + partition_spec=new_partition_spec) + + # Verify resharding worked and cache was invalidated + assert resharded_tensor.mesh_shape == new_mesh_shape + assert resharded_tensor.partition_spec == new_partition_spec + assert resharded_tensor._spec is not initial_spec + + def test_spec_invalidation_on_resharding(self): + """Tests cases where the cached spec may become outdated. + """ + + device_count = xr.global_runtime_device_count() + if device_count < 4: + self.skipTest("Need at least 4 devices for resharding test") + + tensor = torch.randn(100, 50).to('xla') + initial_mesh_shape = (device_count,) + initial_partition_spec = (0, None) + new_mesh_shape = (2, device_count // 2) + new_partition_spec = (0, 1) + + sharded_tensor = wrap_as_sharded_tensor( + tensor, + mesh_shape=initial_mesh_shape, + partition_spec=initial_partition_spec) + initial_spec = sharded_tensor._spec + assert sharded_tensor._cached_spec is not None + + # Changing mesh_shape / partition_spec through wrap_as_sharded_tensor invalidates cache + resharded_tensor = wrap_as_sharded_tensor( + sharded_tensor, + mesh_shape=new_mesh_shape, + partition_spec=initial_partition_spec) + assert resharded_tensor._spec is not initial_spec + assert resharded_tensor._spec.mesh.shape == new_mesh_shape + + initial_spec = resharded_tensor._spec + resharded_tensor = wrap_as_sharded_tensor( + resharded_tensor, + mesh_shape=new_mesh_shape, + partition_spec=new_partition_spec) + assert resharded_tensor._spec is not initial_spec + assert resharded_tensor._spec.placements[1].dim == 1 + + def test_auto_wrapped_tensor_spec_failure(self): + """Test that auto-wrapped tensors fail when accessing _spec property. + + Auto-wrapped tensors are created through operations that trigger __torch_dispatch__ + but don't yet have access to the sharding propagation done through open xla, + causing ._spec to fail. + """ + device_count = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", torch.arange(device_count)) + tensor = torch.randn(4, 4) + sharded_tensor = distribute_tensor(tensor, mesh, [Shard(0)]) + + auto_wrapped = sharded_tensor + sharded_tensor + + with self.assertRaises(ValueError): + _ = auto_wrapped._spec + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 7b1be7574a1..48b760f6e3f 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -1162,10 +1162,6 @@ def test_mark_shard_scalar(self): self.assertIsInstance(shard.indices, type(Ellipsis)) self.assertEqual(shard.replica_id, i) - # It looks like mesh_shape attribute is never implemented. - with self.assertRaises(AttributeError): - xt.mesh_shape - def test_global_mesh(self): expected_mesh = self._get_mesh((1, self.n_devices)) xs.set_global_mesh(expected_mesh) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 1f6f5249b93..24f18d3bdcd 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -61,6 +61,7 @@ run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" run_test "$_TEST_DIR/spmd/test_fsdp_v2.py" run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" +run_test "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" run_test "$_TEST_DIR/test_gradient_accumulation.py" XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v run_test "$_TEST_DIR/test_autocast.py" diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index aedfd6a801e..a20d530f3fa 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -6,6 +6,11 @@ from typing import List, Tuple, Iterator, Union import contextlib import collections +import torch_xla.runtime as xr +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import Shard, Replicate +from torch.utils._pytree import tree_map_only @dataclass @@ -91,10 +96,15 @@ class XLAShardedTensor(torch.Tensor): # >> assert len(input.shape) == len(partition_spec) partition_spec: Tuple[int, None] - __slots__ = ['global_tensor'] + __slots__ = ['global_tensor', 'mesh_shape', 'partition_spec', '_cached_spec'] @staticmethod - def __new__(cls, elem: torch.Tensor, *args, **kwargs): + def __new__(cls, + elem: torch.Tensor, + mesh_shape=None, + partition_spec=None, + *args, + **kwargs): # TODO(yeounoh) wrapper can take different arguments r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] cls, @@ -106,6 +116,13 @@ def __new__(cls, elem: torch.Tensor, *args, **kwargs): device=elem.device, requires_grad=kwargs.get("requires_grad", False)) r.global_tensor = elem.detach() if r.requires_grad else elem + + # Initialize mesh, partition, and spec information + r.mesh_shape = mesh_shape or (elem.mesh_shape if isinstance( + elem, XLAShardedTensor) else None) + r.partition_spec = partition_spec or (elem.partition_spec if isinstance( + elem, XLAShardedTensor) else None) + r._cached_spec = None return r # Shards on the devices are materialized/available after the lazy @@ -130,6 +147,9 @@ def load_local_shards_(self, shards: List[XLAShard]): devices = [s.shard_device for s in shards] torch_xla._XLAC._load_local_shards(self.global_tensor, data, devices) + # Invalidate cached spec since the global_tensor data has changed + self.invalidate_spec_cache() + @property def sharding_spec(self): return torch_xla._XLAC._get_xla_sharding_spec(self.global_tensor) @@ -169,6 +189,62 @@ def wrap(elem): func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) return rs + @property + def _spec(self): + """ + Convert XLA sharding information to DTensorSpec for DTensor interface compatibility. + """ + # Return cached spec if available + if self._cached_spec is not None: + return self._cached_spec + + # use existing mesh_shape + if self.mesh_shape is not None: + device_count = xr.global_runtime_device_count() + device_list = list(range(device_count)) + mesh = DeviceMesh("xla", + torch.tensor(device_list).reshape(self.mesh_shape)) + else: + raise ValueError( + "mesh_shape must be specified to create DTensorSpec. " + "If this tensor was created through torch operations, it may be auto-wrapped. " + "Use wrap_as_sharded_tensor() to set mesh_shape before accessing _spec. " + ) + + # use existing partition_spec + if self.partition_spec is not None: + placements = [] + for mesh_dim in range(len(self.mesh_shape)): + # find tensor dimension sharded on this mesh dimension + tensor_dim = None + for t_dim, m_dim in enumerate(self.partition_spec): + if m_dim == mesh_dim: + tensor_dim = t_dim + break + placements.append( + Shard(tensor_dim) if tensor_dim is not None else Replicate()) + else: + raise ValueError( + "partition_spec must be specified to create DTensorSpec. " + "If this tensor was created through torch operations, it may be auto-wrapped. " + "Use wrap_as_sharded_tensor() to set partition_spec before accessing _spec. " + ) + + # tensor metadata + tensor_meta = TensorMeta( + shape=self.global_tensor.shape, + stride=self.global_tensor.stride(), + dtype=self.global_tensor.dtype) + + # Create and cache the spec + self._cached_spec = DTensorSpec( + mesh=mesh, placements=tuple(placements), tensor_meta=tensor_meta) + return self._cached_spec + + def invalidate_spec_cache(self): + """Invalidate the cached DTensorSpec.""" + self._cached_spec = None + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): return super().__torch_function__(func, types, args, kwargs) diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 49229b17cff..5f4d4378e7d 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -543,7 +543,8 @@ def enable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh = get_global_mesh() if mesh is None else mesh t = mark_sharding(unwrap_sharded_tensor(t), mesh, partition_spec) t = torch_xla._XLAC._spmd_full_to_shard_shape(unwrap_sharded_tensor(t)) - return wrap_as_sharded_tensor(t) + return wrap_as_sharded_tensor( + t, mesh_shape=mesh.mesh_shape, partition_spec=partition_spec) def disable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor], @@ -560,7 +561,8 @@ def disable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor], t = torch_xla._XLAC._spmd_shard_to_full_shape( unwrap_sharded_tensor(t), mesh.get_op_sharding(partition_spec), full_shape, t.dtype) - return wrap_as_sharded_tensor(t) + return wrap_as_sharded_tensor( + t, mesh_shape=mesh.mesh_shape, partition_spec=partition_spec) def annotate_custom_sharding(t: Union[torch.Tensor, @@ -594,7 +596,8 @@ def annotate_custom_sharding(t: Union[torch.Tensor, op_sharding = mesh.get_op_sharding(partition_spec) annotate_func = torch_xla._XLAC._xla_annotate_custom_sharding annotate_func(unwrap_sharded_tensor(t), op_sharding) - return wrap_as_sharded_tensor(t) + return wrap_as_sharded_tensor( + t, mesh_shape=mesh.mesh_shape, partition_spec=partition_spec) def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, @@ -651,7 +654,9 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, op_sharding = mesh.get_op_sharding(partition_spec) annotate_func = torch_xla._XLAC._xla_mark_sharding annotate_func(unwrap_sharded_tensor(t), op_sharding) - return wrap_as_sharded_tensor(t) + # Pass mesh and partition spec information for DTensor compatibility + return wrap_as_sharded_tensor( + t, mesh_shape=mesh.mesh_shape, partition_spec=partition_spec) def mark_sharding_with_gradients( @@ -755,10 +760,31 @@ def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor: return t -def wrap_as_sharded_tensor( - t: Union[torch.Tensor, XLAShardedTensor]) -> XLAShardedTensor: +def wrap_as_sharded_tensor(t: Union[torch.Tensor, XLAShardedTensor], + mesh_shape=None, + partition_spec=None) -> XLAShardedTensor: + # pass along mesh and partition spec information if not isinstance(t, XLAShardedTensor): - return XLAShardedTensor(t) + # Create a new XLAShardedTensor + return XLAShardedTensor( + t, mesh_shape=mesh_shape, partition_spec=partition_spec) + + # Update existing XLAShardedTensor if needed + needs_invalidate = False + + # Always set mesh_shape and partition_spec if provided + if mesh_shape is not None: + t.mesh_shape = mesh_shape + needs_invalidate = True + + if partition_spec is not None: + t.partition_spec = partition_spec + needs_invalidate = True + + # Invalidate cached spec if resharding occurred + if needs_invalidate: + t.invalidate_spec_cache() + return t