Skip to content

Commit 0e229b5

Browse files
author
root
committed
feat:infer addressable devices for submeshing
1 parent 3c07da6 commit 0e229b5

13 files changed

+458
-98
lines changed

test/cpp/test_xla_sharding.cpp

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,34 @@ class XLAShardingTest : public AtenXlaTensorTestBase {
4242
}
4343
};
4444

45+
TEST_F(XLAShardingTest, NormalizeTileAssignment) {
46+
// Test with an empty tile assignment
47+
std::vector<int64_t> empty_tile_assignment = {};
48+
auto normalized =
49+
ShardingUtil::NormalizeTileAssignment(empty_tile_assignment);
50+
EXPECT_TRUE(normalized.empty());
51+
52+
// Test with positive values
53+
std::vector<int64_t> positive_tile_assignment = {3, 1, 4, 2};
54+
normalized = ShardingUtil::NormalizeTileAssignment(positive_tile_assignment);
55+
EXPECT_EQ(normalized, std::vector<int64_t>({2, 0, 3, 1}));
56+
57+
// Test with all identical values
58+
std::vector<int64_t> identical_tile_assignment = {5, 5, 5, 5};
59+
normalized = ShardingUtil::NormalizeTileAssignment(identical_tile_assignment);
60+
EXPECT_EQ(normalized, std::vector<int64_t>({0, 0, 0, 0}));
61+
62+
// Test with negative values
63+
std::vector<int64_t> negative_tile_assignment = {-3, -1, -4, -2};
64+
EXPECT_THROW(ShardingUtil::NormalizeTileAssignment(negative_tile_assignment),
65+
std::runtime_error);
66+
67+
// Test with mixed positive and negative values
68+
std::vector<int64_t> mixed_tile_assignment = {3, -1, 4, 2};
69+
EXPECT_THROW(ShardingUtil::NormalizeTileAssignment(mixed_tile_assignment),
70+
std::runtime_error);
71+
}
72+
4573
TEST_F(XLAShardingTest, GetShardShape) {
4674
auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat));
4775
xla::Shape tensor_shape =
@@ -51,7 +79,8 @@ TEST_F(XLAShardingTest, GetShardShape) {
5179
{2, 3},
5280
});
5381
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
54-
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
82+
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3};
83+
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
5584
auto sharding_spec =
5685
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
5786

@@ -60,7 +89,7 @@ TEST_F(XLAShardingTest, GetShardShape) {
6089
EXPECT_EQ(shard_shape, std::vector<int64_t>({4, 4}));
6190

6291
xla_sharding = xla::HloSharding::Replicate().ToProto();
63-
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
92+
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
6493
sharding_spec->sharding = sharding;
6594
shard_shape = ShardingUtil::GetShardShape(sharding_spec);
6695
// For replicated sharding, each dimension should be preserved
@@ -78,7 +107,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
78107
{2, 3},
79108
});
80109
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
81-
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
110+
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3};
111+
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
82112
auto sharding_spec =
83113
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
84114
auto shard_shape = ShardingUtil::GetShardShape(sharding_spec);
@@ -108,7 +138,7 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
108138
}
109139
}
110140
xla_sharding = xla::HloSharding::Replicate().ToProto();
111-
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
141+
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
112142
sharding_spec->sharding = sharding;
113143
shard_shape = ShardingUtil::GetShardShape(sharding_spec);
114144
replica_and_indices = ShardingUtil::GetShardReplicaAndIndicesForDevices(
@@ -126,6 +156,7 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
126156
TEST_F(XLAShardingTest, ShardTensor) {
127157
std::vector<std::string> devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3",
128158
"TPU:4", "TPU:5", "TPU:6", "TPU:7"};
159+
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7};
129160

