Skip to content

Commit b390dd5

Browse files
author
kvshbg-aws
committed
use tensors instead of paramters_data to get denormalized_tile_assignment
1 parent 2e34d5a commit b390dd5

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

torch_xla/csrc/xla_sharding_util.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -756,17 +756,18 @@ void ShardingUtil::ReshardParameters(
756756
std::vector<runtime::ComputationClient::DataPtr> data_to_reshard;
757757
std::vector<torch_xla::OpSharding> shardings_to_reshard;
758758

759+
std::vector<int64_t> denormalized_tile_assignment;
760+
auto sharding_spec = (*tensors)[0]->sharding_spec();
761+
if (sharding_spec) {
762+
denormalized_tile_assignment = sharding_spec->sharding.GetDenormalizedTileAssignment();
763+
}
759764
for (const auto& sharding : xla_input_shardings) {
760-
for (const auto& data : *parameters) {
761-
runtime::ComputationClient::DataPtr handle =
762-
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(data);
763-
auto computation_client_ptr = runtime::GetComputationClient();
764-
torch_xla::OpSharding torch_xla_opsharding =
765-
(*computation_client_ptr)->GetDataSharding(handle).value();
766-
std::vector<int64_t> denormalized_tile_assignment =
767-
torch_xla_opsharding.GetDenormalizedTileAssignment();
765+
if (denormalized_tile_assignment.size() > 0){
768766
input_shardings.emplace_back(sharding, denormalized_tile_assignment);
769767
}
768+
else{
769+
input_shardings.emplace_back(sharding);
770+
}
770771
}
771772

772773
for (int i = 0; i < input_shardings.size(); ++i) {

0 commit comments

Comments
 (0)