diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index b179c6e523c..74c7d16f38b 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -42,7 +42,7 @@ class XLAShardingTest : public AtenXlaTensorTestBase { } }; -TEST_F(XLAShardingTest, GetShardShape) { +TEST_F(XLAShardingTest, GetShardShapeTiled) { auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); @@ -50,21 +50,34 @@ TEST_F(XLAShardingTest, GetShardShape) { {0, 1}, {2, 3}, }); - auto sharding = xla::HloSharding::Tile(mesh).ToProto(); + auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto(); + std::vector denormalized_tile_assignment = {0, 1, 2, 3}; + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); auto sharding_spec = std::make_shared(sharding, tensor_shape); auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); // For tiled sharding, each dimension should be halved EXPECT_EQ(shard_shape, std::vector({4, 4})); +} + +TEST_F(XLAShardingTest, GetShardShapeReplicated) { + auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + std::vector denormalized_tile_assignment = {0, 1, 2, 3}; - sharding_spec->sharding = xla::HloSharding::Replicate().ToProto(); - shard_shape = ShardingUtil::GetShardShape(sharding_spec); + auto xla_sharding = xla::HloSharding::Replicate().ToProto(); + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); + auto sharding_spec = + std::make_shared(sharding, tensor_shape); + + auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); // For replicated sharding, each dimension should be preserved EXPECT_EQ(shard_shape, std::vector({8, 7})); } -TEST_F(XLAShardingTest, GetShardIndicesForDevices) { +TEST_F(XLAShardingTest, GetShardIndicesForDevicesTiled) { std::vector devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3"}; auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); @@ -74,7 +87,9 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { {0, 1}, {2, 3}, }); - auto sharding = xla::HloSharding::Tile(mesh).ToProto(); + auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto(); + std::vector denormalized_tile_assignment = {0, 1, 2, 3}; + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); auto sharding_spec = std::make_shared(sharding, tensor_shape); auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); @@ -103,10 +118,22 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { EXPECT_EQ(slice.step(), 1); } } - sharding = xla::HloSharding::Replicate().ToProto(); - sharding_spec->sharding = sharding; - shard_shape = ShardingUtil::GetShardShape(sharding_spec); - replica_and_indices = ShardingUtil::GetShardReplicaAndIndicesForDevices( +} + +TEST_F(XLAShardingTest, GetShardIndicesForDevicesReplicated) { + std::vector devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3"}; + + auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + std::vector denormalized_tile_assignment = {0, 1, 2, 3}; + + auto xla_sharding = xla::HloSharding::Replicate().ToProto(); + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); + auto sharding_spec = + std::make_shared(sharding, tensor_shape); + auto shard_shape = ShardingUtil::GetShardShape(sharding_spec); + auto replica_and_indices = ShardingUtil::GetShardReplicaAndIndicesForDevices( shard_shape, tensor.sizes().vec(), sharding, devices); EXPECT_EQ(replica_and_indices.size(), devices.size()); for (int i = 0; i < devices.size(); ++i) { @@ -118,19 +145,21 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) { } } -TEST_F(XLAShardingTest, ShardTensor) { +TEST_F(XLAShardingTest, ShardTensor1D) { std::vector devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5", "TPU:6", "TPU:7"}; + std::vector denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7}; // 1D tiled at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat)); xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); - xla::OpSharding sharding = + xla::OpSharding xla_sharding = xla::HloSharding::Tile1D( CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()), devices.size()) .ToProto(); + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); auto sharding_spec = std::make_shared(sharding, tensor_shape); auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, @@ -138,95 +167,171 @@ TEST_F(XLAShardingTest, ShardTensor) { EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({1})); EXPECT_EQ(shards[1].sizes(), c10::ArrayRef({1})); +} + +TEST_F(XLAShardingTest, ShardTensor2D) { + std::vector devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3", + "TPU:4", "TPU:5", "TPU:6", "TPU:7"}; + std::vector denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7}; // 2D tiled, The first dim is halved and the last replicated. The last shard // size should be smaller in dim=1 because it's not evenly divisible. - tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); - tensor_shape = + at::Tensor tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); xla::Array2D mesh({ {0, 1, 2, 3}, {4, 5, 6, 7}, }); - sharding = xla::HloSharding::Tile(mesh).ToProto(); - sharding_spec = + xla::OpSharding xla_sharding = xla::HloSharding::Tile(mesh).ToProto(); + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); + auto sharding_spec = std::make_shared(sharding, tensor_shape); - shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, - /*padded=*/false); + auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({4, 2, 4})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({4, 1, 4})); +} + +TEST_F(XLAShardingTest, ShardTensor3D) { + std::vector devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3", + "TPU:4", "TPU:5", "TPU:6", "TPU:7"}; + std::vector denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7}; // 3D tiled, the first dim is replicated and the last halved. The last shard // size should be smaller in dim=1 because it's not evenly divisible. + at::Tensor tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); xla::Array3D cube({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}); - sharding_spec->sharding = xla::HloSharding::Tile(cube).ToProto(); - shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, - /*padded=*/false); + xla::OpSharding xla_sharding = xla::HloSharding::Tile(cube).ToProto(); + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); + auto sharding_spec = + std::make_shared(sharding, tensor_shape); + auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({8, 2, 2})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({8, 1, 2})); +} + +TEST_F(XLAShardingTest, ShardTensorReplicated) { + std::vector devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3", + "TPU:4", "TPU:5", "TPU:6", "TPU:7"}; + std::vector denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7}; // Replicated, all shards should be identical. - sharding_spec->sharding = xla::HloSharding::Replicate().ToProto(); - shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, - /*padded=*/false); + at::Tensor tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + xla::OpSharding xla_sharding = xla::HloSharding::Replicate().ToProto(); + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); + auto sharding_spec = + std::make_shared(sharding, tensor_shape); + auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({8, 7, 4})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({8, 7, 4})); +} + +TEST_F(XLAShardingTest, ShardTensor4D) { + std::vector devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3", + "TPU:4", "TPU:5", "TPU:6", "TPU:7"}; + std::vector denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7}; // 4D tiled, the first and second dims are replicated and the last halved. The // last shard size should be smaller in dim=2 because it's not evenly // divisible. - tensor = at::ones({1, 8, 7, 4}, at::TensorOptions(at::kFloat)); - tensor_shape = + at::Tensor tensor = at::ones({1, 8, 7, 4}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); xla::Array4D tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}}); - sharding = xla::HloSharding::Tile(tesseract).ToProto(); - sharding_spec = + xla::OpSharding xla_sharding = xla::HloSharding::Tile(tesseract).ToProto(); + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); + auto sharding_spec = std::make_shared(sharding, tensor_shape); - shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, - /*padded=*/false); + auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({1, 8, 2, 2})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({1, 8, 1, 2})); +} + +TEST_F(XLAShardingTest, ShardTensor4DPadded) { + std::vector devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3", + "TPU:4", "TPU:5", "TPU:6", "TPU:7"}; + std::vector denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7}; - // 4D tiled and padded, all shard sizes should be idential. - shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, - /*padded=*/true); + // 4D tiled and padded, all shard sizes should be identical. + at::Tensor tensor = at::ones({1, 8, 7, 4}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + xla::Array4D tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}}); + xla::OpSharding xla_sharding = xla::HloSharding::Tile(tesseract).ToProto(); + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); + auto sharding_spec = + std::make_shared(sharding, tensor_shape); + auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/true); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({1, 8, 2, 2})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({1, 8, 2, 2})); +} + +TEST_F(XLAShardingTest, ShardTensor5D) { + std::vector devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3", + "TPU:4", "TPU:5", "TPU:6", "TPU:7"}; + std::vector denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7}; // 5D tiled, the first and second dims are replicated and the last halved. The // last shard size should be smaller in dim=2 because it's not evenly // divisible. - tensor = at::ones({10, 1, 8, 7, 4}, at::TensorOptions(at::kFloat)); - tensor_shape = + at::Tensor tensor = at::ones({10, 1, 8, 7, 4}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); xla::Array hypercube(std::vector{1, 1, 2, 2, 2}); hypercube.FillIota(0); - sharding = xla::HloSharding::Tile(hypercube).ToProto(); - sharding_spec = + xla::OpSharding xla_sharding = xla::HloSharding::Tile(hypercube).ToProto(); + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); + auto sharding_spec = std::make_shared(sharding, tensor_shape); - shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, - /*padded=*/false); + auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({10, 1, 4, 4, 2})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({10, 1, 4, 3, 2})); +} + +TEST_F(XLAShardingTest, ShardTensor5DPadded) { + std::vector devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3", + "TPU:4", "TPU:5", "TPU:6", "TPU:7"}; + std::vector denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7}; // 5D tiled and padded, all shard sizes should be identical. - shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, - /*padded=*/true); + at::Tensor tensor = at::ones({10, 1, 8, 7, 4}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + xla::Array hypercube(std::vector{1, 1, 2, 2, 2}); + hypercube.FillIota(0); + xla::OpSharding xla_sharding = xla::HloSharding::Tile(hypercube).ToProto(); + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); + auto sharding_spec = + std::make_shared(sharding, tensor_shape); + auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/true); EXPECT_EQ(shards.size(), 8); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({10, 1, 4, 4, 2})); EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({10, 1, 4, 4, 2})); } -TEST_F(XLAShardingTest, ShardTensorMultiHost) { +TEST_F(XLAShardingTest, ShardTensorMultiHostStartOfMesh) { std::vector devices = {"TPU:4", "TPU:5", "TPU:6", "TPU:7"}; // 2D tiled, The first dim is halved and the last replicated. + // For devices at the start of the mesh, all shards should have the same + // unpadded shape. at::Tensor tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); @@ -234,26 +339,37 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) { {4, 5, 0, 1}, {6, 7, 2, 3}, }); - auto sharding = xla::HloSharding::Tile(mesh).ToProto(); + auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto(); + std::vector denormalized_tile_assignment = {4, 5, 0, 1, 6, 7, 2, 3}; + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); auto sharding_spec = std::make_shared(sharding, tensor_shape); - // For devices at the start of the mesh, all shards should have the same - // unpadded shape. auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, /*padded=*/false); EXPECT_EQ(shards.size(), 4); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({4, 2, 4})); EXPECT_EQ(shards[3].sizes(), c10::ArrayRef({4, 2, 4})); +} + +TEST_F(XLAShardingTest, ShardTensorMultiHostEndOfMesh) { + std::vector devices = {"TPU:4", "TPU:5", "TPU:6", "TPU:7"}; // When this host's devices are at the end of the mesh, the last shard should // be smaller in dim=2 because it's not evenly divisible. - mesh = xla::Array2D({ + at::Tensor tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + xla::Array2D mesh({ {0, 1, 4, 5}, {2, 3, 6, 7}, }); - sharding_spec->sharding = xla::HloSharding::Tile(mesh).ToProto(); - shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, - /*padded=*/false); + auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto(); + std::vector denormalized_tile_assignment = {0, 1, 4, 5, 2, 3, 6, 7}; + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); + auto sharding_spec = + std::make_shared(sharding, tensor_shape); + auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); EXPECT_EQ(shards.size(), 4); EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({4, 2, 4})); EXPECT_EQ(shards[3].sizes(), c10::ArrayRef({4, 1, 4})); @@ -278,7 +394,9 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) { {{7}}, }); - auto sharding = xla::HloSharding::Tile(mesh).ToProto(); + auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto(); + std::vector denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7}; + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); auto sharding_spec = std::make_shared( sharding, global_shape, /*minibatch=*/true); auto shards = ShardingUtil::ShardTensor(minibatch_tensor, sharding_spec, @@ -288,24 +406,87 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) { EXPECT_EQ(shards[3].sizes(), c10::ArrayRef({2, 7, 4})); } -TEST_F(XLAShardingTest, EqualShardingSpecs) { +TEST_F(XLAShardingTest, EqualShardingSpecsSameSpecs) { auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); xla::Shape tensor_shape = CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); - XLATensor::ShardingSpec tiled_2d(xla::HloSharding::Tile({ - {0, 1, 2, 3}, - {4, 5, 6, 7}, - }) - .ToProto(), - tensor_shape); - XLATensor::ShardingSpec tiled_3d( - xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto(), - tensor_shape); - XLATensor::ShardingSpec replicated(xla::HloSharding::Replicate().ToProto(), - tensor_shape); + auto xla_sharding = xla::HloSharding::Tile({ + {0, 1, 2, 3}, + {4, 5, 6, 7}, + }) + .ToProto(); + std::vector denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7}; + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); + XLATensor::ShardingSpec tiled_2d(sharding, tensor_shape); + + // Test that identical sharding specs are equal EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_2d)); +} + +TEST_F(XLAShardingTest, EqualShardingSpecsDifferentTiledSpecs) { + auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + + // Create 2D tiled sharding + auto xla_sharding_2d = xla::HloSharding::Tile({ + {0, 1, 2, 3}, + {4, 5, 6, 7}, + }) + .ToProto(); + std::vector denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7}; + torch_xla::OpSharding sharding_2d(xla_sharding_2d, + denormalized_tile_assignment); + XLATensor::ShardingSpec tiled_2d(sharding_2d, tensor_shape); + + // Create 3D tiled sharding + auto xla_sharding_3d = + xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto(); + torch_xla::OpSharding sharding_3d(xla_sharding_3d, + denormalized_tile_assignment); + XLATensor::ShardingSpec tiled_3d(sharding_3d, tensor_shape); + + // Test that different tiled sharding specs are not equal EXPECT_FALSE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_3d)); +} + +TEST_F(XLAShardingTest, EqualShardingSpecsReplicatedSpecs) { + auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + + auto xla_sharding = xla::HloSharding::Replicate().ToProto(); + std::vector denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7}; + torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); + XLATensor::ShardingSpec replicated(sharding, tensor_shape); + + // Test that identical replicated sharding specs are equal EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(replicated, replicated)); +} + +TEST_F(XLAShardingTest, EqualShardingSpecsTiledVsReplicated) { + auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + std::vector denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7}; + + // Create tiled sharding + auto xla_sharding_tiled = xla::HloSharding::Tile({ + {0, 1, 2, 3}, + {4, 5, 6, 7}, + }) + .ToProto(); + torch_xla::OpSharding sharding_tiled(xla_sharding_tiled, + denormalized_tile_assignment); + XLATensor::ShardingSpec tiled_2d(sharding_tiled, tensor_shape); + + // Create replicated sharding + auto xla_sharding_replicated = xla::HloSharding::Replicate().ToProto(); + torch_xla::OpSharding sharding_replicated(xla_sharding_replicated, + denormalized_tile_assignment); + XLATensor::ShardingSpec replicated(sharding_replicated, tensor_shape); + + // Test that tiled and replicated sharding specs are not equal EXPECT_FALSE(ShardingUtil::EqualShardingSpecs(tiled_2d, replicated)); } @@ -323,12 +504,17 @@ TEST_F(XLAShardingTest, CreateTensorsData) { std::vector devices(3); std::fill_n(devices.begin(), devices.size(), bridge::GetDefaultDevice()->toString()); + auto replicate_xla_sharding = xla::HloSharding::Replicate().ToProto(); + auto unknown_xla_sharding = xla::HloSharding::Unknown().ToProto(); + torch_xla::OpSharding replicate_sharding(replicate_xla_sharding, + std::nullopt); + torch_xla::OpSharding unknown_sharding(unknown_xla_sharding, std::nullopt); std::vector shardings = { nullptr, - std::make_shared( - xla::HloSharding::Replicate().ToProto(), tensor_shape), - std::make_shared( - xla::HloSharding::Unknown().ToProto(), tensor_shape)}; + std::make_shared(replicate_sharding, + tensor_shape), + std::make_shared(unknown_sharding, + tensor_shape)}; std::vector tensors_data = CreateTensorsData(tensors, shardings, devices); @@ -387,13 +573,29 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) { auto y = xla::Add(x, xla::ConstantR0(&b, 3)); xla::XlaComputation xla_computation = GetValueOrThrow(b.Build(/*remove_dynamic_dimensions=*/false)); + + std::vector tensors{XLATensor::Create( + torch_xla::runtime::GetComputationClientOrDie()->CreateDataPlaceholder( + bridge::GetDefaultDevice()->toString(), std::move(shape)))}; + std::vector> denormalized_tile_assignments; + for (auto tensor : tensors) { + auto sharding_spec = tensor->sharding_spec(); + if (sharding_spec) { + denormalized_tile_assignments.push_back( + sharding_spec->sharding.GetDenormalizedTileAssignment()); + } + } + std::vector instances; - instances.push_back({std::move(xla_computation), - bridge::GetDefaultDevice()->toString(), - {bridge::GetDefaultDevice()->toString()}, - &shape, - /*should_wrap_parameter=*/false, - /*is_sharded=*/true}); + instances.push_back( + {std::move(xla_computation), + bridge::GetDefaultDevice()->toString(), + {bridge::GetDefaultDevice()->toString()}, + &shape, + /*should_wrap_parameter=*/false, + /*is_sharded=*/true, + /*allow_spmd_sharding_propagation_to_output=*/true, + /*denormalized_tile_assignments=*/denormalized_tile_assignments}); std::vector< std::shared_ptr> @@ -404,9 +606,6 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) { "add", std::move(computations[0]->move_computation())); // Prepare output sharding propagation, expect a sharded output placeholder. - std::vector tensors{XLATensor::Create( - torch_xla::runtime::GetComputationClientOrDie()->CreateDataPlaceholder( - bridge::GetDefaultDevice()->toString(), std::move(shape)))}; std::vector data_placeholders; std::vector sharding_specs; ShardingUtil::PrepareOutputShardingPropagation( @@ -417,11 +616,12 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) { if (n_devices > 1) { // Tiled sharding requires multiple devices. EXPECT_TRUE(xla::protobuf_util::HaveSameSerialization( - tiled, sharding_specs[0]->sharding)); + tiled, sharding_specs[0]->sharding.GetXlaOpSharding())); } else { // Sincle device execution defaults to replication sharding. EXPECT_TRUE(xla::protobuf_util::HaveSameSerialization( - xla::HloSharding::Replicate().ToProto(), sharding_specs[0]->sharding)); + xla::HloSharding::Replicate().ToProto(), + sharding_specs[0]->sharding.GetXlaOpSharding())); } // Check if the placeholder is on a SPMD device (sharded) with no real values. diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index f99dca0a74e..017788433fa 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -128,6 +128,7 @@ ptxla_cc_library( ":shape_builder", ":shape_helper", ":status", + ":torch_xla_op_sharding", ":version", "//torch_xla/csrc:hash_util", "//torch_xla/csrc:thread_pool", @@ -316,6 +317,7 @@ ptxla_cc_library( ":shape_helper", ":status", ":unwrap_data", + ":torch_xla_op_sharding", "//torch_xla/csrc/runtime:cache", "//torch_xla/csrc/runtime:computation_client", "@com_google_absl//absl/log:absl_check", @@ -385,3 +387,13 @@ cc_library( "@com_google_absl//absl/status:statusor", ], ) + +cc_library( + name = "torch_xla_op_sharding", + srcs = ["torch_xla_op_sharding.cpp"], + hdrs = ["torch_xla_op_sharding.h"], + deps = [ + "//torch_xla/csrc/runtime:debug_macros", + "@xla//xla/hlo/builder:xla_builder", + ], +) diff --git a/torch_xla/csrc/debug_util.cpp b/torch_xla/csrc/debug_util.cpp index 812f7122efa..49e56e9164a 100644 --- a/torch_xla/csrc/debug_util.cpp +++ b/torch_xla/csrc/debug_util.cpp @@ -21,6 +21,7 @@ #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/status.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" #include "torch_xla/csrc/xla_graph_executor.h" namespace torch_xla { @@ -218,7 +219,8 @@ void DebugUtil::SaveOutputShardingInfo(std::vector* tensors, auto xtensor = (*tensors)[indices[i]]; ss << xtensor->shape().get().ToString() << " "; if (xtensor->sharding_spec()) { - ss << xla::HloSharding::FromProto(xtensor->sharding_spec()->sharding) + ss << xla::HloSharding::FromProto( + xtensor->sharding_spec()->sharding.GetXlaOpSharding()) ->ToString(); } else { ss << xla::HloSharding::FromProto(xla::HloSharding::Unknown().ToProto()) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a2099f7d4ec..c9cb8d275c2 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -72,6 +72,7 @@ #include "torch_xla/csrc/tensor_methods.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" #include "torch_xla/csrc/version.h" #include "torch_xla/csrc/xla_backend_impl.h" #include "torch_xla/csrc/xla_graph_executor.h" @@ -740,7 +741,8 @@ std::string GetTensorsHloGraph(const std::vector& tensors, std::string GetXLAShardingSpec(const XLATensorPtr xtensor) { auto sharding_spec = xtensor->sharding_spec(); if (sharding_spec != nullptr) { - auto hlo_sharding = xla::HloSharding::FromProto(sharding_spec->sharding); + auto hlo_sharding = + xla::HloSharding::FromProto(sharding_spec->sharding.GetXlaOpSharding()); return hlo_sharding->ToString(); } return std::string(); @@ -1540,7 +1542,7 @@ void InitXlaModuleBindings(py::module m) { runtime::ComputationClient::ComputationPtr>(m, "XlaComputation"); // Define the _XLAC.OpSharding class. - PythonScope>(m, "OpSharding") + PythonScope>(m, "OpSharding") .def_init([](const py::list& tile_assignment, const py::list& group_assignment, const py::list& replication_groups, int sharding_type) { @@ -2559,16 +2561,16 @@ void InitXlaModuleBindings(py::module m) { } }) .def("_xla_mark_sharding", - [](const at::Tensor& input, xla::OpSharding sharding) { + [](const at::Tensor& input, torch_xla::OpSharding sharding) { ShardingUtil::XlaMarkSharding(input, sharding); }) .def("_xla_annotate_custom_sharding", - [](const at::Tensor& input, xla::OpSharding sharding) { + [](const at::Tensor& input, torch_xla::OpSharding sharding) { XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); ShardingUtil::XlaAnnotateCustomSharding(xtensor, sharding); }) .def("_mark_manual_sharding", - [](const at::Tensor& input, xla::OpSharding sharding) { + [](const at::Tensor& input, torch_xla::OpSharding sharding) { XLA_CHECK(IsNonDeviceDataIR(input)) << "Marking any data tensors as manual is not supported"; ShardingUtil::XlaMarkSharding(input, sharding); @@ -2588,13 +2590,14 @@ void InitXlaModuleBindings(py::module m) { xtensor->CreateFrom(torch_xla::MakeNode( xtensor->GetIrValue(), shard_shape, CustomSharding::Type::kSPMDFullToShardShape)); - output->SetShardingSpec(XLATensor::ShardingSpec( - xla::HloSharding::Manual().ToProto(), shard_shape)); + torch_xla::OpSharding sharding(xla::HloSharding::Manual().ToProto(), + sharding_spec->sharding.GetDenormalizedTileAssignment()); + output->SetShardingSpec(XLATensor::ShardingSpec(sharding, shard_shape)); return bridge::AtenFromXlaTensor(output); }) .def( "_spmd_shard_to_full_shape", - [](const at::Tensor& input, const xla::OpSharding& sharding, + [](const at::Tensor& input, const torch_xla::OpSharding& sharding, const std::vector& output_shape, const py::object& output_dtype) -> at::Tensor { XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); @@ -2628,7 +2631,7 @@ void InitXlaModuleBindings(py::module m) { return GetXLAShardingSpec(xtensor); }) .def("_get_xla_op_sharding", - [](const at::Tensor& input) -> std::optional { + [](const at::Tensor& input) -> std::optional { XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); XLATensor::ShardingSpecPtr sharding_spec = xtensor ? xtensor->sharding_spec() : nullptr; @@ -2668,7 +2671,7 @@ void InitXlaModuleBindings(py::module m) { // `torch_xla.runtime.local_runtime_devices()`. "_global_tensor_from_cpu_shards", [](const std::vector& shards, - const xla::OpSharding& sharding, + const torch_xla::OpSharding& sharding, std::optional>& global_shape) -> at::Tensor { XLA_CHECK(UseVirtualDevice()) << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index 2e4338f50b7..8fa0ff2fc72 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -13,6 +13,7 @@ #include "torch_xla/csrc/runtime/cache.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" namespace torch_xla { namespace { @@ -167,12 +168,12 @@ torch::lazy::hash_t XlaNode::GetOpHash(torch::lazy::OpKind op, return torch::lazy::HashCombine(h, hash_seed); } -void XlaNode::SetSharding(const xla::OpSharding& sharding, size_t index) { +void XlaNode::SetSharding(const torch_xla::OpSharding& sharding, size_t index) { if (output_shardings_.size() == 0) { - output_shardings_ = - std::vector>(num_outputs(), nullptr); + output_shardings_ = std::vector>( + num_outputs(), nullptr); } - output_shardings_[index] = std::make_shared(sharding); + output_shardings_[index] = std::make_shared(sharding); // TODO(JackCaoG): fix this hashing UpdateShardingHash(); } @@ -207,7 +208,10 @@ void XlaNode::UpdateShardingHash() { for (size_t i = 0; i < output_shardings_.size(); i++) { // keep the index as part of the hash sharding_hash_ = torch::lazy::HashCombine(sharding_hash_, (uint32_t)i); - std::shared_ptr sharding = output_shardings_[i]; + std::shared_ptr sharding = + std::make_shared( + output_shardings_[i]->GetXlaOpSharding(), + output_shardings_[i]->GetDenormalizedTileAssignment()); // skip the hash compute for empty sharding if (!sharding) { continue; diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index c2595a433a9..d5c96dc1057 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -21,6 +21,7 @@ #include "absl/types/span.h" #include "torch_xla/csrc/dynamic_shape_detector.h" #include "torch_xla/csrc/runtime/types.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" #include "xla/hlo/builder/xla_builder.h" namespace torch_xla { @@ -133,14 +134,19 @@ class XlaNode : public torch::lazy::Node { torch::lazy::hash_t shardingHash() const { return sharding_hash_; } // The node's outputs get assigned the same HLO sharding - const std::shared_ptr GetSharding(size_t index) const { + const std::shared_ptr GetSharding(size_t index) const { if (output_shardings_.size() == 0) { return nullptr; } return output_shardings_[index]; } - void SetSharding(const xla::OpSharding& sharding, size_t index); + const std::vector> GetShardings() + const { + return output_shardings_; + } + + void SetSharding(const torch_xla::OpSharding& sharding, size_t index); void ClearSharding() { output_shardings_.clear(); @@ -180,7 +186,7 @@ class XlaNode : public torch::lazy::Node { torch::lazy::hash_t sharding_hash_ = 0; // Experimental sharding annotations attached to the IR node. - std::vector> output_shardings_; + std::vector> output_shardings_; }; inline std::ostream& operator<<(std::ostream& stream, const XlaNode& node) { diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index 65f438643f1..772a656f0c1 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -19,6 +19,7 @@ #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/stack_frame_index_builder.h" #include "torch_xla/csrc/status.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" namespace torch_xla { @@ -133,9 +134,9 @@ xla::XlaOp LoweringContext::GetParameter( const std::string param_name = absl::StrCat("p", param_index); xla::XlaOp param; if (data->HasSharding()) { - const xla::OpSharding sharding = data->GetSharding(); - const xla::XlaScopedShardingAssignment scoped_sharding(builder(), - sharding); + const torch_xla::OpSharding sharding = data->GetSharding(); + const xla::XlaScopedShardingAssignment scoped_sharding( + builder(), sharding.GetXlaOpSharding()); param = xla::Parameter(builder(), param_index, shape, param_name); } else { param = xla::Parameter(builder(), param_index, shape, param_name); @@ -237,6 +238,17 @@ xla::XlaOp LoweringContext::GetOutputOp(const torch::lazy::Output& output) { return it->second; } +void LoweringContext::ExtractShardingAndSetDenormalizedTileAssignments( + std::vector> shardings) { + for (auto sharding : shardings) { + std::vector denormalized_tile_assignment = + sharding->GetDenormalizedTileAssignment(); + if (!denormalized_tile_assignment.empty()) { + denormalized_tile_assignments_.push_back(denormalized_tile_assignment); + } + } +} + XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node& node) { XlaOpVector result_ops; try { @@ -244,6 +256,12 @@ XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node& node) { const XlaNode* const casted = dynamic_cast(&node); result_ops = casted->Lower(this); + // save the denormalized_tile_assignment from all nodes and then use it + // during Compile + auto shardings = casted->GetShardings(); + if (!shardings.empty()) { + ExtractShardingAndSetDenormalizedTileAssignments(shardings); + } if (!casted->dynamic_dims().empty()) { const xla::internal::XlaBuilderFriend builder_friend; auto* const inst = builder_friend.GetInstruction(result_ops[0]); diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index f0545534155..136af6c599a 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -117,6 +117,14 @@ class LoweringContext : public torch::lazy::LoweringContext { int64_t AddStackFrameLocation(const torch::lazy::SourceLocation& source, int64_t parent_id); + void ExtractShardingAndSetDenormalizedTileAssignments( + std::vector>); + + const std::vector>& GetDenormalizedTileAssignments() + const { + return denormalized_tile_assignments_; + } + private: struct Parameter { xla::XlaOp param; @@ -135,6 +143,7 @@ class LoweringContext : public torch::lazy::LoweringContext { std::string name_; std::shared_ptr stack_frame_index_builder_; + std::vector> denormalized_tile_assignments_; }; // namespace torch_xla } // namespace torch_xla diff --git a/torch_xla/csrc/ops/device_data.cpp b/torch_xla/csrc/ops/device_data.cpp index a5f5536b5b6..93fb4901920 100644 --- a/torch_xla/csrc/ops/device_data.cpp +++ b/torch_xla/csrc/ops/device_data.cpp @@ -6,6 +6,7 @@ #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" namespace torch_xla { @@ -16,7 +17,7 @@ DeviceData::DeviceData(std::shared_ptr data) /*num_outputs=*/1, /*hash_seed=*/(uint32_t)101), data_(std::move(data)) { - std::optional op_sharding = + std::optional op_sharding = torch_xla::runtime::GetComputationClientOrDie()->GetDataSharding( std::dynamic_pointer_cast(data_)); if (op_sharding.has_value()) { diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index b381d3feff7..6a99d2fabce 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -56,6 +56,7 @@ cc_library( "//torch_xla/csrc:device", "//torch_xla/csrc:dtype", "//torch_xla/csrc:status", + "//torch_xla/csrc:torch_xla_op_sharding", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -123,6 +124,7 @@ cc_library( ":tf_logging", ":xla_coordinator", "//torch_xla/csrc:status", + "//torch_xla/csrc:torch_xla_op_sharding", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 05478dc6cb4..fef86513540 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -26,6 +26,7 @@ #include "torch_xla/csrc/runtime/types.h" #include "torch_xla/csrc/runtime/util.h" #include "torch_xla/csrc/status.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal_util.h" @@ -80,7 +81,7 @@ class ComputationClient { virtual bool HasSharding() const = 0; - virtual xla::OpSharding GetSharding() const = 0; + virtual torch_xla::OpSharding GetSharding() const = 0; private: std::string xla_device_; @@ -228,6 +229,7 @@ class ComputationClient { std::vector devices, const xla::Shape* output_shape, bool parameter_is_tupled_arguments = false, bool is_sharded = false, bool allow_spmd_sharding_propagation_to_output = true, + std::vector> denormalized_tile_assignments = {}, bool use_auto_spmd_partitioning = false, std::vector auto_spmd_mesh_shape = {}, std::vector auto_spmd_mesh_ids = {}, bool eager_mode = false) @@ -239,6 +241,7 @@ class ComputationClient { is_sharded(is_sharded), allow_spmd_sharding_propagation_to_output( allow_spmd_sharding_propagation_to_output), + denormalized_tile_assignments(denormalized_tile_assignments), use_auto_spmd_partitioning(use_auto_spmd_partitioning), auto_spmd_mesh_shape(auto_spmd_mesh_shape), auto_spmd_mesh_ids(auto_spmd_mesh_ids), @@ -248,6 +251,7 @@ class ComputationClient { std::string compilation_device; std::vector devices; const xla::Shape* output_shape = nullptr; + std::vector> denormalized_tile_assignments; bool parameter_is_tupled_arguments; bool is_sharded; bool allow_spmd_sharding_propagation_to_output; @@ -273,7 +277,7 @@ class ComputationClient { // will be populated in an asynchrounous fashion. virtual DataPtr CreateDataPlaceholder( std::string device, xla::Shape shape, - std::optional sharding = std::nullopt) = 0; + std::optional sharding = std::nullopt) = 0; // Returns data shards. We expect this to be called on PjRtShardedData to // retrieve the shards. If other data type is passed, it returns the input @@ -286,11 +290,12 @@ class ComputationClient { // Returns wrapped data shards as PjRtShardedData. virtual DataPtr WrapDataShards(absl::Span shards, std::string device, xla::Shape shape, - xla::OpSharding sharding) = 0; + torch_xla::OpSharding sharding) = 0; // Returns OpSharding attached to PjRtShardedData. The returned optional // structure will be empty if there is no sharding, like with PjRtData. - virtual std::optional GetDataSharding(DataPtr handle) = 0; + virtual std::optional GetDataSharding( + DataPtr handle) = 0; virtual std::string PjRtDeviceToString( xla::PjRtDevice* const device) const = 0; @@ -303,13 +308,13 @@ class ComputationClient { // input sharding spec is identical to the target `sharding` sharding spec. virtual std::vector ReshardData( absl::Span handles, - absl::Span shardings) = 0; + absl::Span shardings) = 0; // Transfers local sharded tensor values to the TPU devices and returns a // `PjRtShardedData`. virtual DataPtr TransferShardsToDevice( absl::Span> tensor_shards, - std::string device, xla::Shape shape, xla::OpSharding sharding) = 0; + std::string device, xla::Shape shape, torch_xla::OpSharding sharding) = 0; // Copies `data->buffer` to `dst` device buffer. virtual DataPtr CopyToDevice(DataPtr data, std::string dst) = 0; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cpp b/torch_xla/csrc/runtime/ifrt_computation_client.cpp index d6337503508..619a54130c0 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cpp @@ -195,14 +195,14 @@ void IfrtComputationClient::IfrtData::Assign( } } -xla::OpSharding IfrtComputationClient::IfrtData::GetSharding() const { +torch_xla::OpSharding IfrtComputationClient::IfrtData::GetSharding() const { XLA_CHECK(HasSharding()) << "Check HasSharding first"; return *sharding_; } ComputationClient::DataPtr IfrtComputationClient::CreateDataPlaceholder( std::string device, xla::Shape shape, - std::optional sharding) { + std::optional sharding) { return std::make_shared(std::move(device), std::move(shape), tsl::RCReference(), std::move(sharding)); @@ -241,7 +241,7 @@ ComputationClient::DataPtr IfrtComputationClient::GetDataShard( ComputationClient::DataPtr IfrtComputationClient::WrapDataShards( absl::Span shards, std::string device, xla::Shape shape, - xla::OpSharding sharding) { + torch_xla::OpSharding sharding) { XLA_CHECK_EQ(shards.size(), client_->addressable_device_count()); std::vector> arrays; std::vector shard_shapes; @@ -277,7 +277,7 @@ ComputationClient::DataPtr IfrtComputationClient::WrapDataShards( return std::make_shared(device, shape, sharded_array, sharding); } -std::optional IfrtComputationClient::GetDataSharding( +std::optional IfrtComputationClient::GetDataSharding( DataPtr handle) { auto ifrt_data = std::dynamic_pointer_cast(handle); return ifrt_data->sharding_; @@ -324,7 +324,7 @@ std::vector IfrtComputationClient::TransferToDevice( ComputationClient::DataPtr IfrtComputationClient::TransferShardsToDevice( absl::Span> tensor_shards, - std::string device, xla::Shape shape, xla::OpSharding sharding) { + std::string device, xla::Shape shape, torch_xla::OpSharding sharding) { tsl::profiler::TraceMe activity( "IfrtComputationClient::TransferShardsToDevice", tsl::profiler::TraceMeLevel::kInfo); @@ -383,7 +383,7 @@ tsl::RCReference IfrtComputationClient::ReplicateShardedData( // TODO: handle replicated data xla::XlaBuilder builder("ReplicateShardedData"); xla::Shape shape = handle->shape(); - builder.SetSharding(handle->GetSharding()); + builder.SetSharding(handle->GetSharding().GetXlaOpSharding()); // perform a simple identity calculation to reassemble the input as // replicated output. @@ -482,6 +482,11 @@ std::vector IfrtComputationClient::Compile( client_->addressable_devices().end()}); for (auto& instance : instances) { + std::vector denormalized_tile_assignment; + if (!instance.denormalized_tile_assignments.empty()) { + denormalized_tile_assignment = instance.denormalized_tile_assignments[0]; + } + xla::CompileOptions compile_options; if (instance.is_sharded) { // TODO(yeounoh) multi-host, multi-slice configurations @@ -528,7 +533,8 @@ std::vector IfrtComputationClient::Compile( std::shared_ptr ifrt_computation = std::make_shared( std::move(xla::XlaComputation(hlo_modules[0]->ToProto())), - instance.devices, std::move(executable)); + instance.devices, std::move(executable), + denormalized_tile_assignment); computations.push_back(ifrt_computation); @@ -607,11 +613,13 @@ IfrtComputationClient::ExecuteReplicated( auto outputs = result.outputs; - const std::vector& output_shardings = - ifrt_computation.output_shardings_ + const std::vector& output_shardings = + ifrt_computation.output_shardings_.has_value() ? *ifrt_computation.output_shardings_ : std::vector(outputs.size(), - xla::HloSharding::Replicate().ToProto()); + torch_xla::OpSharding( + xla::HloSharding::Replicate().ToProto(), + ifrt_computation.denormalized_tile_assignment_)); ABSL_CHECK_EQ(output_shardings.size(), outputs.size()); std::vector data_handles(outputs.size()); diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index e1bcc751bbf..1126c73bb1c 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -12,6 +12,7 @@ #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/operation_manager.h" #include "torch_xla/csrc/runtime/util.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" #include "xla/pjrt/pjrt_client.h" @@ -42,23 +43,24 @@ class IfrtComputationClient : public ComputationClient { DataPtr CreateDataPlaceholder( std::string device, xla::Shape shape, - std::optional sharding = std::nullopt) override; + std::optional sharding = std::nullopt) override; std::vector GetDataShards(DataPtr data) override; DataPtr GetDataShard(DataPtr data, size_t index) override; DataPtr WrapDataShards(absl::Span shards, std::string device, - xla::Shape shape, xla::OpSharding sharding) override; + xla::Shape shape, + torch_xla::OpSharding sharding) override; - std::optional GetDataSharding(DataPtr handle) override; + std::optional GetDataSharding(DataPtr handle) override; std::vector TransferToDevice( absl::Span> tensors) override; std::vector ReshardData( absl::Span handles, - absl::Span shardings) override { + absl::Span shardings) override { XLA_ERROR() << __FUNCTION__ << " not implemented"; } @@ -71,7 +73,8 @@ class IfrtComputationClient : public ComputationClient { DataPtr TransferShardsToDevice( absl::Span> tensor_shards, - std::string device, xla::Shape shape, xla::OpSharding sharding) override; + std::string device, xla::Shape shape, + torch_xla::OpSharding sharding) override; DataPtr CopyToDevice(DataPtr data, std::string dst) override; @@ -209,13 +212,13 @@ class IfrtComputationClient : public ComputationClient { IfrtData(std::string device, xla::Shape device_shape, tsl::RCReference buffer, - std::optional sharding = std::nullopt) + std::optional sharding = std::nullopt) : Data(std::move(device), std::move(device_shape)), buffer(buffer), sharding_(sharding) {} IfrtData(std::string device, tsl::RCReference buffer, - std::optional sharding = std::nullopt) + std::optional sharding = std::nullopt) : Data(std::move(device), xla::ShapeUtil::MakeShape( xla::ifrt::ToPrimitiveType(buffer->dtype()).value(), @@ -238,7 +241,7 @@ class IfrtComputationClient : public ComputationClient { bool HasSharding() const override { return sharding_.has_value(); } - xla::OpSharding GetSharding() const override; + torch_xla::OpSharding GetSharding() const override; std::string ToString() const override { std::stringstream ss; @@ -248,7 +251,9 @@ class IfrtComputationClient : public ComputationClient { ss << " Data Device: " << device() << "\n"; ss << " Data Shape: " << shape().ToString() << "\n"; ss << " OpSharding: " - << xla::HloSharding::FromProto(*sharding_)->ToString() << "\n"; + << xla::HloSharding::FromProto(sharding_.value().GetXlaOpSharding()) + ->ToString() + << "\n"; ss << " NumShards: " << buffer->sharding().devices()->size() << "\n"; } else { ss << "XLAData: \n"; @@ -264,24 +269,42 @@ class IfrtComputationClient : public ComputationClient { return ss.str(); } - std::optional sharding_; + std::optional sharding_; tsl::RCReference buffer; }; tsl::RCReference ReplicateShardedData( const std::shared_ptr handle); + // TODO - move the below constructor to ifrt_computation_client.cpp + // issue link - https://github.com/pytorch/xla/issues/9572 struct IfrtComputation : public Computation { - IfrtComputation(xla::XlaComputation computation, - std::vector devices, - std::shared_ptr executable) + IfrtComputation( + xla::XlaComputation computation, std::vector devices, + std::shared_ptr executable, + std::optional> denormalized_tile_assignment) : Computation(std::move(computation), std::move(devices)), - executable(std::move(executable)) { - output_shardings_ = this->executable->GetOutputShardings(); + executable(std::move(executable)), + denormalized_tile_assignment_(std::move( + denormalized_tile_assignment.value_or(std::vector{}))) { + xla_output_shardings_ = this->executable->GetOutputShardings(); + output_shardings_ = std::nullopt; + if (xla_output_shardings_.has_value()) { + output_shardings_ = std::vector{}; + output_shardings_->reserve(xla_output_shardings_.value().size()); + for (const auto& sharding : xla_output_shardings_.value()) { + // convert each into torch_xla::OpSharding object + torch_xla::OpSharding torch_xla_op_sharding( + sharding, denormalized_tile_assignment_); + output_shardings_.value().push_back(torch_xla_op_sharding); + } + } } std::shared_ptr executable; - std::optional> output_shardings_; + std::optional> xla_output_shardings_; + std::optional> output_shardings_; + std::vector denormalized_tile_assignment_; }; }; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cpp b/torch_xla/csrc/runtime/pjrt_computation_client.cpp index 98ce8520da3..7d169116660 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cpp @@ -188,7 +188,7 @@ void PjRtComputationClient::PjRtData::Assign( ComputationClient::DataPtr PjRtComputationClient::CreateDataPlaceholder( std::string device, xla::Shape shape, - std::optional sharding) { + std::optional sharding) { if (sharding.has_value()) { return std::make_shared( std::move(device), std::move(shape), std::move(*sharding)); @@ -240,7 +240,7 @@ ComputationClient::DataPtr PjRtComputationClient::GetDataShard( ComputationClient::DataPtr PjRtComputationClient::WrapDataShards( absl::Span shards, std::string device, xla::Shape shape, - xla::OpSharding sharding) { + torch_xla::OpSharding sharding) { XLA_CHECK_EQ(shards.size(), client_->addressable_devices().size()); std::vector> pjrt_data_shards; pjrt_data_shards.reserve(shards.size()); @@ -254,12 +254,12 @@ ComputationClient::DataPtr PjRtComputationClient::WrapDataShards( sharding); } -std::optional PjRtComputationClient::GetDataSharding( +std::optional PjRtComputationClient::GetDataSharding( DataPtr handle) { if (auto sharded_data = dynamic_cast(handle.get())) { return sharded_data->GetSharding(); } - return std::optional(); + return std::optional(); } std::vector PjRtComputationClient::TransferToDevice( @@ -299,7 +299,7 @@ std::vector PjRtComputationClient::TransferToDevice( ComputationClient::DataPtr PjRtComputationClient::TransferShardsToDevice( absl::Span> tensor_shards, - std::string device, xla::Shape shape, xla::OpSharding sharding) { + std::string device, xla::Shape shape, torch_xla::OpSharding sharding) { tsl::profiler::TraceMe activity( "PjRtComputationClient::TransferShardsToDevice", tsl::profiler::TraceMeLevel::kInfo); @@ -355,7 +355,7 @@ PjRtComputationClient::ReplicateShardedData( } xla::XlaBuilder builder("ReplicateShardedData"); xla::Shape shape = sharded_data->shape(); - builder.SetSharding(sharded_data->GetSharding()); + builder.SetSharding(sharded_data->GetSharding().GetXlaOpSharding()); // perform a simple identity calculation to reassemble the input as // replicated output. @@ -404,7 +404,7 @@ PjRtComputationClient::ReplicateShardedData( std::vector PjRtComputationClient::ReshardData( absl::Span handles, - absl::Span shardings) { + absl::Span shardings) { tsl::profiler::TraceMe activity("ReshardData", tsl::profiler::TraceMeLevel::kInfo); XLA_COUNTER("ReshardData", 1); @@ -429,19 +429,20 @@ std::vector PjRtComputationClient::ReshardData( << "current device: " << handles[i]->device(); shapes.push_back(sharded_data->shape()); - const xla::OpSharding& sharding = shardings[i]; + const torch_xla::OpSharding& sharding = shardings[i]; XLA_CHECK_NE(sharding.type(), xla::OpSharding::UNKNOWN) << "Resharding by UNKNOWN sharding type is not allowed."; - hlo_shardings.push_back( - GetValueOrThrow(xla::HloSharding::FromProto(sharding))); + hlo_shardings.push_back(GetValueOrThrow( + xla::HloSharding::FromProto(sharding.GetXlaOpSharding()))); xla::OpSharding fallback_sharding; fallback_sharding.set_type(xla::OpSharding::REPLICATED); xla::XlaScopedShardingAssignment assign( - &builder, sharded_data->GetSharding().type() == xla::OpSharding::UNKNOWN + &builder, sharded_data->GetSharding().GetXlaOpSharding().type() == + xla::OpSharding::UNKNOWN ? fallback_sharding - : sharded_data->GetSharding()); + : sharded_data->GetSharding().GetXlaOpSharding()); param_ops.push_back( xla::Parameter(&builder, i, shapes[i], absl::StrCat("p.", i))); } @@ -553,6 +554,10 @@ std::vector PjRtComputationClient::Compile( runtime::sys_util::GetEnvBool("ENABLE_COLLECTIVE_MATMUL_IN_MP", false); for (auto& instance : instances) { + std::vector denormalized_tile_assignment; + if (!instance.denormalized_tile_assignments.empty()) { + denormalized_tile_assignment = instance.denormalized_tile_assignments[0]; + } xla::CompileOptions compile_options; if (enable_cm_in_mp) { compile_options.executable_build_options.set_use_spmd_partitioning(true); @@ -665,7 +670,8 @@ std::vector PjRtComputationClient::Compile( std::shared_ptr pjrt_computation = std::make_shared( std::move(xla::XlaComputation(hlo_modules[0]->ToProto())), - instance.devices, std::move(executable)); + instance.devices, std::move(executable), + denormalized_tile_assignment); computations.push_back(pjrt_computation); @@ -712,7 +718,8 @@ ComputationClient::ComputationPtr PjRtComputationClient::DeserializeComputation( std::vector devices = {UseVirtualDevice() ? spmd_device_str : GetDefaultDevice()}; return std::make_shared(std::move(computation), devices, - std::move(loaded_executable)); + std::move(loaded_executable), + std::nullopt); } torch::lazy::hash_t PjRtComputationClient::HashCompilationEnv() { @@ -901,13 +908,14 @@ PjRtComputationClient::ExecuteReplicated( : std::vector({result_shape}); ABSL_CHECK_EQ(output_shapes.size(), num_outputs); - const std::vector& output_shardings = - pjrt_computation.output_shardings_.has_value() && num_outputs > 0 + const std::vector& output_shardings = + (pjrt_computation.output_shardings_.has_value() && + !pjrt_computation.output_shardings_.value().empty() && num_outputs > 0) ? *pjrt_computation.output_shardings_ : // Without an explicit sharding annotation, the output is implicitly // replicated, and we mark explicitly replicated here. - std::vector(num_outputs); + std::vector(num_outputs); ABSL_CHECK_EQ(output_shardings.size(), num_outputs); absl::BlockingCounter counter(num_outputs); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 3c13d3489ca..02ce8683356 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -12,6 +12,7 @@ #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/operation_manager.h" #include "torch_xla/csrc/runtime/util.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" #include "tsl/platform/env.h" #include "tsl/platform/threadpool.h" #include "xla/hlo/builder/xla_computation.h" @@ -40,7 +41,7 @@ class PjRtComputationClient : public ComputationClient { DataPtr CreateDataPlaceholder( std::string device, xla::Shape shape, - std::optional sharding = std::nullopt) override; + std::optional sharding = std::nullopt) override; static DataPtr CreateData(std::string device, xla::Shape shape, std::shared_ptr pjrt_buffer); @@ -50,9 +51,10 @@ class PjRtComputationClient : public ComputationClient { DataPtr GetDataShard(DataPtr data, size_t index) override; DataPtr WrapDataShards(absl::Span shards, std::string device, - xla::Shape shape, xla::OpSharding sharding) override; + xla::Shape shape, + torch_xla::OpSharding sharding) override; - std::optional GetDataSharding(DataPtr handle) override; + std::optional GetDataSharding(DataPtr handle) override; std::vector TransferToDevice( absl::Span> tensors) override; @@ -63,7 +65,7 @@ class PjRtComputationClient : public ComputationClient { // TODO(yeounoh) replace ReplicateShardedData with this. std::vector ReshardData( absl::Span handles, - absl::Span shardings) override; + absl::Span shardings) override; absl::StatusOr> TransferFromDevice( absl::Span handles) override; @@ -74,7 +76,8 @@ class PjRtComputationClient : public ComputationClient { DataPtr TransferShardsToDevice( absl::Span> tensor_shards, - std::string device, xla::Shape shape, xla::OpSharding sharding) override; + std::string device, xla::Shape shape, + torch_xla::OpSharding sharding) override; DataPtr CopyToDevice(DataPtr data, std::string dst) override; @@ -242,10 +245,10 @@ class PjRtComputationClient : public ComputationClient { bool HasSharding() const override { return false; } - xla::OpSharding GetSharding() const override { + torch_xla::OpSharding GetSharding() const override { XLA_CHECK(false) << "GetSharding should not be called on PjRtData, check " "HasSharding first"; - return xla::OpSharding(); + return torch_xla::OpSharding(); } std::string ToString() const override { @@ -269,12 +272,12 @@ class PjRtComputationClient : public ComputationClient { PjRtShardedData(std::string device, xla::Shape shape) = delete; PjRtShardedData(std::string device, xla::Shape shape, - xla::OpSharding sharding) + torch_xla::OpSharding sharding) : Data(std::move(device), std::move(shape)), sharding(sharding) {} PjRtShardedData(std::string device, xla::Shape shape, std::vector> shards, - xla::OpSharding sharding) + torch_xla::OpSharding sharding) : Data(std::move(device), std::move(shape)), shards(shards), sharding(sharding) {} @@ -311,26 +314,51 @@ class PjRtComputationClient : public ComputationClient { ss << " Data Device: " << device() << "\n"; ss << " Data Shape: " << shape().ToString() << "\n"; ss << " OpSharding: " - << xla::HloSharding::FromProto(sharding)->ToString() << "\n"; + << xla::HloSharding::FromProto(sharding.GetXlaOpSharding())->ToString() + << "\n"; ss << " NumShards: " << shards.size() << "\n"; return ss.str(); } bool HasSharding() const override { return true; } - xla::OpSharding GetSharding() const override { return sharding; } + torch_xla::OpSharding GetSharding() const override { return sharding; } std::vector> shards; - xla::OpSharding sharding; + torch_xla::OpSharding sharding; }; + // TODO - move the below constructor to pjrt_computation_client.cpp + // issue link - https://github.com/pytorch/xla/issues/9572 struct PjRtComputation : public Computation { - PjRtComputation(xla::XlaComputation computation, - std::vector devices, - std::unique_ptr executable) + /** + * Constructs a PjRtComputation with the given parameters. + * + * @param computation The XLA computation to wrap + * @param devices List of device strings for execution + * @param executable The compiled PJRT executable + * @param denormalized_tile_assignment Optional tile assignment for sharding + */ + PjRtComputation( + xla::XlaComputation computation, std::vector devices, + std::unique_ptr executable, + std::optional> denormalized_tile_assignment) : Computation(std::move(computation), std::move(devices)), - executable(std::move(executable)) { - output_shardings_ = this->executable->GetOutputShardings(); + executable(std::move(executable)), + denormalized_tile_assignment_(std::move( + denormalized_tile_assignment.value_or(std::vector{}))) { + xla_output_shardings_ = this->executable->GetOutputShardings(); + output_shardings_ = std::nullopt; + if (xla_output_shardings_.has_value()) { + output_shardings_ = std::vector{}; + output_shardings_->reserve(xla_output_shardings_.value().size()); + for (const auto& sharding : xla_output_shardings_.value()) { + // convert each into torch_xla::OpSharding object + torch_xla::OpSharding torch_xla_op_sharding( + sharding, denormalized_tile_assignment_); + output_shardings_.value().push_back(torch_xla_op_sharding); + } + } } const std::string get_memory_info() const override { @@ -343,7 +371,10 @@ class PjRtComputationClient : public ComputationClient { } std::unique_ptr executable; - std::optional> output_shardings_; + std::optional> xla_output_shardings_; + std::optional> output_shardings_; + std::vector denormalized_tile_assignment_; + ; }; // Use XLA replication to re-assemble the sharded data. diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 69bd8aa3a56..db9352a1bf7 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -11,6 +11,7 @@ #include "absl/base/nullability.h" #include "torch_xla/csrc/runtime/util.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" #include "torch_xla/csrc/view.h" namespace torch_xla { @@ -263,13 +264,13 @@ class XLATensor : public torch::lazy::LazyTensor { // XLA SPMD sharding spec annoation. The XLA tensor uses this to create // HloSharding for replication, manual and tile shardings. struct ShardingSpec { - ShardingSpec(const xla::OpSharding& sharding, const xla::Shape& shape) + ShardingSpec(const torch_xla::OpSharding& sharding, const xla::Shape& shape) : sharding(sharding), shape(shape) {} - ShardingSpec(const xla::OpSharding& sharding, const xla::Shape& shape, + ShardingSpec(const torch_xla::OpSharding& sharding, const xla::Shape& shape, const bool& minibatch) : sharding(sharding), shape(shape), minibatch(minibatch) {} - xla::OpSharding sharding = xla::HloSharding::Unknown().ToProto(); + torch_xla::OpSharding sharding; // Optional source tensor shape unpartitioned. xla::Shape shape; // Parameter for represent input batch in sharded along batch axes diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index bf5f7966f8f..0326fe0848a 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -838,6 +838,39 @@ std::vector CreateTensorsData( runtime::GetComputationClientOrDie()->TransferToDevice(source_tensors)); } +namespace { + +/** + * Filters a list of device strings to include only those with IDs matching + * the provided indices. + * + * @param devices List of device strings in format "TYPE:ID" (e.g., "TPU:0") + * @param indices List of device IDs to filter by + * @return Filtered list of device strings, or error status if parsing fails + * + * Example: + * devices = ["TPU:0", "TPU:1", "TPU:2", "TPU:3"] + * indices = [1, 3] + * result = ["TPU:1", "TPU:3"] + */ +std::vector FilterDevicesByAddressableDevices( + std::vector devices, const std::vector& indices) { + std::vector filtered_devices_; + filtered_devices_.reserve(indices.size()); + for (auto& index : indices) { + for (auto& device : devices) { + std::vector device_spec_parts = absl::StrSplit(device, ':'); + if (std::stoi(device_spec_parts[1]) == index) { + filtered_devices_.push_back(device); + break; + } + } + } + return filtered_devices_; +} + +} // namespace + std::vector CreateTensorsData( const std::vector& tensors, const std::vector& shardings, @@ -860,14 +893,25 @@ std::vector CreateTensorsData( std::vector local_devices = runtime::GetComputationClientOrDie()->GetLocalDevices(); + std::vector addressable_devices = std::move(local_devices); + if (shardings[i]) { + const std::vector& denormalized_tile_assignment = + shardings[i]->sharding.GetDenormalizedTileAssignment(); + if ((!denormalized_tile_assignment.empty()) && + (denormalized_tile_assignment.size() != + addressable_devices.size())) { + addressable_devices = FilterDevicesByAddressableDevices( + addressable_devices, denormalized_tile_assignment); + } + } // Shards the input tensors with padding, to split evenly. // The execution requires consistent shard sizes, and the zero-padded // values should be ignored. - std::vector local_shards = - ShardingUtil::ShardTensor(tensors[i], shardings[i], local_devices, - /*padded=*/true); + std::vector local_shards = ShardingUtil::ShardTensor( + tensors[i], shardings[i], addressable_devices, + /*padded=*/true); new_handles.push_back(ShardingUtil::CreateShardedData( - local_shards, local_devices, shardings[i])); + local_shards, addressable_devices, shardings[i])); } else { source_tensors.push_back(std::make_shared( tensors[i], std::move(shape), devices[i])); diff --git a/torch_xla/csrc/torch_xla_op_sharding.cpp b/torch_xla/csrc/torch_xla_op_sharding.cpp new file mode 100644 index 00000000000..6ad24d32015 --- /dev/null +++ b/torch_xla/csrc/torch_xla_op_sharding.cpp @@ -0,0 +1,121 @@ +#include "torch_xla_op_sharding.h" + +namespace torch_xla { + +OpSharding::OpSharding() {} + +OpSharding::OpSharding( + const xla::OpSharding& op_sharding, + const std::optional>& denormalized_tile_assignment) + : op_sharding_(std::make_unique(op_sharding)), + denormalized_tile_assignment_( + denormalized_tile_assignment.value_or(std::vector{})) {} + +OpSharding::OpSharding(const OpSharding& other) + : denormalized_tile_assignment_(other.denormalized_tile_assignment_) { + if (other.op_sharding_) { + op_sharding_ = std::make_unique(*other.op_sharding_); + } else { + // Fallback to default replicated sharding + op_sharding_ = std::make_unique(); + op_sharding_->set_type(xla::OpSharding::REPLICATED); + } +} + +OpSharding& OpSharding::operator=(const OpSharding& other) { + if (this != &other) { + if (other.op_sharding_) { + op_sharding_ = std::make_unique(*other.op_sharding_); + } else { + // Fallback to default replicated sharding + op_sharding_ = std::make_unique(); + op_sharding_->set_type(xla::OpSharding::REPLICATED); + } + denormalized_tile_assignment_ = other.denormalized_tile_assignment_; + } + return *this; +} + +OpSharding::OpSharding(OpSharding&& other) noexcept + : op_sharding_(std::move(other.op_sharding_)), + denormalized_tile_assignment_( + std::move(other.denormalized_tile_assignment_)) { + // other.op_sharding_ is now nullptr, which is safe +} + +OpSharding& OpSharding::operator=(OpSharding&& other) noexcept { + if (this != &other) { + op_sharding_ = std::move(other.op_sharding_); + denormalized_tile_assignment_ = + std::move(other.denormalized_tile_assignment_); + } + return *this; +} + +// Forwarded methods from xla::OpSharding for API compatibility +xla::OpSharding::Type OpSharding::type() const { return op_sharding_->type(); } + +bool OpSharding::replicate_on_last_tile_dim() const { + return op_sharding_->replicate_on_last_tile_dim(); +} + +int OpSharding::tile_assignment_dimensions_size() const { + return op_sharding_->tile_assignment_dimensions_size(); +} + +int OpSharding::tile_assignment_devices_size() const { + return op_sharding_->tile_assignment_devices_size(); +} + +int OpSharding::tile_assignment_dimensions(int index) const { + return op_sharding_->tile_assignment_dimensions(index); +} + +int OpSharding::tile_assignment_devices(int index) const { + return op_sharding_->tile_assignment_devices(index); +} + +std::string OpSharding::DebugString() const { + return op_sharding_->DebugString(); +} + +const ::google::protobuf::RepeatedField& +OpSharding::iota_reshape_dims() const { + return op_sharding_->iota_reshape_dims(); +} + +const ::google::protobuf::RepeatedField& +OpSharding::tile_assignment_dimensions() const { + return op_sharding_->tile_assignment_dimensions(); +} + +const ::google::protobuf::RepeatedField& +OpSharding::tile_assignment_devices() const { + return op_sharding_->tile_assignment_devices(); +} + +const ::google::protobuf::RepeatedField& +OpSharding::iota_transpose_perm() const { + return op_sharding_->iota_transpose_perm(); +} + +const ::google::protobuf::RepeatedField& OpSharding::last_tile_dims() + const { + return op_sharding_->last_tile_dims(); +} + +const xla::ShapeProto& OpSharding::tile_shape() const { + return op_sharding_->tile_shape(); +} + +const xla::OpSharding& OpSharding::GetXlaOpSharding() const { + return *op_sharding_; +} + +xla::OpSharding& OpSharding::GetMutableXlaOpSharding() { return *op_sharding_; } + +const std::vector& OpSharding::GetDenormalizedTileAssignment() const { + return denormalized_tile_assignment_; +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/torch_xla_op_sharding.h b/torch_xla/csrc/torch_xla_op_sharding.h new file mode 100644 index 00000000000..aa126b48f58 --- /dev/null +++ b/torch_xla/csrc/torch_xla_op_sharding.h @@ -0,0 +1,81 @@ +#ifndef XLA_TORCH_XLA_CSRC_TORCH_XLA_OP_SHARDING_H_ +#define XLA_TORCH_XLA_CSRC_TORCH_XLA_OP_SHARDING_H_ + +#include +#include +#include +#include + +#include "google/protobuf/repeated_field.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/shape.h" + +namespace torch_xla { + +// Wrapper class for xla::OpSharding that provides additional functionality +// and maintains denormalized tile assignment information. +// +// This class serves as a bridge between PyTorch XLA's sharding representation +// and XLA's native OpSharding, allowing for extended functionality while +// maintaining compatibility with the underlying XLA infrastructure. +class OpSharding { + public: + // Default constructor + OpSharding(); + + // Constructs OpSharding from xla::OpSharding with optional denormalized + // tile assignment + explicit OpSharding(const xla::OpSharding& op_sharding, + const std::optional>& + denormalized_tile_assignment = std::nullopt); + + // Copy constructor + OpSharding(const OpSharding& other); + + // Copy assignment operator + OpSharding& operator=(const OpSharding& other); + + // Move constructor + OpSharding(OpSharding&& other) noexcept; + + // Move assignment operator + OpSharding& operator=(OpSharding&& other) noexcept; + + // Destructor (default is sufficient due to unique_ptr) + ~OpSharding() = default; + + // Forwarded methods from xla::OpSharding for API compatibility + xla::OpSharding::Type type() const; + bool replicate_on_last_tile_dim() const; + int tile_assignment_dimensions_size() const; + int tile_assignment_devices_size() const; + int tile_assignment_dimensions(int index) const; + int tile_assignment_devices(int index) const; + std::string DebugString() const; + const ::google::protobuf::RepeatedField& iota_reshape_dims() const; + const ::google::protobuf::RepeatedField& tile_assignment_dimensions() + const; + const ::google::protobuf::RepeatedField& tile_assignment_devices() + const; + const ::google::protobuf::RepeatedField& iota_transpose_perm() const; + const ::google::protobuf::RepeatedField& last_tile_dims() const; + const xla::ShapeProto& tile_shape() const; + + // Access to underlying xla::OpSharding + const xla::OpSharding& GetXlaOpSharding() const; + xla::OpSharding& GetMutableXlaOpSharding(); + + // Access to denormalized tile assignment + const std::vector& GetDenormalizedTileAssignment() const; + + private: + // Underlying XLA OpSharding object + std::unique_ptr op_sharding_; + + // Additional denormalized tile assignment information + std::vector denormalized_tile_assignment_; +}; + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_TORCH_XLA_OP_SHARDING_H_ diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 15f38cae233..378f41934a5 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -1186,18 +1186,26 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( std::vector* tensors, SyncTensorCollection* coll, std::vector parameters_data, std::string device, ComputationCache::TypePtr cached_computation, - const std::vector& tensor_data_vec) { + const std::vector& tensor_data_vec, + std::optional>> + denormalized_tile_assignments) { auto tensors_data = SetTensorData(tensors, coll->config, coll->indices, tensor_data_vec); std::vector sharding_specs(coll->indices.size(), nullptr); + std::optional> denormalized_tile_assignment = + std::nullopt; + if ((denormalized_tile_assignments.has_value()) && + (!denormalized_tile_assignments.value().empty())) { + denormalized_tile_assignment = denormalized_tile_assignments.value()[0]; + } // Extract sharding specs for the results and prepare the sharded data // placeholders if the computation is sharded. if (cached_computation->is_sharded) { ShardingUtil::PrepareOutputShardingPropagation( tensors, coll->indices, cached_computation->computation, &tensors_data, - &sharding_specs); + &sharding_specs, denormalized_tile_assignment); DebugUtil::SaveOutputShardingInfo(tensors, coll->indices); } @@ -1258,15 +1266,37 @@ XLAGraphExecutor::TryRunCachedSync( << torch::lazy::Hash(po_data->parameter_sequence); } + std::vector> denormalized_tile_assignments; + // Extract denormalized tile assignments from all nodes in the post-order + // traversal. This iterates through each node in the computation graph and + // collects sharding information that will be used during compilation. The + // denormalized tile assignments represent how tensors are distributed across + // devices in localized SPMD/submesh execution mode. + for (const auto* node : po_data->post_order) { + const XlaNode* const casted = dynamic_cast(node); + auto shardings = casted->GetShardings(); + if (!shardings.empty()) { + // For each sharding specification on this node, extract the denormalized + // tile assignment which describes the physical device placements + for (auto sharding : shardings) { + std::vector denormalized_tile_assignment = + sharding->GetDenormalizedTileAssignment(); + if (!denormalized_tile_assignment.empty()) { + denormalized_tile_assignments.push_back(denormalized_tile_assignment); + } + } + } + } // don't schedule the execution if the purpose of this SyncTensor is just to // warm up the cache. return std::pair>( - cache_hit, warm_up_cache_only - ? nullptr - : ScheduleSyncTensorsGraph( - tensors, coll, std::move(po_data->parameters_data), - coll->device.toString(), - std::move(cached_computation), tensor_data_vec)); + cache_hit, + warm_up_cache_only + ? nullptr + : ScheduleSyncTensorsGraph( + tensors, coll, std::move(po_data->parameters_data), + coll->device.toString(), std::move(cached_computation), + tensor_data_vec, denormalized_tile_assignments)); } std::vector GetBufferDonorIndexForStepMarker( @@ -1442,13 +1472,16 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( } xla::Shape shape = MakeShapeWithDeviceLayout( program_shape.result(), static_cast(coll.device.type())); - + std::vector> denormalized_tile_assignments = + lowering_ctx.GetDenormalizedTileAssignments(); std::vector instances; instances.push_back( {std::move(computation), coll.device.toString(), runtime::GetComputationClientOrDie()->GetCompilationDevices( coll.device.toString(), devices), - &shape, should_wrap_parameter, is_sharded}); + &shape, should_wrap_parameter, is_sharded, + /*allow_spmd_sharding_propagation_to_output=*/true, + /*denormalized_tile_assignments=*/denormalized_tile_assignments}); instances.front().eager_mode = UseEagerMode(); if (use_autosharding) { TF_VLOG(5) << "use_auto_spmd_partitioning is set."; @@ -1511,7 +1544,8 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( /*emitted_nodes=*/lowering_ctx.GetEmittedNodeCount(), /*computation=*/computations.front(), /*parameters_data=*/std::move(po_data->parameters_data), - /*is_sharded=*/is_sharded}; + /*is_sharded=*/is_sharded, + /*denormalized_tile_assignments=*/denormalized_tile_assignments}; } std::shared_ptr @@ -1587,7 +1621,7 @@ XLAGraphExecutor::SyncTensorsGraphInternal( return ScheduleSyncTensorsGraph( tensors, &coll, std::move(compile_result.parameters_data), compile_result.device.toString(), std::move(cached_computation), - tensor_data_vec); + tensor_data_vec, compile_result.denormalized_tile_assignments); } } diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index 5e1acaa9d47..8467f54a2e4 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -225,6 +225,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { runtime::ComputationClient::ComputationPtr computation; std::vector parameters_data; bool is_sharded = false; + std::vector> denormalized_tile_assignments = {}; }; struct Async : public torch::lazy::LazyGraphExecutor::Async { @@ -359,7 +360,9 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { std::vector* tensors, SyncTensorCollection* coll, std::vector parameters_data, std::string device, ComputationCache::TypePtr cached_computation, - const std::vector& tensor_data_vec); + const std::vector& tensor_data_vec, + std::optional>> + denormalized_tile_assignments = std::nullopt); // Override to enable profiler. PostOrderData RunPostOrder(const std::vector& ir_values, diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index abc18206d42..4696ac30862 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -19,6 +19,7 @@ #include "torch_xla/csrc/tensor_methods.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/thread_pool.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" #include "torch_xla/csrc/xla_graph_executor.h" #include "tsl/profiler/lib/traceme.h" #include "xla/execution_options_util.h" @@ -176,10 +177,11 @@ bool ShardingUtil::SetHloSharding(LoweringContext* lowering_ctx) { const XlaNode* xla_node = dynamic_cast(node); xla::HloInstructionProto* instruction = XlaBuilderFriend::GetInstruction(elem.second); - const std::shared_ptr sharding = + const std::shared_ptr sharding = xla_node->GetSharding(elem.first.index); if (sharding != nullptr && sharding->type() != xla::OpSharding::UNKNOWN) { - *instruction->mutable_sharding() = *sharding; + const xla::OpSharding& sharding_obj = sharding->GetMutableXlaOpSharding(); + *instruction->mutable_sharding() = sharding_obj; is_sharded = true; } } @@ -187,7 +189,7 @@ bool ShardingUtil::SetHloSharding(LoweringContext* lowering_ctx) { } ShardingUtil::ShardingType ShardingUtil::GetShardingType( - xla::OpSharding& sharding) { + torch_xla::OpSharding& sharding) { switch (sharding.type()) { case xla::OpSharding::REPLICATED: return ShardingType::REPLICATED; @@ -210,15 +212,17 @@ ShardingUtil::ShardingType ShardingUtil::GetShardingType( bool ShardingUtil::EqualShardingSpecs(const XLATensor::ShardingSpec& a, const XLATensor::ShardingSpec& b) { - return xla::protobuf_util::HaveSameSerialization(a.sharding, b.sharding); + return xla::protobuf_util::HaveSameSerialization( + a.sharding.GetXlaOpSharding(), b.sharding.GetXlaOpSharding()); } -bool ShardingUtil::EqualOpShardings(const xla::OpSharding& a, - const xla::OpSharding& b) { - return xla::protobuf_util::HaveSameSerialization(a, b); +bool ShardingUtil::EqualOpShardings(const torch_xla::OpSharding& a, + const torch_xla::OpSharding& b) { + return xla::protobuf_util::HaveSameSerialization(a.GetXlaOpSharding(), + b.GetXlaOpSharding()); } -xla::OpSharding ShardingUtil::CreateOpSharding( +torch_xla::OpSharding ShardingUtil::CreateOpSharding( const py::list& tile_assignment, const py::list& group_assignment, const py::list& replication_groups, ShardingType sharding_type) { TORCH_LAZY_COUNTER("CreateOpSharding", 1); @@ -281,7 +285,29 @@ xla::OpSharding ShardingUtil::CreateOpSharding( TF_LOG(ERROR) << "Invalid arguments: sharding_type " << sharding_type; } } - return sharding; + + // Create denormalized_tile_assignment. If sharding.tile_assignment_devices() + // is empty (which happens for REPLICATED, MANUAL, UNKNOWN sharding types), + // use the original tile_assignment arg that was passed to this function. + std::vector denormalized_tile_assignment; + if (sharding.tile_assignment_devices().empty() && !tile_assignment.empty()) { + // Convert the Python list tile_assignment to a flattened vector for + // denormalized assignment + xla::Array tile_array = TileListToArray(tile_assignment); + denormalized_tile_assignment.assign(tile_array.begin(), tile_array.end()); + } else { + // Use the tile_assignment_devices from the XLA OpSharding object + denormalized_tile_assignment.assign( + sharding.tile_assignment_devices().begin(), + sharding.tile_assignment_devices().end()); + } + + // Use the xla::OpSharding object in the wrapper torch_xla::OpSharding along + // with denormalized_tile_assignment (original tile_assignment) for extended + // functionality + torch_xla::OpSharding torch_xla_opsharding(sharding, + denormalized_tile_assignment); + return torch_xla_opsharding; } std::vector ShardingUtil::GetShardShape( @@ -336,7 +362,8 @@ ShardingUtil::GetShardIndicesForMinibatchTensor( std::vector>> ShardingUtil::GetShardReplicaAndIndicesForDevices( const std::vector& shard_shape, - const std::vector& tensor_shape, const xla::OpSharding sharding, + const std::vector& tensor_shape, + const torch_xla::OpSharding sharding, const std::vector& devices) { using namespace at::indexing; @@ -359,9 +386,8 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices( } } else if (sharding.type() == xla::OpSharding::OTHER) { auto device_index = build_index_map(devices); - std::vector tile_assignment_devices( - sharding.tile_assignment_devices().begin(), - sharding.tile_assignment_devices().end()); + std::vector tile_assignment_devices = + sharding.GetDenormalizedTileAssignment(); if (!sharding.iota_reshape_dims().empty()) { auto tileAssignment = xla::TileAssignment( sharding.tile_assignment_dimensions(), sharding.iota_reshape_dims(), @@ -421,7 +447,7 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices( std::vector ShardingUtil::ShardTensor( const at::Tensor& tensor, const XLATensor::ShardingSpecPtr shardings, const std::vector& devices, bool padded) { - xla::OpSharding sharding; + torch_xla::OpSharding sharding; bool minibatch = false; if (shardings != nullptr) { sharding = shardings->sharding; @@ -480,28 +506,40 @@ std::vector ShardingUtil::ShardTensor( std::vector ShardingUtil::GetOutputSharding( const std::vector& output_shapes, - runtime::ComputationClient::ComputationPtr computation) { + runtime::ComputationClient::ComputationPtr computation, + std::optional> denormalized_tile_assignment) { const auto& computation_proto = computation->computation().proto(); size_t num_outputs = output_shapes.size(); - std::vector output_shardings; + std::vector xla_output_shardings; std::vector sharding_specs(num_outputs); if (computation_proto.has_spmd_output_sharding()) { if (computation_proto.spmd_output_sharding().tuple_shardings().size() > 0) { auto tuple_shardings = computation_proto.spmd_output_sharding().tuple_shardings(); - output_shardings = std::vector(tuple_shardings.begin(), - tuple_shardings.end()); + xla_output_shardings = std::vector( + tuple_shardings.begin(), tuple_shardings.end()); } else { - output_shardings = std::vector{ + xla_output_shardings = std::vector{ computation_proto.spmd_output_sharding()}; } } // Output parameter sharding annotations, defaults to REPLICATED(0) if // unset. - if (output_shardings.empty()) { + if (xla_output_shardings.empty()) { // Initializes with default sharding type, REPLCIATED. - output_shardings.resize(num_outputs); + xla_output_shardings.resize(num_outputs); + } + + std::vector output_shardings; + for (const auto& sharding : xla_output_shardings) { + if ((denormalized_tile_assignment.has_value()) && + (!denormalized_tile_assignment.value().empty())) { + output_shardings.emplace_back(sharding, + denormalized_tile_assignment.value()); + } else { + output_shardings.emplace_back(sharding, std::nullopt); + } } for (int i = 0; i < num_outputs; ++i) { @@ -535,7 +573,8 @@ void ShardingUtil::PrepareOutputShardingPropagation( std::vector* tensors, absl::Span indices, runtime::ComputationClient::ComputationPtr computation, std::vector* data_placeholders, - std::vector* sharding_specs) { + std::vector* sharding_specs, + std::optional> denormalized_tile_assignment) { // Resizes the containers to `indices.size()`. data_placeholders->resize(indices.size()); sharding_specs->resize(indices.size()); @@ -546,7 +585,8 @@ void ShardingUtil::PrepareOutputShardingPropagation( auto xtensor = (*tensors)[indices[i]]; output_shapes.push_back(xtensor->shape().get()); } - auto new_sharding_specs = GetOutputSharding(output_shapes, computation); + auto new_sharding_specs = GetOutputSharding(output_shapes, computation, + denormalized_tile_assignment); XLA_CHECK(indices.size() == new_sharding_specs.size()) << "Expected size: " << indices.size() << ", actual size: " << new_sharding_specs.size(); @@ -577,31 +617,46 @@ void ShardingUtil::PrepareOutputShardingPropagation( } } -runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( +namespace { +// helper function to get torch_xla::OpSharding +torch_xla::OpSharding GetTorchXlaOpSharding( + xla::Shape* global_shape, const XLATensor::ShardingSpecPtr& sharding_spec, const std::vector& local_shards, - const std::vector& devices, - const XLATensor::ShardingSpecPtr& sharding_spec) { - XLA_CHECK(local_shards.size() == devices.size()) - << "A device must be speficied for each shard"; - std::vector> source_tensors; - xla::Shape global_shape; - xla::OpSharding sharding; + const std::vector& devices) { if (sharding_spec == nullptr) { // Unknown type is used to mark implicitly replicated data for // auto-sharding. // TODO(yeounoh) see if we can completely rely on Unknown without inference // performance degradation. - sharding = ShardingUtil::GetAutoSharding() - ? xla::HloSharding::Unknown().ToProto() - : xla::HloSharding::Replicate().ToProto(); + xla::OpSharding sharding = ShardingUtil::GetAutoSharding() + ? xla::HloSharding::Unknown().ToProto() + : xla::HloSharding::Replicate().ToProto(); // if replicated, global_shape is shape of the tensor. auto first_device = ParseDeviceString(devices[0]); - global_shape = + *global_shape = CreateComputationShapeFromTensor(local_shards[0], &first_device); + return torch_xla::OpSharding(sharding, std::nullopt); } else { - global_shape = sharding_spec->shape; - sharding = sharding_spec->sharding; + *global_shape = sharding_spec->shape; + return sharding_spec->sharding; } +} + +} // namespace + +runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( + const std::vector& local_shards, + const std::vector& devices, + const XLATensor::ShardingSpecPtr& sharding_spec) { + XLA_CHECK(local_shards.size() == devices.size()) + << "A device must be speficied for each shard"; + std::vector> source_tensors; + xla::Shape global_shape; + + // get torch_xla::OpSharding object + torch_xla::OpSharding sharding = GetTorchXlaOpSharding( + &global_shape, sharding_spec, local_shards, devices); + for (int64_t j = 0; j < devices.size(); ++j) { auto shard_device = ParseDeviceString(devices[j]); auto shard_shape = @@ -672,23 +727,23 @@ void ShardingUtil::ReshardParameters( std::vector* parameters, std::vector* nodes) { // Extract input shardings generated from auto-sharding pass. - std::vector input_shardings; + std::vector xla_input_shardings; if (module.spmd_parameters_shardings().size() == 1 && module.spmd_parameters_shardings()[0].type() == xla::OpSharding::TUPLE) { auto tuple_shardings = module.spmd_parameters_shardings()[0].tuple_shardings(); - input_shardings = std::vector(tuple_shardings.begin(), - tuple_shardings.end()); + xla_input_shardings = std::vector(tuple_shardings.begin(), + tuple_shardings.end()); } else { for (auto sharding : module.spmd_parameters_shardings()) { - input_shardings.push_back(sharding); + xla_input_shardings.push_back(sharding); } } - if (input_shardings.size() == 0) { - TF_VLOG(3) << "ReshardParamters... skip with empty input_shardings."; + if (xla_input_shardings.size() == 0) { + TF_VLOG(3) << "ReshardParamters... skip with empty xla_input_shardings."; return; } - XLA_CHECK_EQ(input_shardings.size(), parameters->size()); + XLA_CHECK_EQ(xla_input_shardings.size(), parameters->size()); // Reshard parameters as needed, as with a new sharding spec. std::vector data = @@ -696,13 +751,30 @@ void ShardingUtil::ReshardParameters( std::vector reshard_indices; std::vector data_to_reshard; - std::vector shardings_to_reshard; + std::vector shardings_to_reshard; + + std::vector denormalized_tile_assignment; + std::vector input_shardings; + auto sharding_spec = (*tensors)[0]->sharding_spec(); + if (sharding_spec) { + denormalized_tile_assignment = + sharding_spec->sharding.GetDenormalizedTileAssignment(); + for (const auto& sharding : xla_input_shardings) { + if (denormalized_tile_assignment.size() > 0) { + input_shardings.emplace_back(sharding, denormalized_tile_assignment); + } else { + input_shardings.emplace_back(sharding); + } + } + } + for (int i = 0; i < input_shardings.size(); ++i) { XLA_CHECK(input_shardings[i].type() != xla::OpSharding::UNKNOWN) << "Resharding by UNKNOWN sharding type is not allowed."; // Skip re-sharding if not necessary. - if (!xla::protobuf_util::HaveSameSerialization(data[i]->GetSharding(), - input_shardings[i])) { + if (!xla::protobuf_util::HaveSameSerialization( + data[i]->GetSharding().GetXlaOpSharding(), + input_shardings[i].GetXlaOpSharding())) { reshard_indices.push_back(i); data_to_reshard.push_back(data[i]); shardings_to_reshard.push_back(input_shardings[i]); @@ -761,7 +833,7 @@ void ShardingUtil::ReshardParameters( } void ShardingUtil::XlaMarkSharding(const at::Tensor& input, - xla::OpSharding sharding) { + torch_xla::OpSharding sharding) { TORCH_LAZY_COUNTER("XlaMarkSharding", 1); XLA_CHECK(UseVirtualDevice()) << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; @@ -839,7 +911,7 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input, } void ShardingUtil::XlaAnnotateCustomSharding(const XLATensorPtr& input, - xla::OpSharding sharding) { + torch_xla::OpSharding sharding) { TORCH_LAZY_COUNTER("XlaAnnotateCustomSharding", 1); XLA_CHECK(UseVirtualDevice()) diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 8b8b98653b2..1f143ece700 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -9,6 +9,7 @@ #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/tensor.h" +#include "torch_xla/csrc/torch_xla_op_sharding.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/service/hlo.pb.h" @@ -28,8 +29,8 @@ class ShardingUtil { UNKNOWN = 6 // implicit replication }; - // Determine the ShardingType of the given xla::OpSharding. - static ShardingType GetShardingType(xla::OpSharding& sharding); + // Determine the ShardingType of the given torch_xla::OpSharding. + static ShardingType GetShardingType(torch_xla::OpSharding& sharding); // Annotates HLO instructions in the lowered computation and returns true if // the computation needs to be compiled with SPMD partitioning. For this call @@ -42,15 +43,14 @@ class ShardingUtil { const XLATensor::ShardingSpec& b); // Returns true if two OpShardings are the same. - static bool EqualOpShardings(const xla::OpSharding& a, - const xla::OpSharding& b); + static bool EqualOpShardings(const torch_xla::OpSharding& a, + const torch_xla::OpSharding& b); - // Creates an xla::OpSharding. `tile_assignmnent` is required for TILED + // Creates an torch_xla::OpSharding. `tile_assignmnent` is required for TILED // `sharding_type` and `replication_groups` for `PARTIAL`. - static xla::OpSharding CreateOpSharding(const py::list& tile_assignment, - const py::list& group_assignment, - const py::list& replication_groups, - ShardingType sharding_type); + static torch_xla::OpSharding CreateOpSharding( + const py::list& tile_assignment, const py::list& group_assignment, + const py::list& replication_groups, ShardingType sharding_type); // Returns the shape of the resulting shards of `tensor` after applying // `sharding`. This assumes the shards will be padded to ensure they all @@ -67,7 +67,7 @@ class ShardingUtil { static std::vector>> GetShardReplicaAndIndicesForDevices(const std::vector& shard_shape, const std::vector& tensor_shape, - const xla::OpSharding sharding, + const torch_xla::OpSharding sharding, const std::vector& devices); // Returns the indices for the shards. Supports `OTHER` sharding types and @@ -94,7 +94,9 @@ class ShardingUtil { // is always on virtual SPMD device. static std::vector GetOutputSharding( const std::vector& output_shapes, - runtime::ComputationClient::ComputationPtr computation); + runtime::ComputationClient::ComputationPtr computation, + std::optional> denormalized_tile_assignment = + std::nullopt); // Create sharded data placeholders, each corresponding to the individual // sharding spec from the input list @@ -111,7 +113,9 @@ class ShardingUtil { std::vector* tensors, absl::Span indices, runtime::ComputationClient::ComputationPtr computation, std::vector* data_placeholders, - std::vector* sharding_specs); + std::vector* sharding_specs, + std::optional> denormalized_tile_assignment = + std::nullopt); // Transfers the individual shards to the devices and returns a DataPtr for // the PjRtShardedData wrapping the shards. @@ -121,14 +125,14 @@ class ShardingUtil { const XLATensor::ShardingSpecPtr& sharding_spec); static void XlaMarkSharding(const at::Tensor& input, - xla::OpSharding sharding); + torch_xla::OpSharding sharding); // Add a custom sharding node IR to an XLATensor. Note that unlike // XlaMarkSharding, this will not explicitly set a sharding spec tied to the // DeviceData node, nor transfer any sharded data to the device. This serves // merely as an XLA custom sharding annotation IR. static void XlaAnnotateCustomSharding(const XLATensorPtr& input, - xla::OpSharding sharding); + torch_xla::OpSharding sharding); //////////////////////////// Auto-Sharding //////////////////////////// // Construct a device mesh for auto-sharding pass. Returns a tuple of mesh