|
| 1 | +import os |
| 2 | +import sys |
| 3 | + |
| 4 | +import torch |
| 5 | +from torch.distributed.tensor import DeviceMesh, Shard, distribute_tensor |
| 6 | +from torch.distributed.tensor.placement_types import Replicate |
| 7 | + |
| 8 | +import torch_xla |
| 9 | +import torch_xla.runtime as xr |
| 10 | +from torch_xla.distributed.spmd import XLAShardedTensor |
| 11 | +from torch_xla.distributed.spmd.xla_sharding import wrap_as_sharded_tensor |
| 12 | + |
| 13 | +import unittest |
| 14 | +import test_xla_sharding_base |
| 15 | + |
| 16 | + |
| 17 | +class XLADTensorSpecConversionTest(test_xla_sharding_base.XlaShardingTest): |
| 18 | + |
| 19 | + @classmethod |
| 20 | + def setUpClass(cls): |
| 21 | + super().setUpClass() |
| 22 | + |
| 23 | + def test_sample_test_case(self): |
| 24 | + world_size = xr.global_runtime_device_count() |
| 25 | + mesh = DeviceMesh("xla", torch.arange(world_size)) |
| 26 | + big_tensor = torch.randn(100000, 88) |
| 27 | + my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)]) |
| 28 | + |
| 29 | + assert my_dtensor._spec.mesh.device_type == mesh.device_type |
| 30 | + assert my_dtensor._spec.placements == (Shard(0),) |
| 31 | + |
| 32 | + def test_xla_to_dtensor_spec_conversion(self): |
| 33 | + device_count = xr.global_runtime_device_count() |
| 34 | + mesh = DeviceMesh("xla", list(range(device_count))) |
| 35 | + |
| 36 | + # Test different sharding patterns |
| 37 | + test_cases = [ |
| 38 | + (torch.randn(100, 50), [Shard(0)]), |
| 39 | + (torch.randn(100, 50), [Shard(1)]), |
| 40 | + (torch.randn(100, 50, 25), [Shard(0)]), |
| 41 | + (torch.randn(100, 50), [Replicate()]), |
| 42 | + ] |
| 43 | + |
| 44 | + for tensor, placements in test_cases: |
| 45 | + xla_tensor = distribute_tensor(tensor, mesh, placements) |
| 46 | + spec = xla_tensor._spec |
| 47 | + |
| 48 | + assert spec is not None |
| 49 | + assert spec.mesh.device_type == "xla" |
| 50 | + assert spec.tensor_meta.shape == tensor.shape |
| 51 | + assert spec.tensor_meta.dtype == tensor.dtype |
| 52 | + assert len(spec.placements) >= 1 |
| 53 | + assert spec.placements == tuple(placements) |
| 54 | + |
| 55 | + def test_mesh_conversion(self): |
| 56 | + device_count = xr.global_runtime_device_count() |
| 57 | + original_mesh = DeviceMesh("xla", list(range(device_count))) |
| 58 | + tensor = torch.randn(50, 50) |
| 59 | + xla_tensor = distribute_tensor(tensor, original_mesh, [Shard(0)]) |
| 60 | + |
| 61 | + converted_spec = xla_tensor._spec |
| 62 | + |
| 63 | + assert converted_spec.mesh.device_type == "xla" |
| 64 | + assert converted_spec.mesh.size() == device_count |
| 65 | + # assert on mesh dimensions |
| 66 | + assert converted_spec.mesh.shape == original_mesh.shape |
| 67 | + |
| 68 | + def test_spec_caching(self): |
| 69 | + """Test that _spec property caches results |
| 70 | + """ |
| 71 | + device_count = xr.global_runtime_device_count() |
| 72 | + mesh = DeviceMesh("xla", list(range(device_count))) |
| 73 | + tensor = torch.randn(100, 100) |
| 74 | + xla_tensor = distribute_tensor(tensor, mesh, [Shard(0)]) |
| 75 | + |
| 76 | + spec1 = xla_tensor._spec |
| 77 | + |
| 78 | + assert xla_tensor._cached_spec is not None |
| 79 | + assert xla_tensor._cached_spec is spec1 |
| 80 | + |
| 81 | + spec2 = xla_tensor._spec |
| 82 | + assert spec1 is spec2 |
| 83 | + |
| 84 | + def _create_test_tensor_and_mesh(self, tensor_shape, mesh_shape, placements): |
| 85 | + """Helper to create tensor and mesh for testing""" |
| 86 | + device_count = xr.global_runtime_device_count() |
| 87 | + if device_count < max(mesh_shape): |
| 88 | + self.skipTest( |
| 89 | + f"Need at least {max(mesh_shape)} devices, got {device_count}") |
| 90 | + |
| 91 | + mesh = DeviceMesh("xla", torch.arange(device_count).reshape(mesh_shape)) |
| 92 | + tensor = torch.randn(*tensor_shape) |
| 93 | + return distribute_tensor(tensor, mesh, placements), mesh |
| 94 | + |
| 95 | + def test_multi_dim_sharding_spec(self): |
| 96 | + """Test _spec for multi-dimensional sharding""" |
| 97 | + device_count = xr.global_runtime_device_count() |
| 98 | + if device_count < 4: |
| 99 | + self.skipTest("Need at least 4 devices for 2D mesh") |
| 100 | + |
| 101 | + mesh_shape = (2, device_count // 2) |
| 102 | + xla_tensor, mesh = self._create_test_tensor_and_mesh( |
| 103 | + (100, 50), mesh_shape, [Shard(0), Shard(1)]) |
| 104 | + spec = xla_tensor._spec |
| 105 | + |
| 106 | + assert len(spec.placements) == 2 |
| 107 | + assert spec.mesh.ndim == 2 |
| 108 | + |
| 109 | + def test_mixed_placement_spec(self): |
| 110 | + """Test _spec for tensors with mixed shard/replicate placements""" |
| 111 | + device_count = xr.global_runtime_device_count() |
| 112 | + if device_count < 4: |
| 113 | + self.skipTest("Need at least 4 devices for 2D mesh") |
| 114 | + |
| 115 | + mesh_shape = (2, device_count // 2) |
| 116 | + xla_tensor, mesh = self._create_test_tensor_and_mesh( |
| 117 | + (100, 50), mesh_shape, [Shard(0), Replicate()]) |
| 118 | + spec = xla_tensor._spec |
| 119 | + |
| 120 | + assert len(spec.placements) == 2 |
| 121 | + assert isinstance(spec.placements[0], Shard) |
| 122 | + assert isinstance(spec.placements[1], Replicate) |
| 123 | + |
| 124 | + def test_sharding_info_acquisition(self): |
| 125 | + """Test that non-XLAShardedTensor can acquire sharding information |
| 126 | +
|
| 127 | + Tests case of 'elem is not an XLAShardedTensor but there exists |
| 128 | + sharding information we want to acquire' |
| 129 | + """ |
| 130 | + |
| 131 | + device_count = xr.global_runtime_device_count() |
| 132 | + mesh_shape = (device_count,) |
| 133 | + partition_spec = (0, None) |
| 134 | + |
| 135 | + regular_tensor = torch.randn(100, 50).to('xla') |
| 136 | + |
| 137 | + sharded_tensor = wrap_as_sharded_tensor( |
| 138 | + regular_tensor, mesh_shape=mesh_shape, partition_spec=partition_spec) |
| 139 | + |
| 140 | + # Verify the tensor acquired the sharding information |
| 141 | + assert isinstance(sharded_tensor, XLAShardedTensor) |
| 142 | + assert sharded_tensor.mesh_shape == mesh_shape |
| 143 | + assert sharded_tensor.partition_spec == partition_spec |
| 144 | + |
| 145 | + def test_resharding_logic(self): |
| 146 | + """ |
| 147 | + Tests wrap_as_sharded_tensor resharding before returning XLAShardedTensor t. |
| 148 | + """ |
| 149 | + |
| 150 | + device_count = xr.global_runtime_device_count() |
| 151 | + if device_count < 4: |
| 152 | + self.skipTest("Need at least 4 devices for resharding test") |
| 153 | + |
| 154 | + # Initial sharding |
| 155 | + initial_mesh_shape = (device_count,) |
| 156 | + initial_partition_spec = (0, None) |
| 157 | + new_mesh_shape = (2, device_count // 2) |
| 158 | + new_partition_spec = (0, 1) |
| 159 | + |
| 160 | + # Create tensor and verify resharding |
| 161 | + tensor = torch.randn(100, 50).to('xla') |
| 162 | + sharded_tensor = wrap_as_sharded_tensor( |
| 163 | + tensor, |
| 164 | + mesh_shape=initial_mesh_shape, |
| 165 | + partition_spec=initial_partition_spec) |
| 166 | + initial_spec = sharded_tensor._spec |
| 167 | + |
| 168 | + resharded_tensor = wrap_as_sharded_tensor( |
| 169 | + sharded_tensor, |
| 170 | + mesh_shape=new_mesh_shape, |
| 171 | + partition_spec=new_partition_spec) |
| 172 | + |
| 173 | + # Verify resharding worked and cache was invalidated |
| 174 | + assert resharded_tensor.mesh_shape == new_mesh_shape |
| 175 | + assert resharded_tensor.partition_spec == new_partition_spec |
| 176 | + assert resharded_tensor._spec is not initial_spec |
| 177 | + |
| 178 | + def test_spec_invalidation_on_resharding(self): |
| 179 | + """Tests cases where the cached spec may become outdated. |
| 180 | + """ |
| 181 | + |
| 182 | + device_count = xr.global_runtime_device_count() |
| 183 | + if device_count < 4: |
| 184 | + self.skipTest("Need at least 4 devices for resharding test") |
| 185 | + |
| 186 | + tensor = torch.randn(100, 50).to('xla') |
| 187 | + initial_mesh_shape = (device_count,) |
| 188 | + initial_partition_spec = (0, None) |
| 189 | + new_mesh_shape = (2, device_count // 2) |
| 190 | + new_partition_spec = (0, 1) |
| 191 | + |
| 192 | + sharded_tensor = wrap_as_sharded_tensor( |
| 193 | + tensor, |
| 194 | + mesh_shape=initial_mesh_shape, |
| 195 | + partition_spec=initial_partition_spec) |
| 196 | + initial_spec = sharded_tensor._spec |
| 197 | + assert sharded_tensor._cached_spec is not None |
| 198 | + |
| 199 | + # Changing mesh_shape / partition_spec through wrap_as_sharded_tensor invalidates cache |
| 200 | + resharded_tensor = wrap_as_sharded_tensor( |
| 201 | + sharded_tensor, |
| 202 | + mesh_shape=new_mesh_shape, |
| 203 | + partition_spec=initial_partition_spec) |
| 204 | + assert resharded_tensor._spec is not initial_spec |
| 205 | + assert resharded_tensor._spec.mesh.shape == new_mesh_shape |
| 206 | + |
| 207 | + initial_spec = resharded_tensor._spec |
| 208 | + resharded_tensor = wrap_as_sharded_tensor( |
| 209 | + resharded_tensor, |
| 210 | + mesh_shape=new_mesh_shape, |
| 211 | + partition_spec=new_partition_spec) |
| 212 | + assert resharded_tensor._spec is not initial_spec |
| 213 | + assert resharded_tensor._spec.placements[1].dim == 1 |
| 214 | + |
| 215 | + def test_auto_wrapped_tensor_spec_failure(self): |
| 216 | + """Test that auto-wrapped tensors fail when accessing _spec property. |
| 217 | + |
| 218 | + Auto-wrapped tensors are created through operations that trigger __torch_dispatch__ |
| 219 | + but don't yet have access to the sharding propagation done through open xla, |
| 220 | + causing ._spec to fail. |
| 221 | + """ |
| 222 | + device_count = xr.global_runtime_device_count() |
| 223 | + mesh = DeviceMesh("xla", torch.arange(device_count)) |
| 224 | + tensor = torch.randn(4, 4) |
| 225 | + sharded_tensor = distribute_tensor(tensor, mesh, [Shard(0)]) |
| 226 | + |
| 227 | + auto_wrapped = sharded_tensor + sharded_tensor |
| 228 | + |
| 229 | + with self.assertRaises(ValueError): |
| 230 | + _ = auto_wrapped._spec |
| 231 | + |
| 232 | + |
| 233 | +if __name__ == '__main__': |
| 234 | + test = unittest.main() |
| 235 | + sys.exit(0 if test.result.wasSuccessful() else 1) |
0 commit comments