130161
// 1D tiled
131162
at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat));
@@ -136,7 +167,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
136167
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()),
137168
devices.size())
138169
.ToProto();
139-
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
170+
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
140171
auto sharding_spec =
141172
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
142173
auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
@@ -155,7 +186,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
155186
{4, 5, 6, 7},
156187
});
157188
xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
158-
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
189+
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
159190
sharding_spec =
160191
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
161192
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
@@ -168,7 +199,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
168199
// size should be smaller in dim=1 because it's not evenly divisible.
169200
xla::Array3D<int64_t> cube({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}});
170201
xla_sharding = xla::HloSharding::Tile(cube).ToProto();
171-
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
202+
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
172203
sharding_spec->sharding = sharding;
173204
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
174205
/*padded=*/false);
@@ -178,7 +209,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
178209

179210
// Replicated, all shards should be identical.
180211
xla_sharding = xla::HloSharding::Replicate().ToProto();
181-
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
212+
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
182213
sharding_spec->sharding = sharding;
183214
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
184215
/*padded=*/false);
@@ -194,7 +225,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
194225
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
195226
xla::Array4D<int64_t> tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}});
196227
xla_sharding = xla::HloSharding::Tile(tesseract).ToProto();
197-
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
228+
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
198229
sharding_spec =
199230
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
200231
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
@@ -219,7 +250,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
219250
xla::Array<int64_t> hypercube(std::vector<int64_t>{1, 1, 2, 2, 2});
220251
hypercube.FillIota(0);
221252
xla_sharding = xla::HloSharding::Tile(hypercube).ToProto();
222-
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
253+
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
223254
sharding_spec =
224255
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
225256
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
@@ -248,7 +279,8 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
248279
{6, 7, 2, 3},
249280
});
250281
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
251-
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
282+
std::vector<int64_t> denormalized_tile_assignment = {4, 5, 0, 1, 6, 7, 2, 3};
283+
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
252284
auto sharding_spec =
253285
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
254286
// For devices at the start of the mesh, all shards should have the same
@@ -266,7 +298,8 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
266298
{2, 3, 6, 7},
267299
});
268300
xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
269-
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
301+
denormalized_tile_assignment = {0, 1, 4, 5, 2, 3, 6, 7};
302+
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
270303
sharding_spec->sharding = sharding;
271304
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
272305
/*padded=*/false);
@@ -295,7 +328,8 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) {
295328
});
296329

297330
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
298-
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
331+
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7};
332+
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
299333
auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
300334
sharding, global_shape, /*minibatch=*/true);
301335
auto shards = ShardingUtil::ShardTensor(minibatch_tensor, sharding_spec,
@@ -314,14 +348,15 @@ TEST_F(XLAShardingTest, EqualShardingSpecs) {
314348
{4, 5, 6, 7},
315349
})
316350
.ToProto();
317-
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
351+
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7};
352+
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
318353
XLATensor::ShardingSpec tiled_2d(sharding, tensor_shape);
319354
xla_sharding =
320355
xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto();
321-
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
356+
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
322357
XLATensor::ShardingSpec tiled_3d(sharding, tensor_shape);
323358
xla_sharding = xla::HloSharding::Replicate().ToProto();
324-
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
359+
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
325360
XLATensor::ShardingSpec replicated(sharding, tensor_shape);
326361
EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_2d));
327362
EXPECT_FALSE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_3d));

test/spmd/test_xla_sharding.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -761,13 +761,15 @@ def test_hybrid_mesh_shape(self):
761761

