Skip to content

Commit 4752558

Browse files
committed
fix: use lowering_cntxt to save/get denormalized_tile_assignment
1 parent 4556faa commit 4752558

File tree

9 files changed

+162
-59
lines changed

9 files changed

+162
-59
lines changed

test/cpp/test_xla_sharding.cpp

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ TEST_F(XLAShardingTest, GetShardShape) {
5151
{2, 3},
5252
});
5353
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
54-
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
54+
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3};
55+
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
5556
auto sharding_spec =
5657
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
5758

@@ -60,7 +61,7 @@ TEST_F(XLAShardingTest, GetShardShape) {
6061
EXPECT_EQ(shard_shape, std::vector<int64_t>({4, 4}));
6162

6263
xla_sharding = xla::HloSharding::Replicate().ToProto();
63-
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
64+
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
6465
sharding_spec->sharding = sharding;
6566
shard_shape = ShardingUtil::GetShardShape(sharding_spec);
6667
// For replicated sharding, each dimension should be preserved
@@ -78,7 +79,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
7879
{2, 3},
7980
});
8081
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
81-
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
82+
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3};
83+
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
8284
auto sharding_spec =
8385
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
8486
auto shard_shape = ShardingUtil::GetShardShape(sharding_spec);
@@ -108,7 +110,7 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
108110
}
109111
}
110112
xla_sharding = xla::HloSharding::Replicate().ToProto();
111-
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
113+
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
112114
sharding_spec->sharding = sharding;
113115
shard_shape = ShardingUtil::GetShardShape(sharding_spec);
114116
replica_and_indices = ShardingUtil::GetShardReplicaAndIndicesForDevices(
@@ -126,6 +128,7 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
126128
TEST_F(XLAShardingTest, ShardTensor) {
127129
std::vector<std::string> devices = {"TPU:0", "TPU:1", "TPU:2", "TPU:3",
128130
"TPU:4", "TPU:5", "TPU:6", "TPU:7"};
131+
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7};
129132

130133
// 1D tiled
131134
at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat));
@@ -136,7 +139,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
136139
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()),
137140
devices.size())
138141
.ToProto();
139-
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
142+
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
140143
auto sharding_spec =
141144
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
142145
auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
@@ -155,7 +158,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
155158
{4, 5, 6, 7},
156159
});
157160
xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
158-
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
161+
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
159162
sharding_spec =
160163
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
161164
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
@@ -168,7 +171,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
168171
// size should be smaller in dim=1 because it's not evenly divisible.
169172
xla::Array3D<int64_t> cube({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}});
170173
xla_sharding = xla::HloSharding::Tile(cube).ToProto();
171-
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
174+
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
172175
sharding_spec->sharding = sharding;
173176
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
174177
/*padded=*/false);
@@ -178,7 +181,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
178181

179182
// Replicated, all shards should be identical.
180183
xla_sharding = xla::HloSharding::Replicate().ToProto();
181-
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
184+
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
182185
sharding_spec->sharding = sharding;
183186
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
184187
/*padded=*/false);
@@ -194,7 +197,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
194197
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
195198
xla::Array4D<int64_t> tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}});
196199
xla_sharding = xla::HloSharding::Tile(tesseract).ToProto();
197-
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
200+
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
198201
sharding_spec =
199202
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
200203
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
@@ -219,7 +222,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
219222
xla::Array<int64_t> hypercube(std::vector<int64_t>{1, 1, 2, 2, 2});
220223
hypercube.FillIota(0);
221224
xla_sharding = xla::HloSharding::Tile(hypercube).ToProto();
222-
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
225+
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
223226
sharding_spec =
224227
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
225228
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
@@ -248,7 +251,8 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
248251
{6, 7, 2, 3},
249252
});
250253
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
251-
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
254+
std::vector<int64_t> denormalized_tile_assignment = {4, 5, 0, 1, 6, 7, 2, 3};
255+
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
252256
auto sharding_spec =
253257
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
254258
// For devices at the start of the mesh, all shards should have the same
@@ -266,7 +270,8 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
266270
{2, 3, 6, 7},
267271
});
268272
xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
269-
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
273+
denormalized_tile_assignment = {0, 1, 4, 5, 2, 3, 6, 7};
274+
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
270275
sharding_spec->sharding = sharding;
271276
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
272277
/*padded=*/false);
@@ -295,7 +300,8 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) {
295300
});
296301

297302
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
298-
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
303+
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7};
304+
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
299305
auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
300306
sharding, global_shape, /*minibatch=*/true);
301307
auto shards = ShardingUtil::ShardTensor(minibatch_tensor, sharding_spec,
@@ -314,14 +320,15 @@ TEST_F(XLAShardingTest, EqualShardingSpecs) {
314320
{4, 5, 6, 7},
315321
})
316322
.ToProto();
317-
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
323+
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7};
324+
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
318325
XLATensor::ShardingSpec tiled_2d(sharding, tensor_shape);
319326
xla_sharding =
320327
xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto();
321-
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
328+
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
322329
XLATensor::ShardingSpec tiled_3d(sharding, tensor_shape);
323330
xla_sharding = xla::HloSharding::Replicate().ToProto();
324-
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
331+
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
325332
XLATensor::ShardingSpec replicated(sharding, tensor_shape);
326333
EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_2d));
327334
EXPECT_FALSE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_3d));

torch_xla/csrc/ir.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@ class XlaNode : public torch::lazy::Node {
141141
return output_shardings_[index];
142142
}
143143

144+
const std::vector<std::shared_ptr<torch_xla::OpSharding>> GetShardings()
145+
const {
146+
return output_shardings_;
147+
}
148+
144149
void SetSharding(const torch_xla::OpSharding& sharding, size_t index);
145150

