File tree Expand file tree Collapse file tree 1 file changed +9
-8
lines changed Expand file tree Collapse file tree 1 file changed +9
-8
lines changed Original file line number Diff line number Diff line change @@ -756,17 +756,18 @@ void ShardingUtil::ReshardParameters(
756
756
std::vector<runtime::ComputationClient::DataPtr> data_to_reshard;
757
757
std::vector<torch_xla::OpSharding> shardings_to_reshard;
758
758
759
+ std::vector<int64_t > denormalized_tile_assignment;
760
+ auto sharding_spec = (*tensors)[0 ]->sharding_spec ();
761
+ if (sharding_spec) {
762
+ denormalized_tile_assignment = sharding_spec->sharding .GetDenormalizedTileAssignment ();
763
+ }
759
764
for (const auto & sharding : xla_input_shardings) {
760
- for (const auto & data : *parameters) {
761
- runtime::ComputationClient::DataPtr handle =
762
- std::dynamic_pointer_cast<runtime::ComputationClient::Data>(data);
763
- auto computation_client_ptr = runtime::GetComputationClient ();
764
- torch_xla::OpSharding torch_xla_opsharding =
765
- (*computation_client_ptr)->GetDataSharding (handle).value ();
766
- std::vector<int64_t > denormalized_tile_assignment =
767
- torch_xla_opsharding.GetDenormalizedTileAssignment ();
765
+ if (denormalized_tile_assignment.size () > 0 ){
768
766
input_shardings.emplace_back (sharding, denormalized_tile_assignment);
769
767
}
768
+ else {
769
+ input_shardings.emplace_back (sharding);
770
+ }
770
771
}
771
772
772
773
for (int i = 0 ; i < input_shardings.size (); ++i) {
You can’t perform that action at this time.
0 commit comments