762762
@unittest.skipIf(xr.device_type() == 'TPU' and tpu.version() < 3,
763763
"Crash on TPU v2")
764+
@patch('torch_xla.runtime.addressable_runtime_device_count')
764765
@patch('torch_xla.runtime.global_runtime_device_attributes')
765766
@patch('torch_xla.core.xla_model.xla_device_hw')
766767
@patch('torch_xla.runtime.global_runtime_device_count')
767768
def test_hybrid_mesh(self, device_count_mock, xla_device_mock,
768-
device_attributes_mock):
769+
device_attributes_mock, addressable_device_count_mock):
769770
# mock device attributes for 2 slices of v4-8
770771
num_slices = 2
772+
addressable_device_count_mock.return_value = 8
771773
device_count_mock.return_value = 8
772774
xla_device_mock.return_value = "TPU"
773775
device_attributes_mock.return_value = [{
@@ -1621,7 +1623,7 @@ def test_device_ids_out_of_bounds(self):
16211623
invalid_ids = np.arange(self.n_devices + 1, self.n_devices * 2 + 1)
16221624

16231625
with self.assertRaisesRegex(AssertionError,
1624-
"Device IDs must be less than mesh size"):
1626+
"Device IDs has to be subset of addressable_devices; got:*."):
16251627
xs.Mesh(invalid_ids, mesh_shape)
16261628

16271629
def test_mesh_size(self):
@@ -1645,7 +1647,7 @@ def test_mismatch_global_devices(self):
16451647
mesh_shape = (1, partial_num_devices)
16461648
with self.assertRaisesRegex(
16471649
AssertionError,
1648-
"Number of device IDs .* must match the global number of devices"):
1650+
"Number of device IDs .* must be less than the global number of devices"):
16491651
xs.Mesh(device_ids, mesh_shape)
16501652

16511653
@unittest.skipIf(xr.global_runtime_device_count() == 1,

torch_xla/csrc/ir.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@ class XlaNode : public torch::lazy::Node {
141141
return output_shardings_[index];
142142
}
143143

144+
const std::vector<std::shared_ptr<torch_xla::OpSharding>> GetShardings()
145+
const {
146+
return output_shardings_;
147+
}
148+
144149
void SetSharding(const torch_xla::OpSharding& sharding, size_t index);
145150

146151
void ClearSharding() {

torch_xla/csrc/lowering_context.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,13 +238,31 @@ xla::XlaOp LoweringContext::GetOutputOp(const torch::lazy::Output& output) {
238238
return it->second;
239239
}
240240

241+
void LoweringContext::ExtractShardingAndSetDenormalizedTileAssignments(
242+
std::vector<std::shared_ptr<torch_xla::OpSharding>> shardings) {
243+
for (auto sharding : shardings) {
244+
std::vector<int64_t> denormalized_tile_assignment =
245+
sharding->GetDenormalizedTileAssignment();
246+
if (!denormalized_tile_assignment.empty()) {
247+
denormalized_tile_assignments_.push_back(
248+
sharding->GetDenormalizedTileAssignment());
249+
}
250+
}
251+
}
252+
241253
XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node& node) {
242254
XlaOpVector result_ops;
243255
try {
244256
const HloMetadataSetter meta_setter(*this, node);
245257
const XlaNode* const casted = dynamic_cast<const XlaNode*>(&node);
246258

247259
result_ops = casted->Lower(this);
260+
auto shardings = casted->GetShardings();
261+
if (!shardings.empty()) {
262+
ExtractShardingAndSetDenormalizedTileAssignments(shardings);
263+
}
264+
// save the denormalized_tile_assignment from all nodes and then use it
265+
// during Compile
248266
if (!casted->dynamic_dims().empty()) {
249267
const xla::internal::XlaBuilderFriend builder_friend;
250268
auto* const inst = builder_friend.GetInstruction(result_ops[0]);

torch_xla/csrc/lowering_context.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,14 @@ class LoweringContext : public torch::lazy::LoweringContext {
117117
int64_t AddStackFrameLocation(const torch::lazy::SourceLocation& source,
118118
int64_t parent_id);
119119

120+
void ExtractShardingAndSetDenormalizedTileAssignments(
121+
std::vector<std::shared_ptr<torch_xla::OpSharding>>);
122+
123+
const std::vector<std::vector<int64_t>>& GetDenormalizedTileAssignments()
124+
const {
125+
return denormalized_tile_assignments_;
126+
}
127+
120128
private:
121129
struct Parameter {
122130
xla::XlaOp param;
@@ -135,6 +143,7 @@ class LoweringContext : public torch::lazy::LoweringContext {
135143
std::string name_;
136144

137145
std::shared_ptr<StackFrameIndexBuilder> stack_frame_index_builder_;
146+
std::vector<std::vector<int64_t>> denormalized_tile_assignments_;
138147
}; // namespace torch_xla
139148

140149
} // namespace torch_xla

0 commit comments

Comments
 (0)