Skip to content

Commit 7cb36bb

Browse files
author
kvshbg-aws
committed
bug fix: set value of denormalized_tile_assignment when sharding.tile_assignment() is empty
1 parent 1db6d33 commit 7cb36bb

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,8 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::ReshardData(
432432
XLA_CHECK_NE(sharding.type(), xla::OpSharding::UNKNOWN)
433433
<< "Resharding by UNKNOWN sharding type is not allowed.";
434434

435-
hlo_shardings.push_back(
436-
GetValueOrThrow(xla::HloSharding::FromProto(sharding.GetXlaOpSharding())));
435+
hlo_shardings.push_back(GetValueOrThrow(
436+
xla::HloSharding::FromProto(sharding.GetXlaOpSharding())));
437437

438438
xla::OpSharding fallback_sharding;
439439
fallback_sharding.set_type(xla::OpSharding::REPLICATED);

torch_xla/csrc/xla_sharding_util.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,21 @@ torch_xla::OpSharding ShardingUtil::CreateOpSharding(
286286
}
287287
}
288288

289-
// create a copy of the original_tile_assignment_devices from xla::OpSharding
290-
// object
291-
std::vector<int64_t> denormalized_tile_assignment(
292-
sharding.tile_assignment_devices().begin(),
293-
sharding.tile_assignment_devices().end());
289+
// Create denormalized_tile_assignment. If sharding.tile_assignment_devices()
290+
// is empty (which happens for REPLICATED, MANUAL, UNKNOWN sharding types),
291+
// use the original tile_assignment arg that was passed to this function.
292+
std::vector<int64_t> denormalized_tile_assignment;
293+
if (sharding.tile_assignment_devices().empty() && !tile_assignment.empty()) {
294+
// Convert the Python list tile_assignment to a flattened vector for
295+
// denormalized assignment
296+
xla::Array<int64_t> tile_array = TileListToArray(tile_assignment);
297+
denormalized_tile_assignment.assign(tile_array.begin(), tile_array.end());
298+
} else {
299+
// Use the tile_assignment_devices from the XLA OpSharding object
300+
denormalized_tile_assignment.assign(
301+
sharding.tile_assignment_devices().begin(),
302+
sharding.tile_assignment_devices().end());
303+
}
294304

295305
// Use the xla::OpSharding object in the wrapper torch_xla::OpSharding along
296306
// with denormalized_tile_assignment (original tile_assignment) for extended

0 commit comments

Comments
 (0)