diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index 3d276f2dc26..d536bf17aa9 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -42,6 +42,34 @@ class XLAShardingTest : public AtenXlaTensorTestBase { } }; +TEST_F(XLAShardingTest, NormalizeTileAssignment) { + // Test with an empty tile assignment + std::vector empty_tile_assignment = {}; + auto normalized = + ShardingUtil::NormalizeTileAssignment(empty_tile_assignment); + EXPECT_TRUE(normalized.empty()); + + // Test with positive values + std::vector positive_tile_assignment = {3, 1, 4, 2}; + normalized = ShardingUtil::NormalizeTileAssignment(positive_tile_assignment); + EXPECT_EQ(normalized, std::vector({2, 0, 3, 1})); + + // Test with all identical values + std::vector identical_tile_assignment = {5, 5, 5, 5}; + normalized = ShardingUtil::NormalizeTileAssignment(identical_tile_assignment); + EXPECT_EQ(normalized, std::vector({0, 0, 0, 0})); + + // Test with negative values + std::vector negative_tile_assignment = {-3, -1, -4, -2}; + EXPECT_THROW(ShardingUtil::NormalizeTileAssignment(negative_tile_assignment), + std::runtime_error); + + // Test with mixed positive and negative values + std::vector mixed_tile_assignment = {3, -1, 4, 2}; + EXPECT_THROW(ShardingUtil::NormalizeTileAssignment(mixed_tile_assignment), + std::runtime_error); +} + TEST_F(XLAShardingTest, GetShardShape) { auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); xla::Shape tensor_shape = @@ -50,7 +78,9 @@ TEST_F(XLAShardingTest, GetShardShape) { {0, 1}, {2, 3}, }); - auto sharding = xla::HloSharding::Tile(mesh).ToProto(); + auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto(); + std::vector denormalized_tile_assignment = {0, 1, 2, 3}; + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); auto sharding_spec = std::make_shared(sharding, tensor_shape); @@ -58,7 +88,9 @@ TEST_F(XLAShardingTest, GetShardShape) { // For tiled sharding, each dimension should be halved EXPECT_EQ(shard_shape, std::vector({4, 4})); - sharding_spec->sharding = xla::HloSharding::Replicate().ToProto(); + xla_sharding = xla::HloSharding::Replicate().ToProto(); + sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment); + sharding_spec->sharding = sharding; shard_shape = ShardingUtil::GetShardShape(sharding_spec); // For replicated sharding, each dimension should be preserved EXPECT_EQ(shard_shape, std::vector({8, 7})); @@ -74,7 +106,9 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { {0, 1}, {2, 3}, }); - auto sharding = xla::HloSharding::Tile(mesh).ToProto(); + auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto(); + std::vector denormalized_tile_assignment = {0, 1, 2, 3}; + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); auto sharding_spec = std::make_shared(sharding, tensor_shape); auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); @@ -103,7 +137,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { EXPECT_EQ(slice.step(), 1); } } - sharding = xla::HloSharding::Replicate().ToProto(); + xla_sharding = xla::HloSharding::Replicate().ToProto(); + sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment); sharding_spec->sharding = sharding; shard_shape = ShardingUtil::GetShardShape(sharding_spec); replica_and_indices = ShardingUtil::GetShardReplicaAndIndicesForDevices( @@ -121,16 +156,18 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { TEST_F(XLAShardingTest, ShardTensor) { std::vector devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5", "TPU:6", "TPU:7"}; + std::vector denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7}; // 1D tiled at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat)); xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); - xla::OpSharding sharding = + xla::OpSharding xla_sharding = xla::HloSharding::Tile1D( CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()), devices.size()) .ToProto(); + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); auto sharding_spec = std::make_shared(sharding, tensor_shape); auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, @@ -148,7 +185,8 @@ TEST_F(XLAShardingTest, ShardTensor) { {0, 1, 2, 3}, {4, 5, 6, 7}, }); - sharding = xla::HloSharding::Tile(mesh).ToProto(); + xla_sharding = xla::HloSharding::Tile(mesh).ToProto(); + sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment); sharding_spec = std::make_shared(sharding, tensor_shape); shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, @@ -160,7 +198,9 @@ TEST_F(XLAShardingTest, ShardTensor) { // 3D tiled, the first dim is replicated and the last halved. The last shard // size should be smaller in dim=1 because it's not evenly divisible. xla::Array3D cube({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}); - sharding_spec->sharding = xla::HloSharding::Tile(cube).ToProto(); + xla_sharding = xla::HloSharding::Tile(cube).ToProto(); + sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment); + sharding_spec->sharding = sharding; shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); EXPECT_EQ(shards.size(), 8); @@ -168,7 +208,9 @@ TEST_F(XLAShardingTest, ShardTensor) { EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({8, 1, 2})); // Replicated, all shards should be identical. - sharding_spec->sharding = xla::HloSharding::Replicate().ToProto(); + xla_sharding = xla::HloSharding::Replicate().ToProto(); + sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment); + sharding_spec->sharding = sharding; shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); EXPECT_EQ(shards.size(), 8); @@ -182,7 +224,8 @@ TEST_F(XLAShardingTest, ShardTensor) { tensor_shape = CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); xla::Array4D tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}}); - sharding = xla::HloSharding::Tile(tesseract).ToProto(); + xla_sharding = xla::HloSharding::Tile(tesseract).ToProto(); + sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment); sharding_spec = std::make_shared(sharding, tensor_shape); shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, @@ -206,7 +249,8 @@ TEST_F(XLAShardingTest, ShardTensor) { CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); xla::Array hypercube(std::vector{1, 1, 2, 2, 2}); hypercube.FillIota(0); - sharding = xla::HloSharding::Tile(hypercube).ToProto(); + xla_sharding = xla::HloSharding::Tile(hypercube).ToProto(); + sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment); sharding_spec = std::make_shared(sharding, tensor_shape); shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, @@ -234,7 +278,9 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) { {4, 5, 0, 1}, {6, 7, 2, 3}, }); - auto sharding = xla::HloSharding::Tile(mesh).ToProto(); + auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto(); + std::vector denormalized_tile_assignment = {4, 5, 0, 1, 6, 7, 2, 3}; + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); auto sharding_spec = std::make_shared(sharding, tensor_shape); // For devices at the start of the mesh, all shards should have the same @@ -251,7 +297,10 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) { {0, 1, 4, 5}, {2, 3, 6, 7}, }); - sharding_spec->sharding = xla::HloSharding::Tile(mesh).ToProto(); + xla_sharding = xla::HloSharding::Tile(mesh).ToProto(); + denormalized_tile_assignment = {0, 1, 4, 5, 2, 3, 6, 7}; + sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment); + sharding_spec->sharding = sharding; shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); EXPECT_EQ(shards.size(), 4); @@ -278,7 +327,9 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) { {{7}}, }); - auto sharding = xla::HloSharding::Tile(mesh).ToProto(); + auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto(); + std::vector denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7}; + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); auto sharding_spec = std::make_shared( sharding, global_shape, /*minibatch=*/true); auto shards = ShardingUtil::ShardTensor(minibatch_tensor, sharding_spec, @@ -292,17 +343,21 @@ TEST_F(XLAShardingTest, EqualShardingSpecs) { auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); - XLATensor::ShardingSpec tiled_2d(xla::HloSharding::Tile({ - {0, 1, 2, 3}, - {4, 5, 6, 7}, - }) - .ToProto(), - tensor_shape); - XLATensor::ShardingSpec tiled_3d( - xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto(), - tensor_shape); - XLATensor::ShardingSpec replicated(xla::HloSharding::Replicate().ToProto(), - tensor_shape); + auto xla_sharding = xla::HloSharding::Tile({ + {0, 1, 2, 3}, + {4, 5, 6, 7}, + }) + .ToProto(); + std::vector denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7}; + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); + XLATensor::ShardingSpec tiled_2d(sharding, tensor_shape); + xla_sharding = + xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto(); + sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment); + XLATensor::ShardingSpec tiled_3d(sharding, tensor_shape); + xla_sharding = xla::HloSharding::Replicate().ToProto(); + sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment); + XLATensor::ShardingSpec replicated(sharding, tensor_shape); EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_2d)); EXPECT_FALSE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_3d)); EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(replicated, replicated)); @@ -323,12 +378,17 @@ TEST_F(XLAShardingTest, CreateTensorsData) { std::vector devices(3); std::fill_n(devices.begin(), devices.size(), bridge::GetDefaultDevice()->toString()); + auto replicate_xla_sharding = xla::HloSharding::Replicate().ToProto(); + auto unknown_xla_sharding = xla::HloSharding::Unknown().ToProto(); + torch_xla::OpSharding replicate_sharding(replicate_xla_sharding, + std::nullopt); + torch_xla::OpSharding unknown_sharding(unknown_xla_sharding, std::nullopt); std::vector shardings = { nullptr, - std::make_shared( - xla::HloSharding::Replicate().ToProto(), tensor_shape), - std::make_shared( - xla::HloSharding::Unknown().ToProto(), tensor_shape)}; + std::make_shared(replicate_sharding, + tensor_shape), + std::make_shared(unknown_sharding, + tensor_shape)}; std::vector tensors_data = CreateTensorsData(tensors, shardings, devices); @@ -387,13 +447,29 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) { auto y = xla::Add(x, xla::ConstantR0(&b, 3)); xla::XlaComputation xla_computation = GetValueOrThrow(b.Build(/*remove_dynamic_dimensions=*/false)); + + std::vector tensors{XLATensor::Create( + torch_xla::runtime::GetComputationClientOrDie()->CreateDataPlaceholder( + bridge::GetDefaultDevice()->toString(), std::move(shape)))}; + std::vector> denormalized_tile_assignments; + for (auto tensor : tensors) { + auto sharding_spec = tensor->sharding_spec(); + if (sharding_spec) { + denormalized_tile_assignments.push_back( + sharding_spec->sharding.GetDenormalizedTileAssignment()); + } + } + std::vector instances; - instances.push_back({std::move(xla_computation), - bridge::GetDefaultDevice()->toString(), - {bridge::GetDefaultDevice()->toString()}, - &shape, - /*should_wrap_parameter=*/false, - /*is_sharded=*/true}); + instances.push_back( + {std::move(xla_computation), + bridge::GetDefaultDevice()->toString(), + {bridge::GetDefaultDevice()->toString()}, + &shape, + /*should_wrap_parameter=*/false, + /*is_sharded=*/true, + /*allow_spmd_sharding_propagation_to_output=*/true, + /*denormalized_tile_assignments=*/denormalized_tile_assignments}); std::vector< std::shared_ptr> @@ -404,9 +480,6 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) { "add", std::move(computations[0]->move_computation())); // Prepare output sharding propagation, expect a sharded output placeholder. - std::vector tensors{XLATensor::Create( - torch_xla::runtime::GetComputationClientOrDie()->CreateDataPlaceholder( - bridge::GetDefaultDevice()->toString(), std::move(shape)))}; std::vector data_placeholders; std::vector sharding_specs; ShardingUtil::PrepareOutputShardingPropagation( @@ -417,11 +490,12 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) { if (n_devices > 1) { // Tiled sharding requires multiple devices. EXPECT_TRUE(xla::protobuf_util::HaveSameSerialization( - tiled, sharding_specs[0]->sharding)); + tiled, sharding_specs[0]->sharding.GetXlaOpSharding())); } else { // Sincle device execution defaults to replication sharding. EXPECT_TRUE(xla::protobuf_util::HaveSameSerialization( - xla::HloSharding::Replicate().ToProto(), sharding_specs[0]->sharding)); + xla::HloSharding::Replicate().ToProto(), + sharding_specs[0]->sharding.GetXlaOpSharding())); } // Check if the placeholder is on a SPMD device (sharded) with no real values. diff --git a/test/run_tests.sh b/test/run_tests.sh index b2cc8f751d2..177c186eead 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -261,6 +261,8 @@ function run_xla_op_tests3 { run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py" "$@" --skip-gradient-checkpointing run_test "$_TEST_DIR/test_gradient_accumulation.py" run_save_tensor_hlo run_test "$_TEST_DIR/spmd/test_spmd_lowering_context.py" + run_test_multi_devices "$_TEST_DIR/spmd/test_submesh_zero_indexed.py" + run_test_multi_devices "$_TEST_DIR/spmd/test_submesh_non_zero_indexed.py" run_test "$_TEST_DIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY run_test "$_TEST_DIR/test_input_output_aliases.py" run_test_without_functionalization "$_TEST_DIR/test_input_output_aliases.py" diff --git a/test/spmd/test_submesh_non_zero_indexed.py b/test/spmd/test_submesh_non_zero_indexed.py new file mode 100644 index 00000000000..d9117bfb186 --- /dev/null +++ b/test/spmd/test_submesh_non_zero_indexed.py @@ -0,0 +1,385 @@ +import unittest + +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr +import torch_xla.distributed.spmd as xs +from torch_xla.distributed.spmd import Mesh +import torch_xla.core.xla_env_vars as xenv +import torch_xla.utils.utils as xu + +import test_xla_sharding_base + + +class SubmeshNonZeroIndexedTest(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + super().setUpClass() + + def _create_non_zero_indexed_submesh_2dev(self): + """Create 2-device submesh starting from non-zero index with no overlap""" + if self.n_devices >= 4: + device_ids = [2, 3] # Use devices 2-3 when we have 4+ devices + else: + raise unittest.SkipTest( + f"Need at least 4 devices for non-overlapping 2-device submesh, got {self.n_devices}" + ) + mesh_shape = (1, 2) + axis_names = ('x', 'y') + return Mesh(device_ids, mesh_shape, axis_names) + + @unittest.skipUnless( + xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for non-zero-indexed 2-device submesh tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_pattern1_2dev(self): + """shard both tensors -> compute -> cpu() -> sync()""" + mesh = self._create_non_zero_indexed_submesh_2dev() + + t1 = torch.randn(4, 4, device='cpu') + t2 = torch.randn(4, 4, device='cpu') + expected = torch.matmul(t1, t2) + + xt1 = t1.to(torch_xla.device()) + xt2 = t2.to(torch_xla.device()) + + xs.mark_sharding(xt1, mesh, ('x', 'y')) + xs.mark_sharding(xt2, mesh, ('x', 'y')) + + result = torch.matmul(xt1, xt2) + result_cpu = result.cpu() + torch_xla.sync() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless( + xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for non-zero-indexed 2-device submesh tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_pattern1_2dev_direct_device(self): + """direct device creation: shard both tensors -> compute -> cpu() -> sync()""" + mesh = self._create_non_zero_indexed_submesh_2dev() + + xt1 = torch.randn(4, 4, device=torch_xla.device()) + xt2 = torch.randn(4, 4, device=torch_xla.device()) + + # Create expected result on CPU for comparison + expected = torch.matmul(xt1.cpu(), xt2.cpu()) + + xs.mark_sharding(xt1, mesh, ('x', 'y')) + xs.mark_sharding(xt2, mesh, ('x', 'y')) + + result = torch.matmul(xt1, xt2) + result_cpu = result.cpu() + torch_xla.sync() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless( + xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for non-zero-indexed 2-device submesh tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_pattern2_2dev_direct_device(self): + """direct device creation: shard both tensors -> compute -> sync() -> cpu()""" + mesh = self._create_non_zero_indexed_submesh_2dev() + + xt1 = torch.randn(4, 4, device=torch_xla.device()) + xt2 = torch.randn(4, 4, device=torch_xla.device()) + + # Create expected result on CPU for comparison + expected = torch.matmul(xt1.cpu(), xt2.cpu()) + + xs.mark_sharding(xt1, mesh, ('x', 'y')) + xs.mark_sharding(xt2, mesh, ('x', 'y')) + + result = torch.matmul(xt1, xt2) + torch_xla.sync() + result_cpu = result.cpu() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless( + xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for non-zero-indexed 2-device submesh tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_pattern3_2dev_direct_device(self): + """direct device creation: shard one tensor -> compute -> cpu() -> sync()""" + mesh = self._create_non_zero_indexed_submesh_2dev() + + xt1 = torch.randn(4, 4, device=torch_xla.device()) + xt2 = torch.randn(4, 4, device=torch_xla.device()) + + # Create expected result on CPU for comparison + expected = torch.matmul(xt1.cpu(), xt2.cpu()) + + xs.mark_sharding(xt1, mesh, ('x', 'y')) + + result = torch.matmul(xt1, xt2) + result_cpu = result.cpu() + torch_xla.sync() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless( + xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for non-zero-indexed 2-device submesh tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_pattern4_2dev_direct_device(self): + """direct device creation: shard one tensor -> compute -> sync() -> cpu()""" + mesh = self._create_non_zero_indexed_submesh_2dev() + + xt1 = torch.randn(4, 4, device=torch_xla.device()) + xt2 = torch.randn(4, 4, device=torch_xla.device()) + + # Create expected result on CPU for comparison + expected = torch.matmul(xt1.cpu(), xt2.cpu()) + + xs.mark_sharding(xt1, mesh, ('x', 'y')) + + result = torch.matmul(xt1, xt2) + torch_xla.sync() + result_cpu = result.cpu() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless( + xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for non-zero-indexed 2-device submesh tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_pattern5_2dev_direct_device(self): + """direct device creation: modify tensor -> shard one tensor -> compute -> sync() -> cpu()""" + mesh = self._create_non_zero_indexed_submesh_2dev() + + xt1 = torch.randn(4, 4, device=torch_xla.device()) + xt2 = torch.randn(4, 4, device=torch_xla.device()) + + # Create expected result on CPU for comparison + expected = torch.matmul(xt1.cpu() + 2, xt2.cpu()) + + xt1 += 2 + xs.mark_sharding(xt1, mesh, ('x', 'y')) + + result = torch.matmul(xt1, xt2) + torch_xla.sync() + result_cpu = result.cpu() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless( + xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for non-zero-indexed 2-device submesh tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_pattern6_2dev_direct_device(self): + """direct device creation: modify tensor -> shard both tensors -> compute -> sync() -> cpu()""" + mesh = self._create_non_zero_indexed_submesh_2dev() + + xt1 = torch.randn(4, 4, device=torch_xla.device()) + xt2 = torch.randn(4, 4, device=torch_xla.device()) + + # Create expected result on CPU for comparison + expected = torch.matmul(xt1.cpu() + 2, xt2.cpu()) + + xt1 += 2 + xs.mark_sharding(xt1, mesh, ('x', 'y')) + xs.mark_sharding(xt2, mesh, ('x', 'y')) + + result = torch.matmul(xt1, xt2) + torch_xla.sync() + result_cpu = result.cpu() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless( + xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for non-zero-indexed 2-device submesh tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_pattern7_2dev_direct_device(self): + """direct device creation: modify tensor -> shard one tensor -> compute -> cpu() -> sync()""" + mesh = self._create_non_zero_indexed_submesh_2dev() + + xt1 = torch.randn(4, 4, device=torch_xla.device()) + xt2 = torch.randn(4, 4, device=torch_xla.device()) + + # Create expected result on CPU for comparison + expected = torch.matmul(xt1.cpu() + 2, xt2.cpu()) + + xt1 += 2 + xs.mark_sharding(xt1, mesh, ('x', 'y')) + + result = torch.matmul(xt1, xt2) + result_cpu = result.cpu() + torch_xla.sync() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless( + xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for non-zero-indexed 2-device submesh tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_single_tensor_addition_2dev_direct_device(self): + """Single tensor addition test with direct device creation and non-zero-indexed 2-device submesh""" + mesh = self._create_non_zero_indexed_submesh_2dev() + + xt = torch.randn(4, 4, device=torch_xla.device()) + expected = xt.cpu() + 4.2 + + xs.mark_sharding(xt, mesh, ('x', None)) + + result = xt + 4.2 + torch_xla.sync() + result_cpu = result.cpu() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless( + xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for advanced direct device tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_complex_operations_direct_device(self): + """Test complex tensor operations with direct device creation and non-zero-indexed submesh""" + mesh = self._create_non_zero_indexed_submesh_2dev() + + # Create tensors directly on device + xt1 = torch.randn(4, 4, device=torch_xla.device()) + xt2 = torch.randn(4, 4, device=torch_xla.device()) + xt3 = torch.randn(4, 4, device=torch_xla.device()) + + # Create expected result on CPU for comparison + t1_cpu = xt1.cpu() + t2_cpu = xt2.cpu() + t3_cpu = xt3.cpu() + expected = torch.matmul(t1_cpu + t2_cpu, t3_cpu) + + # Apply sharding + xs.mark_sharding(xt1, mesh, ('x', 'y')) + xs.mark_sharding(xt2, mesh, ('x', 'y')) + xs.mark_sharding(xt3, mesh, ('x', 'y')) + + # Perform complex operations + result = torch.matmul(xt1 + xt2, xt3) + torch_xla.sync() + result_cpu = result.cpu() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless( + xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for advanced direct device tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_inplace_operations_direct_device(self): + """Test in-place operations with direct device creation and non-zero-indexed submesh""" + mesh = self._create_non_zero_indexed_submesh_2dev() + + # Create tensors directly on device + xt1 = torch.randn(4, 4, device=torch_xla.device()) + xt2 = torch.randn(4, 4, device=torch_xla.device()) + + # Store original values for expected calculation + t1_orig = xt1.cpu().clone() + t2_orig = xt2.cpu().clone() + + # Apply sharding + xs.mark_sharding(xt1, mesh, ('x', 'y')) + xs.mark_sharding(xt2, mesh, ('x', 'y')) + + # Perform in-place operations + xt1 *= 2.0 + xt1 += xt2 + + # Calculate expected result + expected = t1_orig * 2.0 + t2_orig + + torch_xla.sync() + result_cpu = xt1.cpu() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + # Error validation test cases for xs.Mesh constructor + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for error validation tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", + "Error validation tests require CPU") + def test_mesh_axis_names_length_mismatch(self): + """Test error when axis names length doesn't match mesh dimensions""" + device_ids = [2, 3] # Non-zero indexed devices + mesh_shape = (1, 2) # 2 dimensions + axis_names = ('data', 'model', 'extra') # 3 names - mismatch! + + with self.assertRaisesRegex( + AssertionError, "Number of axis names .* must match mesh dimensions"): + Mesh(device_ids, mesh_shape, axis_names) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for error validation tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", + "Error validation tests require CPU") + def test_mesh_duplicate_axis_names(self): + """Test error when axis names are not unique""" + device_ids = [2, 3] # Non-zero indexed devices + mesh_shape = (1, 2) + axis_names = ('data', 'data') # Duplicate names! + + with self.assertRaisesRegex(AssertionError, "Axis names must be unique"): + Mesh(device_ids, mesh_shape, axis_names) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for error validation tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", + "Error validation tests require CPU") + def test_mesh_device_count_mismatch(self): + """Test error when device IDs count doesn't match mesh size""" + device_ids = [2, 3] # 2 devices + mesh_shape = (2, 2) # mesh size = 4, but only 2 device IDs! + + with self.assertRaisesRegex(AssertionError, + "Number of device IDs .* must match mesh size"): + Mesh(device_ids, mesh_shape) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for error validation tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", + "Error validation tests require CPU") + def test_mesh_duplicate_device_ids(self): + """Test error when device IDs are not unique""" + device_ids = [2, 2] # Duplicate device IDs! + mesh_shape = (1, 2) + + with self.assertRaisesRegex(AssertionError, "Device IDs must be unique"): + Mesh(device_ids, mesh_shape) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for error validation tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", + "Error validation tests require CPU") + def test_mesh_device_ids_out_of_bounds(self): + """Test error when device IDs are outside addressable device range""" + # Assuming we have 4 devices (0,1,2,3), use invalid IDs like 10,11 + device_ids = [10, 11] # Out of bounds device IDs! + mesh_shape = (1, 2) + + with self.assertRaisesRegex( + AssertionError, + "Device IDs has to be subset of addressable_devices; got:*."): + Mesh(device_ids, mesh_shape) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/spmd/test_submesh_zero_indexed.py b/test/spmd/test_submesh_zero_indexed.py new file mode 100644 index 00000000000..07e850c9df5 --- /dev/null +++ b/test/spmd/test_submesh_zero_indexed.py @@ -0,0 +1,440 @@ +import unittest + +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr +import torch_xla.distributed.spmd as xs +from torch_xla.distributed.spmd import Mesh +import torch_xla.core.xla_env_vars as xenv +import torch_xla.utils.utils as xu +import sys + +import test_xla_sharding_base + + +class SubmeshZeroIndexedTest(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + super().setUpClass() + + def _create_zero_indexed_submesh_2dev(self): + """Create 2-device submesh starting from device 0: [0,1]""" + device_ids = [0, 1] + mesh_shape = (1, 2) + axis_names = ('x', 'y') + return Mesh(device_ids, mesh_shape, axis_names) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for 2-device submesh tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_pattern1_2dev(self): + """shard both tensors -> compute -> cpu() -> sync()""" + mesh = self._create_zero_indexed_submesh_2dev() + + t1 = torch.randn(4, 4, device='cpu') + t2 = torch.randn(4, 4, device='cpu') + expected = torch.matmul(t1, t2) + + xt1 = t1.to(torch_xla.device()) + xt2 = t2.to(torch_xla.device()) + + xs.mark_sharding(xt1, mesh, ('x', 'y')) + xs.mark_sharding(xt2, mesh, ('x', 'y')) + + result = torch.matmul(xt1, xt2) + result_cpu = result.cpu() + torch_xla.sync() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for 2-device submesh tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_pattern1_2dev_direct_device(self): + """direct device creation: shard both tensors -> compute -> cpu() -> sync()""" + mesh = self._create_zero_indexed_submesh_2dev() + + xt1 = torch.randn(4, 4, device=torch_xla.device()) + xt2 = torch.randn(4, 4, device=torch_xla.device()) + + # Create expected result on CPU for comparison + expected = torch.matmul(xt1.cpu(), xt2.cpu()) + + xs.mark_sharding(xt1, mesh, ('x', 'y')) + xs.mark_sharding(xt2, mesh, ('x', 'y')) + + result = torch.matmul(xt1, xt2) + result_cpu = result.cpu() + torch_xla.sync() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for 2-device submesh tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_pattern2_2dev(self): + """shard both tensors -> compute -> sync() -> cpu()""" + mesh = self._create_zero_indexed_submesh_2dev() + + t1 = torch.randn(4, 4, device='cpu') + t2 = torch.randn(4, 4, device='cpu') + expected = torch.matmul(t1, t2) + + xt1 = t1.to(torch_xla.device()) + xt2 = t2.to(torch_xla.device()) + + xs.mark_sharding(xt1, mesh, ('x', 'y')) + xs.mark_sharding(xt2, mesh, ('x', 'y')) + + result = torch.matmul(xt1, xt2) + torch_xla.sync() + result_cpu = result.cpu() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for 2-device submesh tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_pattern2_2dev_direct_device(self): + """direct device creation: shard both tensors -> compute -> sync() -> cpu()""" + mesh = self._create_zero_indexed_submesh_2dev() + + xt1 = torch.randn(4, 4, device=torch_xla.device()) + xt2 = torch.randn(4, 4, device=torch_xla.device()) + + # Create expected result on CPU for comparison + expected = torch.matmul(xt1.cpu(), xt2.cpu()) + + xs.mark_sharding(xt1, mesh, ('x', 'y')) + xs.mark_sharding(xt2, mesh, ('x', 'y')) + + result = torch.matmul(xt1, xt2) + torch_xla.sync() + result_cpu = result.cpu() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for 2-device submesh tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_pattern3_2dev(self): + """shard one tensor -> compute -> cpu() -> sync()""" + mesh = self._create_zero_indexed_submesh_2dev() + + t1 = torch.randn(4, 4, device='cpu') + t2 = torch.randn(4, 4, device='cpu') + expected = torch.matmul(t1, t2) + + xt1 = t1.to(torch_xla.device()) + xt2 = t2.to(torch_xla.device()) + + xs.mark_sharding(xt1, mesh, ('x', 'y')) + + result = torch.matmul(xt1, xt2) + result_cpu = result.cpu() + torch_xla.sync() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for 2-device submesh tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_pattern3_2dev_direct_device(self): + """direct device creation: shard one tensor -> compute -> cpu() -> sync()""" + mesh = self._create_zero_indexed_submesh_2dev() + + xt1 = torch.randn(4, 4, device=torch_xla.device()) + xt2 = torch.randn(4, 4, device=torch_xla.device()) + + # Create expected result on CPU for comparison + expected = torch.matmul(xt1.cpu(), xt2.cpu()) + + xs.mark_sharding(xt1, mesh, ('x', 'y')) + + result = torch.matmul(xt1, xt2) + result_cpu = result.cpu() + torch_xla.sync() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for 2-device submesh tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_pattern4_2dev_direct_device(self): + """direct device creation: shard one tensor -> compute -> sync() -> cpu()""" + mesh = self._create_zero_indexed_submesh_2dev() + + xt1 = torch.randn(4, 4, device=torch_xla.device()) + xt2 = torch.randn(4, 4, device=torch_xla.device()) + + # Create expected result on CPU for comparison + expected = torch.matmul(xt1.cpu(), xt2.cpu()) + + xs.mark_sharding(xt1, mesh, ('x', 'y')) + + result = torch.matmul(xt1, xt2) + torch_xla.sync() + result_cpu = result.cpu() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for 2-device submesh tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_pattern5_2dev_direct_device(self): + """direct device creation: modify tensor -> shard one tensor -> compute -> sync() -> cpu()""" + mesh = self._create_zero_indexed_submesh_2dev() + + xt1 = torch.randn(4, 4, device=torch_xla.device()) + xt2 = torch.randn(4, 4, device=torch_xla.device()) + + # Create expected result on CPU for comparison + expected = torch.matmul(xt1.cpu() + 2, xt2.cpu()) + + xt1 += 2 + xs.mark_sharding(xt1, mesh, ('x', 'y')) + + result = torch.matmul(xt1, xt2) + torch_xla.sync() + result_cpu = result.cpu() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for 2-device submesh tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_pattern6_2dev_direct_device(self): + """direct device creation: modify tensor -> shard both tensors -> compute -> sync() -> cpu()""" + mesh = self._create_zero_indexed_submesh_2dev() + + xt1 = torch.randn(4, 4, device=torch_xla.device()) + xt2 = torch.randn(4, 4, device=torch_xla.device()) + + # Create expected result on CPU for comparison + expected = torch.matmul(xt1.cpu() + 2, xt2.cpu()) + + xt1 += 2 + xs.mark_sharding(xt1, mesh, ('x', 'y')) + xs.mark_sharding(xt2, mesh, ('x', 'y')) + + result = torch.matmul(xt1, xt2) + torch_xla.sync() + result_cpu = result.cpu() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for 2-device submesh tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_pattern7_2dev_direct_device(self): + """direct device creation: modify tensor -> shard one tensor -> compute -> cpu() -> sync()""" + mesh = self._create_zero_indexed_submesh_2dev() + + xt1 = torch.randn(4, 4, device=torch_xla.device()) + xt2 = torch.randn(4, 4, device=torch_xla.device()) + + # Create expected result on CPU for comparison + expected = torch.matmul(xt1.cpu() + 2, xt2.cpu()) + + xt1 += 2 + xs.mark_sharding(xt1, mesh, ('x', 'y')) + + result = torch.matmul(xt1, xt2) + result_cpu = result.cpu() + torch_xla.sync() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for 2-device submesh tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_single_tensor_addition_2dev(self): + """Single tensor addition test with zero-indexed 2-device submesh""" + mesh = self._create_zero_indexed_submesh_2dev() + + t = torch.randn(4, 4, device='cpu') + expected = t + 4.2 + + xt = t.to(torch_xla.device()) + xs.mark_sharding(xt, mesh, ('x', None)) + + result = xt + 4.2 + torch_xla.sync() + result_cpu = result.cpu() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for 2-device submesh tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_single_tensor_addition_2dev_direct_device(self): + """Single tensor addition test with direct device creation and zero-indexed 2-device submesh""" + mesh = self._create_zero_indexed_submesh_2dev() + + xt = torch.randn(4, 4, device=torch_xla.device()) + expected = xt.cpu() + 4.2 + + xs.mark_sharding(xt, mesh, ('x', None)) + + result = xt + 4.2 + torch_xla.sync() + result_cpu = result.cpu() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless( + xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for advanced direct device tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_complex_operations_direct_device(self): + """Test complex tensor operations with direct device creation""" + mesh = self._create_zero_indexed_submesh_2dev() + + # Create tensors directly on device + xt1 = torch.randn(4, 4, device=torch_xla.device()) + xt2 = torch.randn(4, 4, device=torch_xla.device()) + xt3 = torch.randn(4, 4, device=torch_xla.device()) + + # Create expected result on CPU for comparison + t1_cpu = xt1.cpu() + t2_cpu = xt2.cpu() + t3_cpu = xt3.cpu() + expected = torch.matmul(t1_cpu + t2_cpu, t3_cpu) + + # Apply sharding + xs.mark_sharding(xt1, mesh, ('x', 'y')) + xs.mark_sharding(xt2, mesh, ('x', 'y')) + xs.mark_sharding(xt3, mesh, ('x', 'y')) + + # Perform complex operations + result = torch.matmul(xt1 + xt2, xt3) + torch_xla.sync() + result_cpu = result.cpu() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + @unittest.skipUnless( + xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for advanced direct device tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", "Submesh tests require CPU") + def test_inplace_operations_direct_device(self): + """Test in-place operations with direct device creation""" + mesh = self._create_zero_indexed_submesh_2dev() + + # Create tensors directly on device + xt1 = torch.randn(4, 4, device=torch_xla.device()) + xt2 = torch.randn(4, 4, device=torch_xla.device()) + + # Store original values for expected calculation + t1_orig = xt1.cpu().clone() + t2_orig = xt2.cpu().clone() + + # Apply sharding + xs.mark_sharding(xt1, mesh, ('x', 'y')) + xs.mark_sharding(xt2, mesh, ('x', 'y')) + + # Perform in-place operations + xt1 *= 2.0 + xt1 += xt2 + + # Calculate expected result + expected = t1_orig * 2.0 + t2_orig + + torch_xla.sync() + result_cpu = xt1.cpu() + + self.assertTrue(torch.allclose(expected, result_cpu, atol=1e-5)) + + # Error validation test cases for xs.Mesh constructor + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for error validation tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", + "Error validation tests require CPU") + def test_mesh_axis_names_length_mismatch(self): + """Test error when axis names length doesn't match mesh dimensions""" + device_ids = [0, 1] + mesh_shape = (1, 2) # 2 dimensions + axis_names = ('data', 'model', 'extra') # 3 names - mismatch! + + with self.assertRaisesRegex( + AssertionError, "Number of axis names .* must match mesh dimensions"): + Mesh(device_ids, mesh_shape, axis_names) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for error validation tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", + "Error validation tests require CPU") + def test_mesh_duplicate_axis_names(self): + """Test error when axis names are not unique""" + device_ids = [0, 1] + mesh_shape = (1, 2) + axis_names = ('data', 'data') # Duplicate names! + + with self.assertRaisesRegex(AssertionError, "Axis names must be unique"): + Mesh(device_ids, mesh_shape, axis_names) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for error validation tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", + "Error validation tests require CPU") + def test_mesh_device_count_mismatch(self): + """Test error when device IDs count doesn't match mesh size""" + device_ids = [0, 1] # 2 devices + mesh_shape = (2, 2) # mesh size = 4, but only 2 device IDs! + + with self.assertRaisesRegex(AssertionError, + "Number of device IDs .* must match mesh size"): + Mesh(device_ids, mesh_shape) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for error validation tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", + "Error validation tests require CPU") + def test_mesh_duplicate_device_ids(self): + """Test error when device IDs are not unique""" + device_ids = [0, 0] # Duplicate device IDs! + mesh_shape = (1, 2) + + with self.assertRaisesRegex(AssertionError, "Device IDs must be unique"): + Mesh(device_ids, mesh_shape) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Need at least 4 devices for error validation tests") + @unittest.skipUnless( + xu.getenv_as(xenv.PJRT_DEVICE, str) == "CPU", + "Error validation tests require CPU") + def test_mesh_device_ids_out_of_bounds(self): + """Test error when device IDs are outside addressable device range""" + # Assuming we have 4 devices, use invalid IDs like 10,11 + device_ids = [10, 11] # Out of bounds device IDs! + mesh_shape = (1, 2) + + with self.assertRaisesRegex( + AssertionError, + "Device IDs has to be subset of addressable_devices; got:*."): + Mesh(device_ids, mesh_shape) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 48b760f6e3f..0761a79bc0f 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -761,13 +761,15 @@ def test_hybrid_mesh_shape(self): @unittest.skipIf(xr.device_type() == 'TPU' and tpu.version() < 3, "Crash on TPU v2") + @patch('torch_xla.runtime.addressable_runtime_device_count') @patch('torch_xla.runtime.global_runtime_device_attributes') @patch('torch_xla.core.xla_model.xla_device_hw') @patch('torch_xla.runtime.global_runtime_device_count') def test_hybrid_mesh(self, device_count_mock, xla_device_mock, - device_attributes_mock): + device_attributes_mock, addressable_device_count_mock): # mock device attributes for 2 slices of v4-8 num_slices = 2 + addressable_device_count_mock.return_value = 8 device_count_mock.return_value = 8 xla_device_mock.return_value = "TPU" device_attributes_mock.return_value = [{ @@ -1620,8 +1622,9 @@ def test_device_ids_out_of_bounds(self): mesh_shape = (1, self.n_devices) invalid_ids = np.arange(self.n_devices + 1, self.n_devices * 2 + 1) - with self.assertRaisesRegex(AssertionError, - "Device IDs must be less than mesh size"): + with self.assertRaisesRegex( + AssertionError, + "Device IDs has to be subset of addressable_devices; got:*."): xs.Mesh(invalid_ids, mesh_shape) def test_mesh_size(self): @@ -1645,7 +1648,8 @@ def test_mismatch_global_devices(self): mesh_shape = (1, partial_num_devices) with self.assertRaisesRegex( AssertionError, - "Number of device IDs .* must match the global number of devices"): + "Number of device IDs .* must be less than the global number of devices" + ): xs.Mesh(device_ids, mesh_shape) @unittest.skipIf(xr.global_runtime_device_count() == 1, diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 6c34eca1450..2a521bce9d2 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -126,6 +126,7 @@ ptxla_cc_library( ":shape_builder", ":shape_helper", ":status", + ":torch_xla_op_sharding", ":version", "//torch_xla/csrc:hash_util", "//torch_xla/csrc:thread_pool", @@ -313,6 +314,7 @@ ptxla_cc_library( ":shape_helper", ":status", ":unwrap_data", + ":torch_xla_op_sharding", "//torch_xla/csrc/runtime:cache", "//torch_xla/csrc/runtime:computation_client", "@com_google_absl//absl/log:absl_check", @@ -381,3 +383,13 @@ cc_library( "@com_google_absl//absl/status:statusor", ], ) + +cc_library( + name = "torch_xla_op_sharding", + srcs = ["torch_xla_op_sharding.cpp"], + hdrs = ["torch_xla_op_sharding.h"], + deps = [ + "//torch_xla/csrc/runtime:debug_macros", + "@xla//xla/hlo/builder:xla_builder", + ], +) diff --git a/torch_xla/csrc/debug_util.cpp b/torch_xla/csrc/debug_util.cpp index 812f7122efa..49e56e9164a 100644 --- a/torch_xla/csrc/debug_util.cpp +++ b/torch_xla/csrc/debug_util.cpp @@ -21,6 +21,7 @@ #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/status.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" #include "torch_xla/csrc/xla_graph_executor.h" namespace torch_xla { @@ -218,7 +219,8 @@ void DebugUtil::SaveOutputShardingInfo(std::vector* tensors, auto xtensor = (*tensors)[indices[i]]; ss << xtensor->shape().get().ToString() << " "; if (xtensor->sharding_spec()) { - ss << xla::HloSharding::FromProto(xtensor->sharding_spec()->sharding) + ss << xla::HloSharding::FromProto( + xtensor->sharding_spec()->sharding.GetXlaOpSharding()) ->ToString(); } else { ss << xla::HloSharding::FromProto(xla::HloSharding::Unknown().ToProto()) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 5b62d95efd5..0309718a1aa 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -69,6 +69,7 @@ #include "torch_xla/csrc/tensor_methods.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" #include "torch_xla/csrc/version.h" #include "torch_xla/csrc/xla_backend_impl.h" #include "torch_xla/csrc/xla_graph_executor.h" @@ -706,7 +707,8 @@ std::string GetTensorsHloGraph(const std::vector& tensors, std::string GetXLAShardingSpec(const XLATensorPtr xtensor) { auto sharding_spec = xtensor->sharding_spec(); if (sharding_spec != nullptr) { - auto hlo_sharding = xla::HloSharding::FromProto(sharding_spec->sharding); + auto hlo_sharding = + xla::HloSharding::FromProto(sharding_spec->sharding.GetXlaOpSharding()); return hlo_sharding->ToString(); } return std::string(); @@ -1503,7 +1505,7 @@ void InitXlaModuleBindings(py::module m) { runtime::ComputationClient::ComputationPtr>(m, "XlaComputation"); // Define the _XLAC.OpSharding class. - PythonScope>(m, "OpSharding") + PythonScope>(m, "OpSharding") .def_init([](const py::list& tile_assignment, const py::list& group_assignment, const py::list& replication_groups, int sharding_type) { @@ -2504,16 +2506,16 @@ void InitXlaModuleBindings(py::module m) { } }) .def("_xla_mark_sharding", - [](const at::Tensor& input, xla::OpSharding sharding) { + [](const at::Tensor& input, torch_xla::OpSharding sharding) { ShardingUtil::XlaMarkSharding(input, sharding); }) .def("_xla_annotate_custom_sharding", - [](const at::Tensor& input, xla::OpSharding sharding) { + [](const at::Tensor& input, torch_xla::OpSharding sharding) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); ShardingUtil::XlaAnnotateCustomSharding(xtensor, sharding); }) .def("_mark_manual_sharding", - [](const at::Tensor& input, xla::OpSharding sharding) { + [](const at::Tensor& input, torch_xla::OpSharding sharding) { XLA_CHECK(IsNonDeviceDataIR(input)) << "Marking any data tensors as manual is not supported"; ShardingUtil::XlaMarkSharding(input, sharding); @@ -2533,13 +2535,14 @@ void InitXlaModuleBindings(py::module m) { xtensor->CreateFrom(torch_xla::MakeNode( xtensor->GetIrValue(), shard_shape, CustomSharding::Type::kSPMDFullToShardShape)); - output->SetShardingSpec(XLATensor::ShardingSpec( - xla::HloSharding::Manual().ToProto(), shard_shape)); + torch_xla::OpSharding sharding(xla::HloSharding::Manual().ToProto(), + sharding_spec->sharding.GetDenormalizedTileAssignment()); + output->SetShardingSpec(XLATensor::ShardingSpec(sharding, shard_shape)); return bridge::AtenFromXlaTensor(output); }) .def( "_spmd_shard_to_full_shape", - [](const at::Tensor& input, const xla::OpSharding& sharding, + [](const at::Tensor& input, const torch_xla::OpSharding& sharding, const std::vector& output_shape, const py::object& output_dtype) -> at::Tensor { XLATensorPtr xtensor = bridge::GetXlaTensor(input); @@ -2573,7 +2576,7 @@ void InitXlaModuleBindings(py::module m) { return GetXLAShardingSpec(xtensor); }) .def("_get_xla_op_sharding", - [](const at::Tensor& input) -> std::optional { + [](const at::Tensor& input) -> std::optional { XLATensorPtr xtensor = bridge::GetXlaTensor(input); XLATensor::ShardingSpecPtr sharding_spec = xtensor ? xtensor->sharding_spec() : nullptr; @@ -2613,7 +2616,7 @@ void InitXlaModuleBindings(py::module m) { // `torch_xla.runtime.local_runtime_devices()`. "_global_tensor_from_cpu_shards", [](const std::vector& shards, - const xla::OpSharding& sharding, + const torch_xla::OpSharding& sharding, std::optional>& global_shape) -> at::Tensor { XLA_CHECK(UseVirtualDevice()) << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index 2e4338f50b7..8fa0ff2fc72 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -13,6 +13,7 @@ #include "torch_xla/csrc/runtime/cache.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" namespace torch_xla { namespace { @@ -167,12 +168,12 @@ torch::lazy::hash_t XlaNode::GetOpHash(torch::lazy::OpKind op, return torch::lazy::HashCombine(h, hash_seed); } -void XlaNode::SetSharding(const xla::OpSharding& sharding, size_t index) { +void XlaNode::SetSharding(const torch_xla::OpSharding& sharding, size_t index) { if (output_shardings_.size() == 0) { - output_shardings_ = - std::vector>(num_outputs(), nullptr); + output_shardings_ = std::vector>( + num_outputs(), nullptr); } - output_shardings_[index] = std::make_shared(sharding); + output_shardings_[index] = std::make_shared(sharding); // TODO(JackCaoG): fix this hashing UpdateShardingHash(); } @@ -207,7 +208,10 @@ void XlaNode::UpdateShardingHash() { for (size_t i = 0; i < output_shardings_.size(); i++) { // keep the index as part of the hash sharding_hash_ = torch::lazy::HashCombine(sharding_hash_, (uint32_t)i); - std::shared_ptr sharding = output_shardings_[i]; + std::shared_ptr sharding = + std::make_shared( + output_shardings_[i]->GetXlaOpSharding(), + output_shardings_[i]->GetDenormalizedTileAssignment()); // skip the hash compute for empty sharding if (!sharding) { continue; diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index c2595a433a9..d5c96dc1057 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -21,6 +21,7 @@ #include "absl/types/span.h" #include "torch_xla/csrc/dynamic_shape_detector.h" #include "torch_xla/csrc/runtime/types.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" #include "xla/hlo/builder/xla_builder.h" namespace torch_xla { @@ -133,14 +134,19 @@ class XlaNode : public torch::lazy::Node { torch::lazy::hash_t shardingHash() const { return sharding_hash_; } // The node's outputs get assigned the same HLO sharding - const std::shared_ptr GetSharding(size_t index) const { + const std::shared_ptr GetSharding(size_t index) const { if (output_shardings_.size() == 0) { return nullptr; } return output_shardings_[index]; } - void SetSharding(const xla::OpSharding& sharding, size_t index); + const std::vector> GetShardings() + const { + return output_shardings_; + } + + void SetSharding(const torch_xla::OpSharding& sharding, size_t index); void ClearSharding() { output_shardings_.clear(); @@ -180,7 +186,7 @@ class XlaNode : public torch::lazy::Node { torch::lazy::hash_t sharding_hash_ = 0; // Experimental sharding annotations attached to the IR node. - std::vector> output_shardings_; + std::vector> output_shardings_; }; inline std::ostream& operator<<(std::ostream& stream, const XlaNode& node) { diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index 65f438643f1..c79ab727d3c 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -19,6 +19,7 @@ #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/stack_frame_index_builder.h" #include "torch_xla/csrc/status.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" namespace torch_xla { @@ -133,9 +134,9 @@ xla::XlaOp LoweringContext::GetParameter( const std::string param_name = absl::StrCat("p", param_index); xla::XlaOp param; if (data->HasSharding()) { - const xla::OpSharding sharding = data->GetSharding(); - const xla::XlaScopedShardingAssignment scoped_sharding(builder(), - sharding); + const torch_xla::OpSharding sharding = data->GetSharding(); + const xla::XlaScopedShardingAssignment scoped_sharding( + builder(), sharding.GetXlaOpSharding()); param = xla::Parameter(builder(), param_index, shape, param_name); } else { param = xla::Parameter(builder(), param_index, shape, param_name); @@ -237,6 +238,18 @@ xla::XlaOp LoweringContext::GetOutputOp(const torch::lazy::Output& output) { return it->second; } +void LoweringContext::ExtractShardingAndSetDenormalizedTileAssignments( + std::vector> shardings) { + for (auto sharding : shardings) { + std::vector denormalized_tile_assignment = + sharding->GetDenormalizedTileAssignment(); + if (!denormalized_tile_assignment.empty()) { + denormalized_tile_assignments_.push_back( + sharding->GetDenormalizedTileAssignment()); + } + } +} + XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node& node) { XlaOpVector result_ops; try { @@ -244,6 +257,12 @@ XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node& node) { const XlaNode* const casted = dynamic_cast(&node); result_ops = casted->Lower(this); + auto shardings = casted->GetShardings(); + if (!shardings.empty()) { + ExtractShardingAndSetDenormalizedTileAssignments(shardings); + } + // save the denormalized_tile_assignment from all nodes and then use it + // during Compile if (!casted->dynamic_dims().empty()) { const xla::internal::XlaBuilderFriend builder_friend; auto* const inst = builder_friend.GetInstruction(result_ops[0]); diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index f0545534155..136af6c599a 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -117,6 +117,14 @@ class LoweringContext : public torch::lazy::LoweringContext { int64_t AddStackFrameLocation(const torch::lazy::SourceLocation& source, int64_t parent_id); + void ExtractShardingAndSetDenormalizedTileAssignments( + std::vector>); + + const std::vector>& GetDenormalizedTileAssignments() + const { + return denormalized_tile_assignments_; + } + private: struct Parameter { xla::XlaOp param; @@ -135,6 +143,7 @@ class LoweringContext : public torch::lazy::LoweringContext { std::string name_; std::shared_ptr stack_frame_index_builder_; + std::vector> denormalized_tile_assignments_; }; // namespace torch_xla } // namespace torch_xla diff --git a/torch_xla/csrc/ops/device_data.cpp b/torch_xla/csrc/ops/device_data.cpp index a5f5536b5b6..93fb4901920 100644 --- a/torch_xla/csrc/ops/device_data.cpp +++ b/torch_xla/csrc/ops/device_data.cpp @@ -6,6 +6,7 @@ #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" namespace torch_xla { @@ -16,7 +17,7 @@ DeviceData::DeviceData(std::shared_ptr data) /*num_outputs=*/1, /*hash_seed=*/(uint32_t)101), data_(std::move(data)) { - std::optional op_sharding = + std::optional op_sharding = torch_xla::runtime::GetComputationClientOrDie()->GetDataSharding( std::dynamic_pointer_cast(data_)); if (op_sharding.has_value()) { diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index c4760783f4d..0408e818245 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -56,6 +56,7 @@ cc_library( "//torch_xla/csrc:device", "//torch_xla/csrc:dtype", "//torch_xla/csrc:status", + "//torch_xla/csrc:torch_xla_op_sharding", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -123,6 +124,7 @@ cc_library( ":tf_logging", ":xla_coordinator", "//torch_xla/csrc:status", + "//torch_xla/csrc:torch_xla_op_sharding", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index c2f9389a4a0..6e8c30e24bc 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -25,6 +25,7 @@ #include "torch_xla/csrc/runtime/types.h" #include "torch_xla/csrc/runtime/util.h" #include "torch_xla/csrc/status.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal_util.h" @@ -79,7 +80,7 @@ class ComputationClient { virtual bool HasSharding() const = 0; - virtual xla::OpSharding GetSharding() const = 0; + virtual torch_xla::OpSharding GetSharding() const = 0; private: std::string xla_device_; @@ -227,6 +228,7 @@ class ComputationClient { std::vector devices, const xla::Shape* output_shape, bool parameter_is_tupled_arguments = false, bool is_sharded = false, bool allow_spmd_sharding_propagation_to_output = true, + std::vector> denormalized_tile_assignments = {}, bool use_auto_spmd_partitioning = false, std::vector auto_spmd_mesh_shape = {}, std::vector auto_spmd_mesh_ids = {}, bool eager_mode = false) @@ -238,6 +240,7 @@ class ComputationClient { is_sharded(is_sharded), allow_spmd_sharding_propagation_to_output( allow_spmd_sharding_propagation_to_output), + denormalized_tile_assignments(denormalized_tile_assignments), use_auto_spmd_partitioning(use_auto_spmd_partitioning), auto_spmd_mesh_shape(auto_spmd_mesh_shape), auto_spmd_mesh_ids(auto_spmd_mesh_ids), @@ -247,6 +250,7 @@ class ComputationClient { std::string compilation_device; std::vector devices; const xla::Shape* output_shape = nullptr; + std::vector> denormalized_tile_assignments; bool parameter_is_tupled_arguments; bool is_sharded; bool allow_spmd_sharding_propagation_to_output; @@ -272,7 +276,7 @@ class ComputationClient { // will be populated in an asynchrounous fashion. virtual DataPtr CreateDataPlaceholder( std::string device, xla::Shape shape, - std::optional sharding = std::nullopt) = 0; + std::optional sharding = std::nullopt) = 0; // Returns data shards. We expect this to be called on PjRtShardedData to // retrieve the shards. If other data type is passed, it returns the input @@ -285,11 +289,12 @@ class ComputationClient { // Returns wrapped data shards as PjRtShardedData. virtual DataPtr WrapDataShards(absl::Span shards, std::string device, xla::Shape shape, - xla::OpSharding sharding) = 0; + torch_xla::OpSharding sharding) = 0; // Returns OpSharding attached to PjRtShardedData. The returned optional // structure will be empty if there is no sharding, like with PjRtData. - virtual std::optional GetDataSharding(DataPtr handle) = 0; + virtual std::optional GetDataSharding( + DataPtr handle) = 0; virtual std::string PjRtDeviceToString( xla::PjRtDevice* const device) const = 0; @@ -302,13 +307,13 @@ class ComputationClient { // input sharding spec is identical to the target `sharding` sharding spec. virtual std::vector ReshardData( absl::Span handles, - absl::Span shardings) = 0; + absl::Span shardings) = 0; // Transfers local sharded tensor values to the TPU devices and returns a // `PjRtShardedData`. virtual DataPtr TransferShardsToDevice( absl::Span> tensor_shards, - std::string device, xla::Shape shape, xla::OpSharding sharding) = 0; + std::string device, xla::Shape shape, torch_xla::OpSharding sharding) = 0; // Copies `data->buffer` to `dst` device buffer. virtual DataPtr CopyToDevice(DataPtr data, std::string dst) = 0; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cpp b/torch_xla/csrc/runtime/ifrt_computation_client.cpp index f5a6af1b267..915abf75fa7 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cpp @@ -194,14 +194,14 @@ void IfrtComputationClient::IfrtData::Assign( } } -xla::OpSharding IfrtComputationClient::IfrtData::GetSharding() const { +torch_xla::OpSharding IfrtComputationClient::IfrtData::GetSharding() const { XLA_CHECK(HasSharding()) << "Check HasSharding first"; return *sharding_; } ComputationClient::DataPtr IfrtComputationClient::CreateDataPlaceholder( std::string device, xla::Shape shape, - std::optional sharding) { + std::optional sharding) { return std::make_shared(std::move(device), std::move(shape), tsl::RCReference(), std::move(sharding)); @@ -240,7 +240,7 @@ ComputationClient::DataPtr IfrtComputationClient::GetDataShard( ComputationClient::DataPtr IfrtComputationClient::WrapDataShards( absl::Span shards, std::string device, xla::Shape shape, - xla::OpSharding sharding) { + torch_xla::OpSharding sharding) { XLA_CHECK_EQ(shards.size(), client_->addressable_device_count()); std::vector> arrays; std::vector shard_shapes; @@ -276,7 +276,7 @@ ComputationClient::DataPtr IfrtComputationClient::WrapDataShards( return std::make_shared(device, shape, sharded_array, sharding); } -std::optional IfrtComputationClient::GetDataSharding( +std::optional IfrtComputationClient::GetDataSharding( DataPtr handle) { auto ifrt_data = std::dynamic_pointer_cast(handle); return ifrt_data->sharding_; @@ -323,7 +323,7 @@ std::vector IfrtComputationClient::TransferToDevice( ComputationClient::DataPtr IfrtComputationClient::TransferShardsToDevice( absl::Span> tensor_shards, - std::string device, xla::Shape shape, xla::OpSharding sharding) { + std::string device, xla::Shape shape, torch_xla::OpSharding sharding) { tsl::profiler::TraceMe activity( "IfrtComputationClient::TransferShardsToDevice", tsl::profiler::TraceMeLevel::kInfo); @@ -382,7 +382,7 @@ tsl::RCReference IfrtComputationClient::ReplicateShardedData( // TODO: handle replicated data xla::XlaBuilder builder("ReplicateShardedData"); xla::Shape shape = handle->shape(); - builder.SetSharding(handle->GetSharding()); + builder.SetSharding(handle->GetSharding().GetXlaOpSharding()); // perform a simple identity calculation to reassemble the input as // replicated output. @@ -481,6 +481,11 @@ std::vector IfrtComputationClient::Compile( client_->addressable_devices().end()}); for (auto& instance : instances) { + std::vector denormalized_tile_assignment; + if (!instance.denormalized_tile_assignments.empty()) { + denormalized_tile_assignment = instance.denormalized_tile_assignments[0]; + } + xla::CompileOptions compile_options; if (instance.is_sharded) { // TODO(yeounoh) multi-host, multi-slice configurations @@ -527,7 +532,8 @@ std::vector IfrtComputationClient::Compile( std::shared_ptr ifrt_computation = std::make_shared( std::move(xla::XlaComputation(hlo_modules[0]->ToProto())), - instance.devices, std::move(executable)); + instance.devices, std::move(executable), + denormalized_tile_assignment); computations.push_back(ifrt_computation); @@ -607,11 +613,13 @@ IfrtComputationClient::ExecuteReplicated( auto outputs = result.outputs; - const std::vector& output_shardings = - ifrt_computation.output_shardings_ + const std::vector& output_shardings = + ifrt_computation.output_shardings_.has_value() ? *ifrt_computation.output_shardings_ : std::vector(outputs.size(), - xla::HloSharding::Replicate().ToProto()); + torch_xla::OpSharding( + xla::HloSharding::Replicate().ToProto(), + ifrt_computation.denormalized_tile_assignment_)); XLA_CHECK_EQ(output_shardings.size(), outputs.size()); std::vector data_handles(outputs.size()); diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 46b6343dc10..c348ff30823 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -12,6 +12,7 @@ #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/operation_manager.h" #include "torch_xla/csrc/runtime/util.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" #include "xla/pjrt/pjrt_client.h" @@ -42,23 +43,24 @@ class IfrtComputationClient : public ComputationClient { DataPtr CreateDataPlaceholder( std::string device, xla::Shape shape, - std::optional sharding = std::nullopt) override; + std::optional sharding = std::nullopt) override; std::vector GetDataShards(DataPtr data) override; DataPtr GetDataShard(DataPtr data, size_t index) override; DataPtr WrapDataShards(absl::Span shards, std::string device, - xla::Shape shape, xla::OpSharding sharding) override; + xla::Shape shape, + torch_xla::OpSharding sharding) override; - std::optional GetDataSharding(DataPtr handle) override; + std::optional GetDataSharding(DataPtr handle) override; std::vector TransferToDevice( absl::Span> tensors) override; std::vector ReshardData( absl::Span handles, - absl::Span shardings) override { + absl::Span shardings) override { XLA_ERROR() << __FUNCTION__ << " not implemented"; } @@ -71,7 +73,8 @@ class IfrtComputationClient : public ComputationClient { DataPtr TransferShardsToDevice( absl::Span> tensor_shards, - std::string device, xla::Shape shape, xla::OpSharding sharding) override; + std::string device, xla::Shape shape, + torch_xla::OpSharding sharding) override; DataPtr CopyToDevice(DataPtr data, std::string dst) override; @@ -207,13 +210,13 @@ class IfrtComputationClient : public ComputationClient { IfrtData(std::string device, xla::Shape device_shape, tsl::RCReference buffer, - std::optional sharding = std::nullopt) + std::optional sharding = std::nullopt) : Data(std::move(device), std::move(device_shape)), buffer(buffer), sharding_(sharding) {} IfrtData(std::string device, tsl::RCReference buffer, - std::optional sharding = std::nullopt) + std::optional sharding = std::nullopt) : Data(std::move(device), xla::ShapeUtil::MakeShape( xla::ifrt::ToPrimitiveType(buffer->dtype()).value(), @@ -236,7 +239,7 @@ class IfrtComputationClient : public ComputationClient { bool HasSharding() const override { return sharding_.has_value(); } - xla::OpSharding GetSharding() const override; + torch_xla::OpSharding GetSharding() const override; std::string ToString() const override { std::stringstream ss; @@ -246,7 +249,9 @@ class IfrtComputationClient : public ComputationClient { ss << " Data Device: " << device() << "\n"; ss << " Data Shape: " << shape().ToString() << "\n"; ss << " OpSharding: " - << xla::HloSharding::FromProto(*sharding_)->ToString() << "\n"; + << xla::HloSharding::FromProto(sharding_.value().GetXlaOpSharding()) + ->ToString() + << "\n"; ss << " NumShards: " << buffer->sharding().devices()->size() << "\n"; } else { ss << "XLAData: \n"; @@ -262,7 +267,7 @@ class IfrtComputationClient : public ComputationClient { return ss.str(); } - std::optional sharding_; + std::optional sharding_; tsl::RCReference buffer; }; @@ -270,16 +275,33 @@ class IfrtComputationClient : public ComputationClient { const std::shared_ptr handle); struct IfrtComputation : public Computation { - IfrtComputation(xla::XlaComputation computation, - std::vector devices, - std::shared_ptr executable) + IfrtComputation( + xla::XlaComputation computation, std::vector devices, + std::shared_ptr executable, + std::optional> denormalized_tile_assignment) : Computation(std::move(computation), std::move(devices)), - executable(std::move(executable)) { - output_shardings_ = this->executable->GetOutputShardings(); + executable(std::move(executable)), + denormalized_tile_assignment_(std::move( + denormalized_tile_assignment.value_or(std::vector{}))) { + xla_output_shardings_ = this->executable->GetOutputShardings(); + if (xla_output_shardings_.has_value()) { + output_shardings_ = std::vector{}; + output_shardings_->reserve(xla_output_shardings_.value().size()); + for (const auto& sharding : xla_output_shardings_.value()) { + // convert each into torch_xla::OpSharding object + torch_xla::OpSharding torch_xla_op_sharding( + sharding, denormalized_tile_assignment_); + output_shardings_.value().push_back(torch_xla_op_sharding); + } + } else { + output_shardings_ = std::nullopt; + } } std::shared_ptr executable; - std::optional> output_shardings_; + std::optional> xla_output_shardings_; + std::optional> output_shardings_; + std::vector denormalized_tile_assignment_; }; }; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cpp b/torch_xla/csrc/runtime/pjrt_computation_client.cpp index dd4950d87f5..c76838e8b2c 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cpp @@ -188,7 +188,7 @@ void PjRtComputationClient::PjRtData::Assign( ComputationClient::DataPtr PjRtComputationClient::CreateDataPlaceholder( std::string device, xla::Shape shape, - std::optional sharding) { + std::optional sharding) { if (sharding.has_value()) { return std::make_shared( std::move(device), std::move(shape), std::move(*sharding)); @@ -240,7 +240,7 @@ ComputationClient::DataPtr PjRtComputationClient::GetDataShard( ComputationClient::DataPtr PjRtComputationClient::WrapDataShards( absl::Span shards, std::string device, xla::Shape shape, - xla::OpSharding sharding) { + torch_xla::OpSharding sharding) { XLA_CHECK_EQ(shards.size(), client_->addressable_devices().size()); std::vector> pjrt_data_shards; pjrt_data_shards.reserve(shards.size()); @@ -254,12 +254,12 @@ ComputationClient::DataPtr PjRtComputationClient::WrapDataShards( sharding); } -std::optional PjRtComputationClient::GetDataSharding( +std::optional PjRtComputationClient::GetDataSharding( DataPtr handle) { if (auto sharded_data = dynamic_cast(handle.get())) { return sharded_data->GetSharding(); } - return std::optional(); + return std::optional(); } std::vector PjRtComputationClient::TransferToDevice( @@ -299,7 +299,7 @@ std::vector PjRtComputationClient::TransferToDevice( ComputationClient::DataPtr PjRtComputationClient::TransferShardsToDevice( absl::Span> tensor_shards, - std::string device, xla::Shape shape, xla::OpSharding sharding) { + std::string device, xla::Shape shape, torch_xla::OpSharding sharding) { tsl::profiler::TraceMe activity( "PjRtComputationClient::TransferShardsToDevice", tsl::profiler::TraceMeLevel::kInfo); @@ -355,7 +355,7 @@ PjRtComputationClient::ReplicateShardedData( } xla::XlaBuilder builder("ReplicateShardedData"); xla::Shape shape = sharded_data->shape(); - builder.SetSharding(sharded_data->GetSharding()); + builder.SetSharding(sharded_data->GetSharding().GetXlaOpSharding()); // perform a simple identity calculation to reassemble the input as // replicated output. @@ -373,13 +373,19 @@ PjRtComputationClient::ReplicateShardedData( GetValueOrThrow(computation.GetProgramShape()); std::string device = GetDefaultDevice(); + std::vector denormalized_tile_assignment = + sharded_data->GetSharding().GetDenormalizedTileAssignment(); + std::vector> denormalized_tile_assignments = { + denormalized_tile_assignment}; std::vector instances; - instances.push_back({std::move(computation), device, - GetCompilationDevices(device, {}), &shape, - /*should_wrap_parameter=*/false, - /*is_sharded=*/true, - /*allow_spmd_sharding_propagation_to_output=*/false}); + instances.push_back( + {std::move(computation), device, GetCompilationDevices(device, {}), + &shape, + /*should_wrap_parameter=*/false, + /*is_sharded=*/true, + /*allow_spmd_sharding_propagation_to_output=*/false, + /*denormalized_tile_assignments=*/denormalized_tile_assignments}); std::vector< std::shared_ptr> computations = Compile(std::move(instances)); @@ -404,7 +410,7 @@ PjRtComputationClient::ReplicateShardedData( std::vector PjRtComputationClient::ReshardData( absl::Span handles, - absl::Span shardings) { + absl::Span shardings) { tsl::profiler::TraceMe activity("ReshardData", tsl::profiler::TraceMeLevel::kInfo); XLA_COUNTER("ReshardData", 1); @@ -421,27 +427,28 @@ std::vector PjRtComputationClient::ReshardData( hlo_shardings.reserve(handles.size()); std::vector param_ops; param_ops.reserve(handles.size()); + PjRtShardedData* sharded_data; for (int i = 0; i < handles.size(); ++i) { - PjRtShardedData* sharded_data = - dynamic_cast(handles[i].get()); + sharded_data = dynamic_cast(handles[i].get()); XLA_CHECK_NE(sharded_data, nullptr) << "Resharding requires PjRtShardedData on SPMD virtual device, " << "current device: " << handles[i]->device(); shapes.push_back(sharded_data->shape()); - const xla::OpSharding& sharding = shardings[i]; + const torch_xla::OpSharding& sharding = shardings[i]; XLA_CHECK_NE(sharding.type(), xla::OpSharding::UNKNOWN) << "Resharding by UNKNOWN sharding type is not allowed."; - hlo_shardings.push_back( - GetValueOrThrow(xla::HloSharding::FromProto(sharding))); + hlo_shardings.push_back(GetValueOrThrow( + xla::HloSharding::FromProto(sharding.GetXlaOpSharding()))); xla::OpSharding fallback_sharding; fallback_sharding.set_type(xla::OpSharding::REPLICATED); xla::XlaScopedShardingAssignment assign( - &builder, sharded_data->GetSharding().type() == xla::OpSharding::UNKNOWN + &builder, sharded_data->GetSharding().GetXlaOpSharding().type() == + xla::OpSharding::UNKNOWN ? fallback_sharding - : sharded_data->GetSharding()); + : sharded_data->GetSharding().GetXlaOpSharding()); param_ops.push_back( xla::Parameter(&builder, i, shapes[i], absl::StrCat("p.", i))); } @@ -462,13 +469,18 @@ std::vector PjRtComputationClient::ReshardData( GetValueOrThrow(xla_computation.GetProgramShape()); std::string device = GetDefaultDevice(); + std::vector denormalized_tile_assignment = + sharded_data->GetSharding().GetDenormalizedTileAssignment(); + std::vector> denormalized_tile_assignments = { + denormalized_tile_assignment}; std::vector instances; - instances.push_back({std::move(xla_computation), device, - GetCompilationDevices(device, {}), - &program_shape.result(), - /*should_wrap_parameter=*/false, - /*is_sharded=*/true, - /*allow_spmd_sharding_propagation_to_output=*/false}); + instances.push_back( + {std::move(xla_computation), device, GetCompilationDevices(device, {}), + &program_shape.result(), + /*should_wrap_parameter=*/false, + /*is_sharded=*/true, + /*allow_spmd_sharding_propagation_to_output=*/false, + /*denormalized_tile_assignments=*/denormalized_tile_assignments}); std::shared_ptr computation = Compile(std::move(instances)).front(); @@ -552,7 +564,54 @@ std::vector PjRtComputationClient::Compile( static bool enable_cm_in_mp = runtime::sys_util::GetEnvBool("ENABLE_COLLECTIVE_MATMUL_IN_MP", false); + std::vector addressable_devices_str = GetLocalDevices(); + std::set addressable_devices; + for (auto& device : addressable_devices_str) { + std::vector device_spec_parts = absl::StrSplit(device, ':'); + XLA_CHECK_EQ(device_spec_parts.size(), 2) + << "Invalid device specification: " << device; + addressable_devices.insert(std::stoi(device_spec_parts[1])); + } for (auto& instance : instances) { + std::vector denormalized_tile_assignment; + std::vector sorted_denormalized_tile_assignment; + if (!instance.denormalized_tile_assignments.empty()) { + denormalized_tile_assignment = instance.denormalized_tile_assignments[0]; + sorted_denormalized_tile_assignment = denormalized_tile_assignment; + std::sort(sorted_denormalized_tile_assignment.begin(), + sorted_denormalized_tile_assignment.end()); + } + // Validate that all instances.denormalized_tile_assignments have the same + // tile assignment + if (!sorted_denormalized_tile_assignment.empty()) { + for (auto& to_check_tile_assignment : + instance.denormalized_tile_assignments) { + std::sort(to_check_tile_assignment.begin(), + to_check_tile_assignment.end()); + if ((!to_check_tile_assignment.empty()) && + (to_check_tile_assignment != sorted_denormalized_tile_assignment)) { + XLA_ERROR() << "the tile assignments - " << to_check_tile_assignment + << " - i.e. mesh are not same for all tensors (it should " + "be this - " + << sorted_denormalized_tile_assignment << " )"; + } + } + } + // Validate that tile assignment devices are addressable + if (!denormalized_tile_assignment.empty()) { + bool is_subset = std::all_of(denormalized_tile_assignment.begin(), + denormalized_tile_assignment.end(), + [&addressable_devices](int64_t num) { + return addressable_devices.find(num) != + addressable_devices.end(); + }); + if (!is_subset) { + XLA_ERROR() + << "tile_assignment - " << denormalized_tile_assignment + << " has device_ids not included in the addressable devices - " + << addressable_devices; + } + } xla::CompileOptions compile_options; if (enable_cm_in_mp) { compile_options.executable_build_options.set_use_spmd_partitioning(true); @@ -572,7 +631,11 @@ std::vector PjRtComputationClient::Compile( .set_allow_spmd_sharding_propagation_to_output( {instance.allow_spmd_sharding_propagation_to_output}); - int num_partitions = client_->device_count(); + // Use denormalized_tile_assignment size if available, otherwise fall back + // to device_count + int num_partitions = denormalized_tile_assignment.size() > 0 + ? denormalized_tile_assignment.size() + : client_->device_count(); compile_options.executable_build_options.set_num_partitions( num_partitions); compile_options.executable_build_options.set_num_replicas(1); @@ -602,12 +665,21 @@ std::vector PjRtComputationClient::Compile( } // TODO(244391366) verify this is correct for the collectives ops - xla::DeviceAssignment device_assignment(1, client_->device_count()); - // DeviceAssignment values must be the PjRtDevice ID, so we need to - // unwind the global ordinal mapping. - for (const auto& [device_id, global_ordinal] : global_ordinals_) { - device_assignment(0, global_ordinal) = device_id; + xla::DeviceAssignment device_assignment(1, num_partitions); + if (!denormalized_tile_assignment.empty()) { + // Use the denormalized_tile_assignment to assign the devices for + // computation + for (int i = denormalized_tile_assignment.size() - 1; i >= 0; --i) { + device_assignment(0, i) = denormalized_tile_assignment[i]; + } + } else { + TF_VLOG(5) << "Fall back to original logic since " + "denormalized_tile_assignment is empty"; + for (const auto& [device_id, global_ordinal] : global_ordinals_) { + device_assignment(0, global_ordinal) = device_id; + } } + compile_options.executable_build_options.set_device_assignment( device_assignment); } else { @@ -665,7 +737,8 @@ std::vector PjRtComputationClient::Compile( std::shared_ptr pjrt_computation = std::make_shared( std::move(xla::XlaComputation(hlo_modules[0]->ToProto())), - instance.devices, std::move(executable)); + instance.devices, std::move(executable), + denormalized_tile_assignment); computations.push_back(pjrt_computation); @@ -712,7 +785,8 @@ ComputationClient::ComputationPtr PjRtComputationClient::DeserializeComputation( std::vector devices = {UseVirtualDevice() ? spmd_device_str : GetDefaultDevice()}; return std::make_shared(std::move(computation), devices, - std::move(loaded_executable)); + std::move(loaded_executable), + std::nullopt); } torch::lazy::hash_t PjRtComputationClient::HashCompilationEnv() { @@ -794,6 +868,91 @@ PjRtComputationClient::ExecuteComputation( TF_VLOG(1) << "Returning " << datas.size() << " results"; return datas; } +namespace { +// Add this function somewhere in the file (e.g., before ExecuteReplicated +// function) +void PrintArgumentHandles2D( + const std::vector>& argument_handles, + const std::string& label = "argument_handles") { + std::cout << "\n=== " << label << " (2D Matrix) ===" << std::endl; + std::cout << "Dimensions: " << argument_handles.size() + << " devices (rows) x "; + if (!argument_handles.empty()) { + std::cout << argument_handles[0].size() << " arguments (cols)" << std::endl; + } else { + std::cout << "0 arguments (cols)" << std::endl; + return; + } + + // Print column headers + std::cout << std::setw(12) << "Device\\Arg"; + for (size_t col = 0; col < argument_handles[0].size(); ++col) { + std::cout << std::setw(15) << ("Arg[" + std::to_string(col) + "]"); + } + std::cout << std::endl; + + // Print separator line + std::cout << std::setw(12) << std::setfill('-') << ""; + for (size_t col = 0; col < argument_handles[0].size(); ++col) { + std::cout << std::setw(15) << std::setfill('-') << ""; + } + std::cout << std::setfill(' ') << std::endl; + + // Print each row (device) + for (size_t row = 0; row < argument_handles.size(); ++row) { + std::cout << std::setw(11) << ("Dev[" + std::to_string(row) + "]") << "|"; + for (size_t col = 0; col < argument_handles[row].size(); ++col) { + if (argument_handles[row][col] != nullptr) { + // Print pointer address (you can modify this to print other info) + std::cout << std::setw(14) << std::hex << argument_handles[row][col] + << std::dec; + } else { + std::cout << std::setw(14) << "nullptr"; + } + std::cout << " "; + } + std::cout << std::endl; + } + std::cout << "=== End " << label << " ===" << std::endl << std::endl; +} + +// Alternative version that shows more buffer information +void PrintArgumentHandlesDetailed( + const std::vector>& argument_handles, + const std::string& label = "argument_handles") { + std::cout << "\n=== " << label << " (Detailed) ===" << std::endl; + + for (size_t row = 0; row < argument_handles.size(); ++row) { + std::cout << "Device[" << row << "]: "; + for (size_t col = 0; col < argument_handles[row].size(); ++col) { + std::cout << "Arg[" << col << "]="; + if (argument_handles[row][col] != nullptr) { + auto* buffer = argument_handles[row][col]; + std::cout << "{ptr:" << std::hex << buffer << std::dec + << ", device:" << buffer->device()->DebugString() << "}"; + } else { + std::cout << "nullptr"; + } + if (col < argument_handles[row].size() - 1) std::cout << ", "; + } + std::cout << std::endl; + } + std::cout << "=== End " << label << " ===" << std::endl << std::endl; +} +} // namespace + +// wrapped function to handle absl::Span instead of std::vector +absl::Span FilterDevicesByAddressableDevices( + absl::Span devices, + const std::vector& indices) { + static std::vector filtered_devices_; + filtered_devices_.clear(); + filtered_devices_.reserve(indices.size()); + filtered_devices_ = + torch_xla::runtime::util::FilterDevicesByAddressableDevices(devices, + indices); + return absl::MakeConstSpan(filtered_devices_); +} std::vector PjRtComputationClient::ExecuteReplicated( @@ -811,13 +970,42 @@ PjRtComputationClient::ExecuteReplicated( const PjRtComputation& pjrt_computation = dynamic_cast(computation); - std::vector> argument_handles( - devices.size(), std::vector(arguments.size())); + absl::Span addressable_devices_; + std::vector addressable_devices_vec; + std::vector> argument_handles; { tsl::profiler::TraceMe activity( "PjRtComputationClient::ExecuteReplicated_argument_handle", tsl::profiler::TraceMeLevel::kInfo); + // Step 1: Determine which devices are needed across ALL arguments + std::set all_required_devices; + for (const auto& arg : arguments) { + auto pjrt_data = std::dynamic_pointer_cast(arg); + if (pjrt_data->GetSharding().GetDenormalizedTileAssignment().empty()) { + // Unsharded: needs all devices + all_required_devices.insert(devices.begin(), devices.end()); + } else { + // Sharded: needs filtered devices + absl::Span filtered_devices = + FilterDevicesByAddressableDevices( + devices, + pjrt_data->GetSharding().GetDenormalizedTileAssignment()); + all_required_devices.insert(filtered_devices.begin(), + filtered_devices.end()); + } + } + + // Convert to vector for indexing and assign to addressable_devices_ + addressable_devices_vec.assign(all_required_devices.begin(), + all_required_devices.end()); + addressable_devices_ = absl::MakeSpan(addressable_devices_vec); + + // Step 2: Pre-allocate argument_handles for all required devices + argument_handles.assign( + addressable_devices_.size(), + std::vector(arguments.size(), nullptr)); + absl::BlockingCounter counter(arguments.size()); // Time in nanoseconds that it takes to prepare an argument. Used to tune @@ -829,24 +1017,98 @@ PjRtComputationClient::ExecuteReplicated( for (int32_t i = start; i < end; ++i) { auto pjrt_data = std::dynamic_pointer_cast(arguments[i]); - XLA_CHECK_EQ(pjrt_data->shards.size(), devices.size()) - << "Expected one shard per device"; - for (int32_t d = 0; d < devices.size(); d++) { - std::shared_ptr shard = pjrt_data->shards[d]; + // Determine which devices this specific argument uses + absl::Span arg_devices; + if (pjrt_data->GetSharding() + .GetDenormalizedTileAssignment() + .empty()) { + arg_devices = devices; // Unsharded: replicated to all devices + } else { + arg_devices = FilterDevicesByAddressableDevices( + devices, + pjrt_data->GetSharding().GetDenormalizedTileAssignment()); + } + + TF_VLOG(3) << "Argument " << i + << " uses devices: " << absl::StrJoin(arg_devices, ",") + << "\n"; + + XLA_CHECK_EQ(pjrt_data->shards.size(), arg_devices.size()) + << "Expected one shard per device for argument " << i; + + // Step 3: Map this argument's shards to the current total + // addressable device indices + for (int32_t shard_idx = 0; shard_idx < arg_devices.size(); + shard_idx++) { + const std::string& device_name = arg_devices[shard_idx]; + + // Find the global device index in addressable_devices_ + auto device_ = std::find(addressable_devices_.begin(), + addressable_devices_.end(), device_name); + XLA_CHECK(device_ != addressable_devices_.end()) + << "Device " << device_name + << " not found in addressable_devices_"; + + int32_t device_idx = + std::distance(addressable_devices_.begin(), device_); - xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[d]); + std::shared_ptr shard = pjrt_data->shards[shard_idx]; + TF_VLOG(3) << "Mapping arg " << i << ", shard " << shard_idx + << " to device " << device_idx << " (" << device_name + << "), shard data: " << shard->ToString() << "\n"; + + xla::PjRtDevice* pjrt_device = StringToPjRtDevice(device_name); XLA_CHECK_EQ(shard->buffer->device(), pjrt_device); XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); - argument_handles[d][i] = shard->buffer.get(); + argument_handles[device_idx][i] = shard->buffer.get(); } counter.DecrementCount(); } }); counter.Wait(); } + // need to reassign the addressable_devices_ and argument_handles according to + // the sub-mesh device_ids used + if (!pjrt_computation.denormalized_tile_assignment_.empty()) { + // Get the filtered devices for the submesh + absl::Span submesh_devices = + FilterDevicesByAddressableDevices( + devices, pjrt_computation.denormalized_tile_assignment_); + // Create a mapping from device name to its index in the ORIGINAL + // argument_handles + std::unordered_map device_to_original_idx; + for (size_t i = 0; i < addressable_devices_.size(); ++i) { + device_to_original_idx[std::string(addressable_devices_[i])] = i; + } + // Create new argument_handles containing only the submesh devices + std::vector> submesh_argument_handles; + submesh_argument_handles.reserve(submesh_devices.size()); + for (const std::string& device_name : submesh_devices) { + auto it = device_to_original_idx.find(device_name); + if (it != device_to_original_idx.end()) { + // Copy the row for this device from the original argument_handles + submesh_argument_handles.push_back(argument_handles[it->second]); + } else { + // This shouldn't happen if the logic is correct, but handle gracefully + TF_VLOG(3) << "WARNING: Device " << device_name + << " not found in original devices!" << std::endl; + submesh_argument_handles.push_back( + std::vector(arguments.size(), nullptr)); + } + } + // Update the addressable devices and argument handles + addressable_devices_vec.assign(submesh_devices.begin(), + submesh_devices.end()); + addressable_devices_ = absl::MakeSpan(addressable_devices_vec); + argument_handles = std::move(submesh_argument_handles); + } else { + addressable_devices_ = devices; + // Resize to match all devices if no specific tile assignment + argument_handles.resize(addressable_devices_.size()); + } xla::ExecuteOptions execute_options; execute_options.untuple_result = options.explode_tuple; @@ -899,13 +1161,14 @@ PjRtComputationClient::ExecuteReplicated( : std::vector({result_shape}); XLA_CHECK_EQ(output_shapes.size(), num_outputs); - const std::vector& output_shardings = - pjrt_computation.output_shardings_.has_value() && num_outputs > 0 + const std::vector& output_shardings = + (pjrt_computation.output_shardings_.has_value() && + !pjrt_computation.output_shardings_.value().empty() && num_outputs > 0) ? *pjrt_computation.output_shardings_ : // Without an explicit sharding annotation, the output is implicitly // replicated, and we mark explicitly replicated here. - std::vector(num_outputs); + std::vector(num_outputs); XLA_CHECK_EQ(output_shardings.size(), num_outputs); absl::BlockingCounter counter(num_outputs); @@ -916,12 +1179,13 @@ PjRtComputationClient::ExecuteReplicated( pool_.ParallelFor( num_outputs, result_handle_cost_ns, [&](int64_t start, int64_t end) { for (int32_t i = start; i < end; ++i) { - std::vector> shards(devices.size()); - for (int32_t d = 0; d < devices.size(); d++) { + std::vector> shards( + addressable_devices_.size()); + for (int32_t d = 0; d < addressable_devices_.size(); d++) { std::unique_ptr buffer = std::move(results[d][i]); - shards[d] = - std::make_shared(devices[d], std::move(buffer)); + shards[d] = std::make_shared(addressable_devices_[d], + std::move(buffer)); } data_handles[i] = std::make_shared( diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 3a6b4478f72..2091930772b 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -12,6 +12,7 @@ #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/operation_manager.h" #include "torch_xla/csrc/runtime/util.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" #include "tsl/platform/env.h" #include "tsl/platform/threadpool.h" #include "xla/hlo/builder/xla_computation.h" @@ -40,7 +41,7 @@ class PjRtComputationClient : public ComputationClient { DataPtr CreateDataPlaceholder( std::string device, xla::Shape shape, - std::optional sharding = std::nullopt) override; + std::optional sharding = std::nullopt) override; static DataPtr CreateData(std::string device, xla::Shape shape, std::shared_ptr pjrt_buffer); @@ -50,9 +51,10 @@ class PjRtComputationClient : public ComputationClient { DataPtr GetDataShard(DataPtr data, size_t index) override; DataPtr WrapDataShards(absl::Span shards, std::string device, - xla::Shape shape, xla::OpSharding sharding) override; + xla::Shape shape, + torch_xla::OpSharding sharding) override; - std::optional GetDataSharding(DataPtr handle) override; + std::optional GetDataSharding(DataPtr handle) override; std::vector TransferToDevice( absl::Span> tensors) override; @@ -63,7 +65,7 @@ class PjRtComputationClient : public ComputationClient { // TODO(yeounoh) replace ReplicateShardedData with this. std::vector ReshardData( absl::Span handles, - absl::Span shardings) override; + absl::Span shardings) override; absl::StatusOr> TransferFromDevice( absl::Span handles) override; @@ -74,7 +76,8 @@ class PjRtComputationClient : public ComputationClient { DataPtr TransferShardsToDevice( absl::Span> tensor_shards, - std::string device, xla::Shape shape, xla::OpSharding sharding) override; + std::string device, xla::Shape shape, + torch_xla::OpSharding sharding) override; DataPtr CopyToDevice(DataPtr data, std::string dst) override; @@ -240,10 +243,10 @@ class PjRtComputationClient : public ComputationClient { bool HasSharding() const override { return false; } - xla::OpSharding GetSharding() const override { + torch_xla::OpSharding GetSharding() const override { XLA_CHECK(false) << "GetSharding should not be called on PjRtData, check " "HasSharding first"; - return xla::OpSharding(); + return torch_xla::OpSharding(); } std::string ToString() const override { @@ -267,12 +270,12 @@ class PjRtComputationClient : public ComputationClient { PjRtShardedData(std::string device, xla::Shape shape) = delete; PjRtShardedData(std::string device, xla::Shape shape, - xla::OpSharding sharding) + torch_xla::OpSharding sharding) : Data(std::move(device), std::move(shape)), sharding(sharding) {} PjRtShardedData(std::string device, xla::Shape shape, std::vector> shards, - xla::OpSharding sharding) + torch_xla::OpSharding sharding) : Data(std::move(device), std::move(shape)), shards(shards), sharding(sharding) {} @@ -309,26 +312,52 @@ class PjRtComputationClient : public ComputationClient { ss << " Data Device: " << device() << "\n"; ss << " Data Shape: " << shape().ToString() << "\n"; ss << " OpSharding: " - << xla::HloSharding::FromProto(sharding)->ToString() << "\n"; + << xla::HloSharding::FromProto(sharding.GetXlaOpSharding())->ToString() + << "\n"; + ss << " DenormalizedTileAssignment: " + << sharding.GetDenormalizedTileAssignment() << "\n"; ss << " NumShards: " << shards.size() << "\n"; return ss.str(); } bool HasSharding() const override { return true; } - xla::OpSharding GetSharding() const override { return sharding; } + torch_xla::OpSharding GetSharding() const override { return sharding; } std::vector> shards; - xla::OpSharding sharding; + torch_xla::OpSharding sharding; }; struct PjRtComputation : public Computation { - PjRtComputation(xla::XlaComputation computation, - std::vector devices, - std::unique_ptr executable) + /** + * Constructs a PjRtComputation with the given parameters. + * + * @param computation The XLA computation to wrap + * @param devices List of device strings for execution + * @param executable The compiled PJRT executable + * @param denormalized_tile_assignment Optional tile assignment for sharding + */ + PjRtComputation( + xla::XlaComputation computation, std::vector devices, + std::unique_ptr executable, + std::optional> denormalized_tile_assignment) : Computation(std::move(computation), std::move(devices)), - executable(std::move(executable)) { - output_shardings_ = this->executable->GetOutputShardings(); + executable(std::move(executable)), + denormalized_tile_assignment_(std::move( + denormalized_tile_assignment.value_or(std::vector{}))) { + xla_output_shardings_ = this->executable->GetOutputShardings(); + if (xla_output_shardings_.has_value()) { + output_shardings_ = std::vector{}; + output_shardings_->reserve(xla_output_shardings_.value().size()); + for (const auto& sharding : xla_output_shardings_.value()) { + // convert each into torch_xla::OpSharding object + torch_xla::OpSharding torch_xla_op_sharding( + sharding, denormalized_tile_assignment_); + output_shardings_.value().push_back(torch_xla_op_sharding); + } + } else { + output_shardings_ = std::nullopt; + } } const std::string get_memory_info() const override { @@ -341,7 +370,10 @@ class PjRtComputationClient : public ComputationClient { } std::unique_ptr executable; - std::optional> output_shardings_; + std::optional> xla_output_shardings_; + std::optional> output_shardings_; + std::vector denormalized_tile_assignment_; + ; }; // Use XLA replication to re-assemble the sharded data. diff --git a/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp b/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp index 0fe2b2a70fc..66e40dcc949 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp @@ -127,5 +127,64 @@ TEST_F(PjRtComputationClientTest, Init) { result_literals[0])); } +TEST_F(PjRtComputationClientTest, ThrowsWhenTileAssignmentsAreDifferent) { + // Compose a computation to add two matrices. + xla::Shape out_shape(xla::F32, {2, 2}, + /*dynamic_dimensions=*/{}); + std::vector instances; + + // Create an instance with different tile assignments for different tensors + ComputationClient::CompileInstance instance( + std::move(MakeAddComputation().value()), device_, + client_->GetCompilationDevices(device_, client_->GetLocalDevices()), + &out_shape, + /*parameter_is_tupled_arguments=*/false, + /*is_sharded=*/true, + /*allow_spmd_sharding_propagation_to_output=*/true, + /*denormalized_tile_assignments=*/ + { + {0, 1, 2, 3}, // First tensor assignment + {4, 5, 6, 7} // Second tensor assignment + }); + + instances.push_back(std::move(instance)); + + // Compiling should fail due to mismatched tile assignments + EXPECT_THROW(client_->Compile(std::move(instances)), std::runtime_error); +} + +TEST_F(PjRtComputationClientTest, + ThrowsWhenTileAssignmentDevicesNotAddressable) { + // Compose a computation to add two matrices. + xla::Shape out_shape(xla::F32, {2, 2}, + /*dynamic_dimensions=*/{}); + std::vector instances; + + // Get the number of local devices to create an invalid device ID + std::vector local_devices = client_->GetLocalDevices(); + int64_t invalid_device_id = + local_devices.size() + 100; // Use a device ID that doesn't exist + + // Create an instance with tile assignment containing non-addressable device + // IDs + ComputationClient::CompileInstance instance( + std::move(MakeAddComputation().value()), device_, + client_->GetCompilationDevices(device_, client_->GetLocalDevices()), + &out_shape, + /*parameter_is_tupled_arguments=*/false, + /*is_sharded=*/true, + /*allow_spmd_sharding_propagation_to_output=*/true, + /*denormalized_tile_assignments=*/ + { + {0, 1, invalid_device_id, + invalid_device_id + 1} // Contains non-addressable device IDs + }); + + instances.push_back(std::move(instance)); + + // Compiling should fail due to non-addressable device IDs in tile assignment + EXPECT_THROW(client_->Compile(std::move(instances)), std::runtime_error); +} + } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/util.h b/torch_xla/csrc/runtime/util.h index 112ee44f2c7..3f15a6443b1 100644 --- a/torch_xla/csrc/runtime/util.h +++ b/torch_xla/csrc/runtime/util.h @@ -15,6 +15,7 @@ #include #include "absl/status/statusor.h" +#include "absl/strings/str_split.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "torch_xla/csrc/runtime/types.h" @@ -177,6 +178,36 @@ RaisePythonValueErrorOnFailure(const Func& func) { throw std::invalid_argument(std::string(result.status().message())); } +/** + * Filters a list of device strings to include only those with IDs matching + * the provided indices. + * + * @param devices List of device strings in format "TYPE:ID" (e.g., "TPU:0") + * @param indices List of device IDs to filter by + * @return Filtered list of device strings + * + * Example: + * devices = ["TPU:0", "TPU:1", "TPU:2", "TPU:3"] + * indices = [1, 3] + * result = ["TPU:1", "TPU:3"] + */ +template +std::vector FilterDevicesByAddressableDevices( + const DeviceContainer& devices, const std::vector& indices) { + std::vector filtered_devices_; + filtered_devices_.reserve(indices.size()); + for (auto& index : indices) { + for (auto& device : devices) { + std::vector device_spec_parts = absl::StrSplit(device, ':'); + if (std::stoi(device_spec_parts[1]) == index) { + filtered_devices_.push_back(device); + break; + } + } + } + return filtered_devices_; +} + } // namespace util } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 0d49e98b67f..aefd12bd759 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -10,6 +10,7 @@ #include #include "torch_xla/csrc/runtime/util.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" #include "torch_xla/csrc/view.h" namespace torch_xla { @@ -262,13 +263,13 @@ class XLATensor : public torch::lazy::LazyTensor { // XLA SPMD sharding spec annoation. The XLA tensor uses this to create // HloSharding for replication, manual and tile shardings. struct ShardingSpec { - ShardingSpec(const xla::OpSharding& sharding, const xla::Shape& shape) + ShardingSpec(const torch_xla::OpSharding& sharding, const xla::Shape& shape) : sharding(sharding), shape(shape) {} - ShardingSpec(const xla::OpSharding& sharding, const xla::Shape& shape, + ShardingSpec(const torch_xla::OpSharding& sharding, const xla::Shape& shape, const bool& minibatch) : sharding(sharding), shape(shape), minibatch(minibatch) {} - xla::OpSharding sharding = xla::HloSharding::Unknown().ToProto(); + torch_xla::OpSharding sharding; // Optional source tensor shape unpartitioned. xla::Shape shape; // Parameter for represent input batch in sharded along batch axes diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index e2cd3a025f5..77befc0c181 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -860,14 +860,26 @@ std::vector CreateTensorsData( std::vector local_devices = runtime::GetComputationClientOrDie()->GetLocalDevices(); + std::vector addressable_devices = std::move(local_devices); + if (shardings[i]) { + const std::vector& denormalized_tile_assignment = + shardings[i]->sharding.GetDenormalizedTileAssignment(); + if ((!denormalized_tile_assignment.empty()) && + (denormalized_tile_assignment.size() != + addressable_devices.size())) { + addressable_devices = + torch_xla::runtime::util::FilterDevicesByAddressableDevices( + addressable_devices, denormalized_tile_assignment); + } + } // Shards the input tensors with padding, to split evenly. // The execution requires consistent shard sizes, and the zero-padded // values should be ignored. - std::vector local_shards = - ShardingUtil::ShardTensor(tensors[i], shardings[i], local_devices, - /*padded=*/true); + std::vector local_shards = ShardingUtil::ShardTensor( + tensors[i], shardings[i], addressable_devices, + /*padded=*/true); new_handles.push_back(ShardingUtil::CreateShardedData( - local_shards, local_devices, shardings[i])); + local_shards, addressable_devices, shardings[i])); } else { source_tensors.push_back(std::make_shared( tensors[i], std::move(shape), devices[i])); diff --git a/torch_xla/csrc/torch_xla_op_sharding.cpp b/torch_xla/csrc/torch_xla_op_sharding.cpp new file mode 100644 index 00000000000..6ad24d32015 --- /dev/null +++ b/torch_xla/csrc/torch_xla_op_sharding.cpp @@ -0,0 +1,121 @@ +#include "torch_xla_op_sharding.h" + +namespace torch_xla { + +OpSharding::OpSharding() {} + +OpSharding::OpSharding( + const xla::OpSharding& op_sharding, + const std::optional>& denormalized_tile_assignment) + : op_sharding_(std::make_unique(op_sharding)), + denormalized_tile_assignment_( + denormalized_tile_assignment.value_or(std::vector{})) {} + +OpSharding::OpSharding(const OpSharding& other) + : denormalized_tile_assignment_(other.denormalized_tile_assignment_) { + if (other.op_sharding_) { + op_sharding_ = std::make_unique(*other.op_sharding_); + } else { + // Fallback to default replicated sharding + op_sharding_ = std::make_unique(); + op_sharding_->set_type(xla::OpSharding::REPLICATED); + } +} + +OpSharding& OpSharding::operator=(const OpSharding& other) { + if (this != &other) { + if (other.op_sharding_) { + op_sharding_ = std::make_unique(*other.op_sharding_); + } else { + // Fallback to default replicated sharding + op_sharding_ = std::make_unique(); + op_sharding_->set_type(xla::OpSharding::REPLICATED); + } + denormalized_tile_assignment_ = other.denormalized_tile_assignment_; + } + return *this; +} + +OpSharding::OpSharding(OpSharding&& other) noexcept + : op_sharding_(std::move(other.op_sharding_)), + denormalized_tile_assignment_( + std::move(other.denormalized_tile_assignment_)) { + // other.op_sharding_ is now nullptr, which is safe +} + +OpSharding& OpSharding::operator=(OpSharding&& other) noexcept { + if (this != &other) { + op_sharding_ = std::move(other.op_sharding_); + denormalized_tile_assignment_ = + std::move(other.denormalized_tile_assignment_); + } + return *this; +} + +// Forwarded methods from xla::OpSharding for API compatibility +xla::OpSharding::Type OpSharding::type() const { return op_sharding_->type(); } + +bool OpSharding::replicate_on_last_tile_dim() const { + return op_sharding_->replicate_on_last_tile_dim(); +} + +int OpSharding::tile_assignment_dimensions_size() const { + return op_sharding_->tile_assignment_dimensions_size(); +} + +int OpSharding::tile_assignment_devices_size() const { + return op_sharding_->tile_assignment_devices_size(); +} + +int OpSharding::tile_assignment_dimensions(int index) const { + return op_sharding_->tile_assignment_dimensions(index); +} + +int OpSharding::tile_assignment_devices(int index) const { + return op_sharding_->tile_assignment_devices(index); +} + +std::string OpSharding::DebugString() const { + return op_sharding_->DebugString(); +} + +const ::google::protobuf::RepeatedField& +OpSharding::iota_reshape_dims() const { + return op_sharding_->iota_reshape_dims(); +} + +const ::google::protobuf::RepeatedField& +OpSharding::tile_assignment_dimensions() const { + return op_sharding_->tile_assignment_dimensions(); +} + +const ::google::protobuf::RepeatedField& +OpSharding::tile_assignment_devices() const { + return op_sharding_->tile_assignment_devices(); +} + +const ::google::protobuf::RepeatedField& +OpSharding::iota_transpose_perm() const { + return op_sharding_->iota_transpose_perm(); +} + +const ::google::protobuf::RepeatedField& OpSharding::last_tile_dims() + const { + return op_sharding_->last_tile_dims(); +} + +const xla::ShapeProto& OpSharding::tile_shape() const { + return op_sharding_->tile_shape(); +} + +const xla::OpSharding& OpSharding::GetXlaOpSharding() const { + return *op_sharding_; +} + +xla::OpSharding& OpSharding::GetMutableXlaOpSharding() { return *op_sharding_; } + +const std::vector& OpSharding::GetDenormalizedTileAssignment() const { + return denormalized_tile_assignment_; +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/torch_xla_op_sharding.h b/torch_xla/csrc/torch_xla_op_sharding.h new file mode 100644 index 00000000000..aa126b48f58 --- /dev/null +++ b/torch_xla/csrc/torch_xla_op_sharding.h @@ -0,0 +1,81 @@ +#ifndef XLA_TORCH_XLA_CSRC_TORCH_XLA_OP_SHARDING_H_ +#define XLA_TORCH_XLA_CSRC_TORCH_XLA_OP_SHARDING_H_ + +#include +#include +#include +#include + +#include "google/protobuf/repeated_field.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/shape.h" + +namespace torch_xla { + +// Wrapper class for xla::OpSharding that provides additional functionality +// and maintains denormalized tile assignment information. +// +// This class serves as a bridge between PyTorch XLA's sharding representation +// and XLA's native OpSharding, allowing for extended functionality while +// maintaining compatibility with the underlying XLA infrastructure. +class OpSharding { + public: + // Default constructor + OpSharding(); + + // Constructs OpSharding from xla::OpSharding with optional denormalized + // tile assignment + explicit OpSharding(const xla::OpSharding& op_sharding, + const std::optional>& + denormalized_tile_assignment = std::nullopt); + + // Copy constructor + OpSharding(const OpSharding& other); + + // Copy assignment operator + OpSharding& operator=(const OpSharding& other); + + // Move constructor + OpSharding(OpSharding&& other) noexcept; + + // Move assignment operator + OpSharding& operator=(OpSharding&& other) noexcept; + + // Destructor (default is sufficient due to unique_ptr) + ~OpSharding() = default; + + // Forwarded methods from xla::OpSharding for API compatibility + xla::OpSharding::Type type() const; + bool replicate_on_last_tile_dim() const; + int tile_assignment_dimensions_size() const; + int tile_assignment_devices_size() const; + int tile_assignment_dimensions(int index) const; + int tile_assignment_devices(int index) const; + std::string DebugString() const; + const ::google::protobuf::RepeatedField& iota_reshape_dims() const; + const ::google::protobuf::RepeatedField& tile_assignment_dimensions() + const; + const ::google::protobuf::RepeatedField& tile_assignment_devices() + const; + const ::google::protobuf::RepeatedField& iota_transpose_perm() const; + const ::google::protobuf::RepeatedField& last_tile_dims() const; + const xla::ShapeProto& tile_shape() const; + + // Access to underlying xla::OpSharding + const xla::OpSharding& GetXlaOpSharding() const; + xla::OpSharding& GetMutableXlaOpSharding(); + + // Access to denormalized tile assignment + const std::vector& GetDenormalizedTileAssignment() const; + + private: + // Underlying XLA OpSharding object + std::unique_ptr op_sharding_; + + // Additional denormalized tile assignment information + std::vector denormalized_tile_assignment_; +}; + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_TORCH_XLA_OP_SHARDING_H_ diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 65eee78bc02..ed5116a82cc 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -1178,18 +1178,26 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( std::vector* tensors, SyncTensorCollection* coll, std::vector parameters_data, std::string device, ComputationCache::TypePtr cached_computation, - const std::vector& tensor_data_vec) { + const std::vector& tensor_data_vec, + std::optional>> + denormalized_tile_assignments) { auto tensors_data = SetTensorData(tensors, coll->config, coll->indices, tensor_data_vec); std::vector sharding_specs(coll->indices.size(), nullptr); + std::optional> denormalized_tile_assignment = + std::nullopt; + if ((denormalized_tile_assignments.has_value()) && + (!denormalized_tile_assignments.value().empty())) { + denormalized_tile_assignment = denormalized_tile_assignments.value()[0]; + } // Extract sharding specs for the results and prepare the sharded data // placeholders if the computation is sharded. if (cached_computation->is_sharded) { ShardingUtil::PrepareOutputShardingPropagation( tensors, coll->indices, cached_computation->computation, &tensors_data, - &sharding_specs); + &sharding_specs, denormalized_tile_assignment); DebugUtil::SaveOutputShardingInfo(tensors, coll->indices); } @@ -1250,15 +1258,31 @@ XLAGraphExecutor::TryRunCachedSync( << torch::lazy::Hash(po_data->parameter_sequence); } + std::vector> denormalized_tile_assignments; + for (const auto* node : po_data->post_order) { + const XlaNode* const casted = dynamic_cast(node); + auto shardings = casted->GetShardings(); + if (!shardings.empty()) { + for (auto sharding : shardings) { + std::vector denormalized_tile_assignment = + sharding->GetDenormalizedTileAssignment(); + if (!denormalized_tile_assignment.empty()) { + denormalized_tile_assignments.push_back( + sharding->GetDenormalizedTileAssignment()); + } + } + } + } // don't schedule the execution if the purpose of this SyncTensor is just to // warm up the cache. return std::pair>( - cache_hit, warm_up_cache_only - ? nullptr - : ScheduleSyncTensorsGraph( - tensors, coll, std::move(po_data->parameters_data), - coll->device.toString(), - std::move(cached_computation), tensor_data_vec)); + cache_hit, + warm_up_cache_only + ? nullptr + : ScheduleSyncTensorsGraph( + tensors, coll, std::move(po_data->parameters_data), + coll->device.toString(), std::move(cached_computation), + tensor_data_vec, denormalized_tile_assignments)); } std::vector GetBufferDonorIndexForStepMarker( @@ -1434,13 +1458,16 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( } xla::Shape shape = MakeShapeWithDeviceLayout( program_shape.result(), static_cast(coll.device.type())); - + std::vector> denormalized_tile_assignments = + lowering_ctx.GetDenormalizedTileAssignments(); std::vector instances; instances.push_back( {std::move(computation), coll.device.toString(), runtime::GetComputationClientOrDie()->GetCompilationDevices( coll.device.toString(), devices), - &shape, should_wrap_parameter, is_sharded}); + &shape, should_wrap_parameter, is_sharded, + /*allow_spmd_sharding_propagation_to_output=*/true, + /*denormalized_tile_assignments=*/denormalized_tile_assignments}); instances.front().eager_mode = UseEagerMode(); if (use_autosharding) { TF_VLOG(5) << "use_auto_spmd_partitioning is set."; @@ -1503,7 +1530,8 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( /*emitted_nodes=*/lowering_ctx.GetEmittedNodeCount(), /*computation=*/computations.front(), /*parameters_data=*/std::move(po_data->parameters_data), - /*is_sharded=*/is_sharded}; + /*is_sharded=*/is_sharded, + /*denormalized_tile_assignments=*/denormalized_tile_assignments}; } std::shared_ptr @@ -1579,7 +1607,7 @@ XLAGraphExecutor::SyncTensorsGraphInternal( return ScheduleSyncTensorsGraph( tensors, &coll, std::move(compile_result.parameters_data), compile_result.device.toString(), std::move(cached_computation), - tensor_data_vec); + tensor_data_vec, compile_result.denormalized_tile_assignments); } } diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index 5e1acaa9d47..8467f54a2e4 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -225,6 +225,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { runtime::ComputationClient::ComputationPtr computation; std::vector parameters_data; bool is_sharded = false; + std::vector> denormalized_tile_assignments = {}; }; struct Async : public torch::lazy::LazyGraphExecutor::Async { @@ -359,7 +360,9 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { std::vector* tensors, SyncTensorCollection* coll, std::vector parameters_data, std::string device, ComputationCache::TypePtr cached_computation, - const std::vector& tensor_data_vec); + const std::vector& tensor_data_vec, + std::optional>> + denormalized_tile_assignments = std::nullopt); // Override to enable profiler. PostOrderData RunPostOrder(const std::vector& ir_values, diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 3ddec53d700..21a42c23ed6 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -19,6 +19,7 @@ #include "torch_xla/csrc/tensor_methods.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/thread_pool.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" #include "torch_xla/csrc/xla_graph_executor.h" #include "tsl/profiler/lib/traceme.h" #include "xla/execution_options_util.h" @@ -176,10 +177,11 @@ bool ShardingUtil::SetHloSharding(LoweringContext* lowering_ctx) { const XlaNode* xla_node = dynamic_cast(node); xla::HloInstructionProto* instruction = XlaBuilderFriend::GetInstruction(elem.second); - const std::shared_ptr sharding = + const std::shared_ptr sharding = xla_node->GetSharding(elem.first.index); if (sharding != nullptr && sharding->type() != xla::OpSharding::UNKNOWN) { - *instruction->mutable_sharding() = *sharding; + const xla::OpSharding& sharding_obj = sharding->GetMutableXlaOpSharding(); + *instruction->mutable_sharding() = sharding_obj; is_sharded = true; } } @@ -187,7 +189,7 @@ bool ShardingUtil::SetHloSharding(LoweringContext* lowering_ctx) { } ShardingUtil::ShardingType ShardingUtil::GetShardingType( - xla::OpSharding& sharding) { + torch_xla::OpSharding& sharding) { switch (sharding.type()) { case xla::OpSharding::REPLICATED: return ShardingType::REPLICATED; @@ -210,15 +212,47 @@ ShardingUtil::ShardingType ShardingUtil::GetShardingType( bool ShardingUtil::EqualShardingSpecs(const XLATensor::ShardingSpec& a, const XLATensor::ShardingSpec& b) { - return xla::protobuf_util::HaveSameSerialization(a.sharding, b.sharding); + return xla::protobuf_util::HaveSameSerialization( + a.sharding.GetXlaOpSharding(), b.sharding.GetXlaOpSharding()); } -bool ShardingUtil::EqualOpShardings(const xla::OpSharding& a, - const xla::OpSharding& b) { - return xla::protobuf_util::HaveSameSerialization(a, b); +bool ShardingUtil::EqualOpShardings(const torch_xla::OpSharding& a, + const torch_xla::OpSharding& b) { + return xla::protobuf_util::HaveSameSerialization(a.GetXlaOpSharding(), + b.GetXlaOpSharding()); } -xla::OpSharding ShardingUtil::CreateOpSharding( +// function to normalize tile_assignment +std::vector ShardingUtil::NormalizeTileAssignment( + const std::vector& tile_assignment) { + // Check if the tile_assignment is empty + if (tile_assignment.empty()) { + TF_LOG(WARNING) << "tile_assignment is empty"; + return tile_assignment; + } + + // Find the minimum value in the tile_assignment + int64_t min_value = + *std::min_element(tile_assignment.begin(), tile_assignment.end()); + + // check if min_value of tile_assignment is positive + XLA_CHECK(min_value >= 0) + << "min_value of tile_assignment cannot be negative"; + + // Create a vector to store the normalized tile_assignment + std::vector normalized_tile_assignment; + normalized_tile_assignment.reserve( + tile_assignment.size()); // Reserve space to avoid reallocations + + // Normalize each device ID by subtracting the minimum value + for (const auto& device : tile_assignment) { + normalized_tile_assignment.push_back(device - min_value); + } + + return normalized_tile_assignment; +} + +torch_xla::OpSharding ShardingUtil::CreateOpSharding( const py::list& tile_assignment, const py::list& group_assignment, const py::list& replication_groups, ShardingType sharding_type) { TORCH_LAZY_COUNTER("CreateOpSharding", 1); @@ -281,7 +315,40 @@ xla::OpSharding ShardingUtil::CreateOpSharding( TF_LOG(ERROR) << "Invalid arguments: sharding_type " << sharding_type; } } - return sharding; + + // Create denormalized_tile_assignment. If sharding.tile_assignment_devices() + // is empty (which happens for REPLICATED, MANUAL, UNKNOWN sharding types), + // use the original tile_assignment arg that was passed to this function. + std::vector denormalized_tile_assignment; + if (sharding.tile_assignment_devices().empty() && !tile_assignment.empty()) { + // Convert the Python list tile_assignment to a flattened vector for + // denormalized assignment + xla::Array tile_array = TileListToArray(tile_assignment); + denormalized_tile_assignment.assign(tile_array.begin(), tile_array.end()); + } else { + // Use the tile_assignment_devices from the XLA OpSharding object + denormalized_tile_assignment.assign( + sharding.tile_assignment_devices().begin(), + sharding.tile_assignment_devices().end()); + } + + // normalize tile_assigment_devices to start from 0 + std::vector normalized_tile_assignment_devices = + NormalizeTileAssignment(denormalized_tile_assignment); + + // clear sharding.tile_assignment_devices + sharding.clear_tile_assignment_devices(); + + // add the normalized tile_assignment_devices to sharding object + for (int64_t device : normalized_tile_assignment_devices) { + sharding.add_tile_assignment_devices(device); + } + // Use the xla::OpSharding object in the wrapper torch_xla::OpSharding along + // with denormalized_tile_assignment (original tile_assignment) for extended + // functionality + torch_xla::OpSharding torch_xla_opsharding(sharding, + denormalized_tile_assignment); + return torch_xla_opsharding; } std::vector ShardingUtil::GetShardShape( @@ -336,7 +403,8 @@ ShardingUtil::GetShardIndicesForMinibatchTensor( std::vector>> ShardingUtil::GetShardReplicaAndIndicesForDevices( const std::vector& shard_shape, - const std::vector& tensor_shape, const xla::OpSharding sharding, + const std::vector& tensor_shape, + const torch_xla::OpSharding sharding, const std::vector& devices) { using namespace at::indexing; @@ -359,9 +427,8 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices( } } else if (sharding.type() == xla::OpSharding::OTHER) { auto device_index = build_index_map(devices); - std::vector tile_assignment_devices( - sharding.tile_assignment_devices().begin(), - sharding.tile_assignment_devices().end()); + std::vector tile_assignment_devices = + sharding.GetDenormalizedTileAssignment(); if (!sharding.iota_reshape_dims().empty()) { auto tileAssignment = xla::TileAssignment( sharding.tile_assignment_dimensions(), sharding.iota_reshape_dims(), @@ -421,7 +488,7 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices( std::vector ShardingUtil::ShardTensor( const at::Tensor& tensor, const XLATensor::ShardingSpecPtr shardings, const std::vector& devices, bool padded) { - xla::OpSharding sharding; + torch_xla::OpSharding sharding; bool minibatch = false; if (shardings != nullptr) { sharding = shardings->sharding; @@ -480,28 +547,40 @@ std::vector ShardingUtil::ShardTensor( std::vector ShardingUtil::GetOutputSharding( const std::vector& output_shapes, - runtime::ComputationClient::ComputationPtr computation) { + runtime::ComputationClient::ComputationPtr computation, + std::optional> denormalized_tile_assignment) { const auto& computation_proto = computation->computation().proto(); size_t num_outputs = output_shapes.size(); - std::vector output_shardings; + std::vector xla_output_shardings; + std::vector output_shardings; std::vector sharding_specs(num_outputs); if (computation_proto.has_spmd_output_sharding()) { if (computation_proto.spmd_output_sharding().tuple_shardings().size() > 0) { auto tuple_shardings = computation_proto.spmd_output_sharding().tuple_shardings(); - output_shardings = std::vector(tuple_shardings.begin(), - tuple_shardings.end()); + xla_output_shardings = std::vector( + tuple_shardings.begin(), tuple_shardings.end()); } else { - output_shardings = std::vector{ + xla_output_shardings = std::vector{ computation_proto.spmd_output_sharding()}; } } // Output parameter sharding annotations, defaults to REPLICATED(0) if // unset. - if (output_shardings.empty()) { + if (xla_output_shardings.empty()) { // Initializes with default sharding type, REPLCIATED. - output_shardings.resize(num_outputs); + xla_output_shardings.resize(num_outputs); + } + + for (const auto& sharding : xla_output_shardings) { + if ((denormalized_tile_assignment.has_value()) && + (!denormalized_tile_assignment.value().empty())) { + output_shardings.emplace_back(sharding, + denormalized_tile_assignment.value()); + } else { + output_shardings.emplace_back(sharding, std::nullopt); + } } for (int i = 0; i < num_outputs; ++i) { @@ -535,7 +614,8 @@ void ShardingUtil::PrepareOutputShardingPropagation( std::vector* tensors, absl::Span indices, runtime::ComputationClient::ComputationPtr computation, std::vector* data_placeholders, - std::vector* sharding_specs) { + std::vector* sharding_specs, + std::optional> denormalized_tile_assignment) { // Resizes the containers to `indices.size()`. data_placeholders->resize(indices.size()); sharding_specs->resize(indices.size()); @@ -546,7 +626,8 @@ void ShardingUtil::PrepareOutputShardingPropagation( auto xtensor = (*tensors)[indices[i]]; output_shapes.push_back(xtensor->shape().get()); } - auto new_sharding_specs = GetOutputSharding(output_shapes, computation); + auto new_sharding_specs = GetOutputSharding(output_shapes, computation, + denormalized_tile_assignment); XLA_CHECK(indices.size() == new_sharding_specs.size()) << "Expected size: " << indices.size() << ", actual size: " << new_sharding_specs.size(); @@ -577,31 +658,46 @@ void ShardingUtil::PrepareOutputShardingPropagation( } } -runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( +namespace { +// helper function to get torch_xla::OpSharding +torch_xla::OpSharding GetTorchXlaOpSharding( + xla::Shape* global_shape, const XLATensor::ShardingSpecPtr& sharding_spec, const std::vector& local_shards, - const std::vector& devices, - const XLATensor::ShardingSpecPtr& sharding_spec) { - XLA_CHECK(local_shards.size() == devices.size()) - << "A device must be speficied for each shard"; - std::vector> source_tensors; - xla::Shape global_shape; - xla::OpSharding sharding; + const std::vector& devices) { if (sharding_spec == nullptr) { // Unknown type is used to mark implicitly replicated data for // auto-sharding. // TODO(yeounoh) see if we can completely rely on Unknown without inference // performance degradation. - sharding = ShardingUtil::GetAutoSharding() - ? xla::HloSharding::Unknown().ToProto() - : xla::HloSharding::Replicate().ToProto(); + xla::OpSharding sharding = ShardingUtil::GetAutoSharding() + ? xla::HloSharding::Unknown().ToProto() + : xla::HloSharding::Replicate().ToProto(); // if replicated, global_shape is shape of the tensor. auto first_device = ParseDeviceString(devices[0]); - global_shape = + *global_shape = CreateComputationShapeFromTensor(local_shards[0], &first_device); + return torch_xla::OpSharding(sharding, std::nullopt); } else { - global_shape = sharding_spec->shape; - sharding = sharding_spec->sharding; + *global_shape = sharding_spec->shape; + return sharding_spec->sharding; } +} + +} // namespace + +runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( + const std::vector& local_shards, + const std::vector& devices, + const XLATensor::ShardingSpecPtr& sharding_spec) { + XLA_CHECK(local_shards.size() == devices.size()) + << "A device must be speficied for each shard"; + std::vector> source_tensors; + xla::Shape global_shape; + + // get torch_xla::OpSharding object + torch_xla::OpSharding sharding = GetTorchXlaOpSharding( + &global_shape, sharding_spec, local_shards, devices); + for (int64_t j = 0; j < devices.size(); ++j) { auto shard_device = ParseDeviceString(devices[j]); auto shard_shape = @@ -672,23 +768,24 @@ void ShardingUtil::ReshardParameters( std::vector* parameters, std::vector* nodes) { // Extract input shardings generated from auto-sharding pass. - std::vector input_shardings; + std::vector xla_input_shardings; + std::vector input_shardings; if (module.spmd_parameters_shardings().size() == 1 && module.spmd_parameters_shardings()[0].type() == xla::OpSharding::TUPLE) { auto tuple_shardings = module.spmd_parameters_shardings()[0].tuple_shardings(); - input_shardings = std::vector(tuple_shardings.begin(), - tuple_shardings.end()); + xla_input_shardings = std::vector(tuple_shardings.begin(), + tuple_shardings.end()); } else { for (auto sharding : module.spmd_parameters_shardings()) { - input_shardings.push_back(sharding); + xla_input_shardings.push_back(sharding); } } - if (input_shardings.size() == 0) { + if (xla_input_shardings.size() == 0) { TF_VLOG(3) << "ReshardParamters... skip with empty input_shardings."; return; } - XLA_CHECK_EQ(input_shardings.size(), parameters->size()); + XLA_CHECK_EQ(xla_input_shardings.size(), parameters->size()); // Reshard parameters as needed, as with a new sharding spec. std::vector data = @@ -696,13 +793,29 @@ void ShardingUtil::ReshardParameters( std::vector reshard_indices; std::vector data_to_reshard; - std::vector shardings_to_reshard; + std::vector shardings_to_reshard; + + std::vector denormalized_tile_assignment; + auto sharding_spec = (*tensors)[0]->sharding_spec(); + if (sharding_spec) { + denormalized_tile_assignment = + sharding_spec->sharding.GetDenormalizedTileAssignment(); + } + for (const auto& sharding : xla_input_shardings) { + if (denormalized_tile_assignment.size() > 0) { + input_shardings.emplace_back(sharding, denormalized_tile_assignment); + } else { + input_shardings.emplace_back(sharding); + } + } + for (int i = 0; i < input_shardings.size(); ++i) { XLA_CHECK(input_shardings[i].type() != xla::OpSharding::UNKNOWN) << "Resharding by UNKNOWN sharding type is not allowed."; // Skip re-sharding if not necessary. - if (!xla::protobuf_util::HaveSameSerialization(data[i]->GetSharding(), - input_shardings[i])) { + if (!xla::protobuf_util::HaveSameSerialization( + data[i]->GetSharding().GetXlaOpSharding(), + input_shardings[i].GetXlaOpSharding())) { reshard_indices.push_back(i); data_to_reshard.push_back(data[i]); shardings_to_reshard.push_back(input_shardings[i]); @@ -761,7 +874,7 @@ void ShardingUtil::ReshardParameters( } void ShardingUtil::XlaMarkSharding(const at::Tensor& input, - xla::OpSharding sharding) { + torch_xla::OpSharding sharding) { TORCH_LAZY_COUNTER("XlaMarkSharding", 1); XLA_CHECK(UseVirtualDevice()) << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; @@ -839,7 +952,7 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input, } void ShardingUtil::XlaAnnotateCustomSharding(const XLATensorPtr& input, - xla::OpSharding sharding) { + torch_xla::OpSharding sharding) { TORCH_LAZY_COUNTER("XlaAnnotateCustomSharding", 1); XLA_CHECK(UseVirtualDevice()) diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 8b8b98653b2..87f6638990c 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -9,6 +9,7 @@ #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/tensor.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/service/hlo.pb.h" @@ -28,8 +29,8 @@ class ShardingUtil { UNKNOWN = 6 // implicit replication }; - // Determine the ShardingType of the given xla::OpSharding. - static ShardingType GetShardingType(xla::OpSharding& sharding); + // Determine the ShardingType of the given torch_xla::OpSharding. + static ShardingType GetShardingType(torch_xla::OpSharding& sharding); // Annotates HLO instructions in the lowered computation and returns true if // the computation needs to be compiled with SPMD partitioning. For this call @@ -42,15 +43,18 @@ class ShardingUtil { const XLATensor::ShardingSpec& b); // Returns true if two OpShardings are the same. - static bool EqualOpShardings(const xla::OpSharding& a, - const xla::OpSharding& b); + static bool EqualOpShardings(const torch_xla::OpSharding& a, + const torch_xla::OpSharding& b); - // Creates an xla::OpSharding. `tile_assignmnent` is required for TILED + // Returns tile_assignment after normalizing + static std::vector NormalizeTileAssignment( + const std::vector& tile_assignment); + + // Creates an torch_xla::OpSharding. `tile_assignmnent` is required for TILED // `sharding_type` and `replication_groups` for `PARTIAL`. - static xla::OpSharding CreateOpSharding(const py::list& tile_assignment, - const py::list& group_assignment, - const py::list& replication_groups, - ShardingType sharding_type); + static torch_xla::OpSharding CreateOpSharding( + const py::list& tile_assignment, const py::list& group_assignment, + const py::list& replication_groups, ShardingType sharding_type); // Returns the shape of the resulting shards of `tensor` after applying // `sharding`. This assumes the shards will be padded to ensure they all @@ -67,7 +71,7 @@ class ShardingUtil { static std::vector>> GetShardReplicaAndIndicesForDevices(const std::vector& shard_shape, const std::vector& tensor_shape, - const xla::OpSharding sharding, + const torch_xla::OpSharding sharding, const std::vector& devices); // Returns the indices for the shards. Supports `OTHER` sharding types and @@ -94,7 +98,9 @@ class ShardingUtil { // is always on virtual SPMD device. static std::vector GetOutputSharding( const std::vector& output_shapes, - runtime::ComputationClient::ComputationPtr computation); + runtime::ComputationClient::ComputationPtr computation, + std::optional> denormalized_tile_assignment = + std::nullopt); // Create sharded data placeholders, each corresponding to the individual // sharding spec from the input list @@ -111,7 +117,9 @@ class ShardingUtil { std::vector* tensors, absl::Span indices, runtime::ComputationClient::ComputationPtr computation, std::vector* data_placeholders, - std::vector* sharding_specs); + std::vector* sharding_specs, + std::optional> denormalized_tile_assignment = + std::nullopt); // Transfers the individual shards to the devices and returns a DataPtr for // the PjRtShardedData wrapping the shards. @@ -121,14 +129,14 @@ class ShardingUtil { const XLATensor::ShardingSpecPtr& sharding_spec); static void XlaMarkSharding(const at::Tensor& input, - xla::OpSharding sharding); + torch_xla::OpSharding sharding); // Add a custom sharding node IR to an XLATensor. Note that unlike // XlaMarkSharding, this will not explicitly set a sharding spec tied to the // DeviceData node, nor transfer any sharded data to the device. This serves // merely as an XLA custom sharding annotation IR. static void XlaAnnotateCustomSharding(const XLATensorPtr& input, - xla::OpSharding sharding); + torch_xla::OpSharding sharding); //////////////////////////// Auto-Sharding //////////////////////////// // Construct a device mesh for auto-sharding pass. Returns a tuple of mesh diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 5f4d4378e7d..c56b54faf7b 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -78,9 +78,8 @@ def __init__(self, # devices. num_devices = xr.global_runtime_device_count() assert num_devices > 0, "This requires XLA supported device(s)." - assert num_devices == len( - device_ids - ), f"Number of device IDs ({len(device_ids)}) must match the global number of devices ({num_devices})" + assert len(device_ids) <= num_devices, \ + f"Number of device IDs ({len(device_ids)}) must be less than the global number of devices ({num_devices})" if axis_names is not None: assert len(mesh_shape) == len(axis_names), \ @@ -97,8 +96,10 @@ def __init__(self, self.device_ids = device_ids self.mesh_shape = mesh_shape self.axis_names = axis_names - assert all(d < self.size() for d in device_ids), \ - f"Device IDs must be less than mesh size ({self.size()}), got: {device_ids}" + assert len(self.device_ids) <= self.size(), \ + f"Length of device IDs cannot be greater than mesh size ({self.size()}), got: {device_ids}" + assert set(device_ids).issubset(set(np.arange(xr.addressable_runtime_device_count()))), \ + f"Device IDs has to be subset of addressable_devices; got: {device_ids} and {np.arange(xr.addressable_runtime_device_count())}" def size(self): return np.prod(self.mesh_shape)