146151
void ClearSharding() {

torch_xla/csrc/lowering_context.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,13 +238,31 @@ xla::XlaOp LoweringContext::GetOutputOp(const torch::lazy::Output& output) {
238238
return it->second;
239239
}
240240

241+
void LoweringContext::ExtractShardingAndSetDenormalizedTileAssignments(
242+
std::vector<std::shared_ptr<torch_xla::OpSharding>> shardings) {
243+
for (auto sharding : shardings) {
244+
std::vector<int64_t> denormalized_tile_assignment =
245+
sharding->GetDenormalizedTileAssignment();
246+
if (!denormalized_tile_assignment.empty()) {
247+
denormalized_tile_assignments_.push_back(
248+
sharding->GetDenormalizedTileAssignment());
249+
}
250+
}
251+
}
252+
241253
XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node& node) {
242254
XlaOpVector result_ops;
243255
try {
244256
const HloMetadataSetter meta_setter(*this, node);
245257
const XlaNode* const casted = dynamic_cast<const XlaNode*>(&node);
246258

247259
result_ops = casted->Lower(this);
260+
// save the denormalized_tile_assignment from all nodes and then use it
261+
// during Compile
262+
auto shardings = casted->GetShardings();
263+
if (!shardings.empty()) {
264+
ExtractShardingAndSetDenormalizedTileAssignments(shardings);
265+
}
248266
if (!casted->dynamic_dims().empty()) {
249267
const xla::internal::XlaBuilderFriend builder_friend;
250268
auto* const inst = builder_friend.GetInstruction(result_ops[0]);

torch_xla/csrc/lowering_context.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,14 @@ class LoweringContext : public torch::lazy::LoweringContext {
117117
int64_t AddStackFrameLocation(const torch::lazy::SourceLocation& source,
118118
int64_t parent_id);
119119

120+
void ExtractShardingAndSetDenormalizedTileAssignments(
121+
std::vector<std::shared_ptr<torch_xla::OpSharding>>);
122+
123+
const std::vector<std::vector<int64_t>>& GetDenormalizedTileAssignments()
124+
const {
125+
return denormalized_tile_assignments_;
126+
}
127+
120128
private:
121129
struct Parameter {
122130
xla::XlaOp param;
@@ -135,6 +143,7 @@ class LoweringContext : public torch::lazy::LoweringContext {
135143
std::string name_;
136144

137145
std::shared_ptr<StackFrameIndexBuilder> stack_frame_index_builder_;
146+
std::vector<std::vector<int64_t>> denormalized_tile_assignments_;
138147
}; // namespace torch_xla
139148

140149
} // namespace torch_xla

torch_xla/csrc/tensor_util.cpp

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,39 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
838838
runtime::GetComputationClientOrDie()->TransferToDevice(source_tensors));
839839
}
840840

841+
namespace {
842+
843+
/**
844+
* Filters a list of device strings to include only those with IDs matching
845+
* the provided indices.
846+
*
847+
* @param devices List of device strings in format "TYPE:ID" (e.g., "TPU:0")
848+
* @param indices List of device IDs to filter by
849+
* @return Filtered list of device strings, or error status if parsing fails
850+
*
851+
* Example:
852+
* devices = ["TPU:0", "TPU:1", "TPU:2", "TPU:3"]
853+
* indices = [1, 3]
854+
* result = ["TPU:1", "TPU:3"]
855+
*/
856+
std::vector<std::string> FilterDevicesByAddressableDevices(
857+
std::vector<std::string> devices, const std::vector<int64_t>& indices) {
858+
std::vector<std::string> filtered_devices_;
859+
filtered_devices_.reserve(indices.size());
860+
for (auto& index : indices) {
861+
for (auto& device : devices) {
862+
std::vector<std::string> device_spec_parts = absl::StrSplit(device, ':');
863+
if (std::stoi(device_spec_parts[1]) == index) {
864+
filtered_devices_.push_back(device);
865+
break;
866+
}
867+
}
868+
}
869+
return filtered_devices_;
870+
}
871+
872+
} // namespace
873+
841874
std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
842875
const std::vector<at::Tensor>& tensors,
843876
const std::vector<XLATensor::ShardingSpecPtr>& shardings,
@@ -860,14 +893,25 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
860893

861894
std::vector<std::string> local_devices =
862895
runtime::GetComputationClientOrDie()->GetLocalDevices();
896+
std::vector<std::string> addressable_devices = std::move(local_devices);
897+
if (shardings[i]) {
898+
const std::vector<int64_t>& denormalized_tile_assignment =
899+
shardings[i]->sharding.GetDenormalizedTileAssignment();
900+
if ((!denormalized_tile_assignment.empty()) &&
901+
(denormalized_tile_assignment.size() !=
902+
addressable_devices.size())) {
903+
addressable_devices = FilterDevicesByAddressableDevices(
904+
addressable_devices, denormalized_tile_assignment);
905+
}
906+
}
863907
// Shards the input tensors with padding, to split evenly.
864908
// The execution requires consistent shard sizes, and the zero-padded
865909
// values should be ignored.
866-
std::vector<at::Tensor> local_shards =
867-
ShardingUtil::ShardTensor(tensors[i], shardings[i], local_devices,
868-
/*padded=*/true);
910+
std::vector<at::Tensor> local_shards = ShardingUtil::ShardTensor(
911+
tensors[i], shardings[i], addressable_devices,
912+
/*padded=*/true);
869913
new_handles.push_back(ShardingUtil::CreateShardedData(
870-
local_shards, local_devices, shardings[i]));
914+
local_shards, addressable_devices, shardings[i]));
871915
} else {
872916
source_tensors.push_back(std::make_shared<runtime::AtenSource>(
873917
tensors[i], std::move(shape), devices[i]));

0 commit comments

Comments
 (0)