Skip to content

feat: abstraction of xla::OpSharding proto using wrapper class #9467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

kvshbg-aws
Copy link

This PR includes the changes related to abstracting xla::OpSharidng proto object into a torch_xla::OpSharding wrapper class.

This new class object will not have the requirements of xla::OpSharding (however, it will be an extension xla::OpSharding proto defined over here).
We have defined the wrapper class in torch/xla which will construct an xla::OpSharding object with additional fields such as global_device_ids/global_tile_assignment and will have forwarded/proxy functions to xla::OpSharding . These forwarded functions will help user still make use of the same xla::OpSharding APIs as they normally would. We can also define torch_xla specific functions in this wrapper class to further use the extra fields that were stored during the initialization of the OpSharding object. This approach also allows the flexibility of converting the torch_xla::OpSharding object back to xla::OpSharding while lowering into HLO, thus, giving user the flexibility to use the abstracted class (and other additional fields stored) anywhere in the code base as needed, this is particularly useful since the XLA's HLOs are 0th indexed, hence we need to use the normalized_device_ids (starting from index 0) when lowering the program into the HLO, whereas we can still use the denormalized/global_device_ids in other places such as inside pjrt client to set the device_assignment using the user specified device_ids.

Component diagram for reference -
Image (1)

Ref issue - #9390

@kvshbg-aws kvshbg-aws force-pushed the kvshbg-aws/local-spmd-abstraction branch from d0502ab to 7fc15ea Compare July 10, 2025 18:34
@qihqi qihqi requested review from rpsilva-aws and pgmoka July 11, 2025 04:20
@kvshbg-aws kvshbg-aws force-pushed the kvshbg-aws/local-spmd-abstraction branch 4 times, most recently from 7c4a3cd to 1d55ae9 Compare July 16, 2025 23:57
@kvshbg-aws kvshbg-aws force-pushed the kvshbg-aws/local-spmd-abstraction branch 2 times, most recently from 1ddbb1b to 2756c1a Compare July 25, 2025 23:53
@kvshbg-aws kvshbg-aws force-pushed the kvshbg-aws/local-spmd-abstraction branch from 2756c1a to 4556faa Compare July 30, 2025 04:14
Copy link
Collaborator

@pgmoka pgmoka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left comments. Overall, the PR looks very good.

Comment on lines +144 to +147
const std::vector<std::shared_ptr<torch_xla::OpSharding>> GetShardings()
const {
return output_shardings_;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep GetShardings consistend with SetSharding and have its implementation be in ir.cpp

Comment on lines +244 to +249
std::vector<int64_t> denormalized_tile_assignment =
sharding->GetDenormalizedTileAssignment();
if (!denormalized_tile_assignment.empty()) {
denormalized_tile_assignments_.push_back(
sharding->GetDenormalizedTileAssignment());
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Mixing declaration types here is a bit confusing to read. I would consider:

  1. In line 247, using:
denormalized_tile_assignments_.push_back(denormalized_tile_assignment);

OR
2) Remove denormalized_tile_assignment, and just check !sharding->GetDenormalizedTileAssignment().empty()

Comment on lines +1269 to +1283
std::vector<std::vector<int64_t>> denormalized_tile_assignments;
for (const auto* node : po_data->post_order) {
const XlaNode* const casted = dynamic_cast<const XlaNode*>(node);
auto shardings = casted->GetShardings();
if (!shardings.empty()) {
for (auto sharding : shardings) {
std::vector<int64_t> denormalized_tile_assignment =
sharding->GetDenormalizedTileAssignment();
if (!denormalized_tile_assignment.empty()) {
denormalized_tile_assignments.push_back(
sharding->GetDenormalizedTileAssignment());
}
}
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For posterity, a quick comment to outline what this is doing might be useful.

Comment on lines +293 to +303
if (sharding.tile_assignment_devices().empty() && !tile_assignment.empty()) {
// Convert the Python list tile_assignment to a flattened vector for
// denormalized assignment
xla::Array<int64_t> tile_array = TileListToArray(tile_assignment);
denormalized_tile_assignment.assign(tile_array.begin(), tile_array.end());
} else {
// Use the tile_assignment_devices from the XLA OpSharding object
denormalized_tile_assignment.assign(
sharding.tile_assignment_devices().begin(),
sharding.tile_assignment_devices().end());
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT suggestion: we can remove if branching here by doing something like:

// Use the tile_assignment_devices from the XLA OpSharding object by default
denormalized_tile_assignment.assign(
    sharding.tile_assignment_devices().begin(),
   sharding.tile_assignment_devices().end());
if (sharding.tile_assignment_devices().empty() && !tile_assignment.empty()) {
    // Convert the Python list tile_assignment to a flattened vector for
    // denormalized assignment
    xla::Array<int64_t> tile_array = TileListToArray(tile_assignment);
    denormalized_tile_assignment.assign(tile_array.begin(), tile_array.end());
  }

}
}
if (input_shardings.size() == 0) {
if (xla_input_shardings.size() == 0) {
TF_VLOG(3) << "ReshardParamters... skip with empty input_shardings.";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: "ReshardParamters... skip with empty xla_input_shardings."

sharding_spec->sharding.GetDenormalizedTileAssignment();
}
for (const auto& sharding : xla_input_shardings) {
if (denormalized_tile_assignment.size() > 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this generates an odd use case. I am not sure if it would ever get triggered, but if (*tensors)[0]->sharding_spec() does not exist, denormalized_tile_assignment never gets created, which might generate a error here. The way to ensure this is to move the for loop into the if (sharding_spec) check.

Comment on lines +318 to +332
auto xla_sharding = xla::HloSharding::Tile({
{0, 1, 2, 3},
{4, 5, 6, 7},
})
.ToProto();
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7};
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment);
XLATensor::ShardingSpec tiled_2d(sharding, tensor_shape);
xla_sharding =
xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto();
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
XLATensor::ShardingSpec tiled_3d(sharding, tensor_shape);
xla_sharding = xla::HloSharding::Replicate().ToProto();
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment);
XLATensor::ShardingSpec replicated(sharding, tensor_shape);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit confusing. Specifically, xla_sharding is used and redefined multiple times through the test. I would perhaps separate these into separate instances. It is less efficient, but I would favor readability for the test

Comment on lines +278 to +298
IfrtComputation(
xla::XlaComputation computation, std::vector<std::string> devices,
std::shared_ptr<xla::ifrt::LoadedExecutable> executable,
std::optional<std::vector<int64_t>> denormalized_tile_assignment)
: Computation(std::move(computation), std::move(devices)),
executable(std::move(executable)) {
output_shardings_ = this->executable->GetOutputShardings();
executable(std::move(executable)),
denormalized_tile_assignment_(std::move(
denormalized_tile_assignment.value_or(std::vector<int64_t>{}))) {
xla_output_shardings_ = this->executable->GetOutputShardings();
if (xla_output_shardings_.has_value()) {
output_shardings_ = std::vector<torch_xla::OpSharding>{};
output_shardings_->reserve(xla_output_shardings_.value().size());
for (const auto& sharding : xla_output_shardings_.value()) {
// convert each into torch_xla::OpSharding object
torch_xla::OpSharding torch_xla_op_sharding(
sharding, denormalized_tile_assignment_);
output_shardings_.value().push_back(torch_xla_op_sharding);
}
} else {
output_shardings_ = std::nullopt;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to move the constructor to ifrt_computation_client.cpp. If you can do here, great. If you don't want to do it, please create an issue and link a TODO from here for posterity.

Comment on lines +287 to +298
if (xla_output_shardings_.has_value()) {
output_shardings_ = std::vector<torch_xla::OpSharding>{};
output_shardings_->reserve(xla_output_shardings_.value().size());
for (const auto& sharding : xla_output_shardings_.value()) {
// convert each into torch_xla::OpSharding object
torch_xla::OpSharding torch_xla_op_sharding(
sharding, denormalized_tile_assignment_);
output_shardings_.value().push_back(torch_xla_op_sharding);
}
} else {
output_shardings_ = std::nullopt;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can simplify this if statement tree by setting output_shardings_ = std::nullopt; as default. Ex.:

      output_shardings_ = std::nullopt;
      if (xla_output_shardings_.has_value()) {
        output_shardings_ = std::vector<torch_xla::OpSharding>{};
        output_shardings_->reserve(xla_output_shardings_.value().size());
        for (const auto& sharding : xla_output_shardings_.value()) {
          // convert each into torch_xla::OpSharding object
          torch_xla::OpSharding torch_xla_op_sharding(
              sharding, denormalized_tile_assignment_);
          output_shardings_.value().push_back(torch_xla_op_sharding);
        }
      } 

* @param executable The compiled PJRT executable
* @param denormalized_tile_assignment Optional tile assignment for sharding
*/
PjRtComputation(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same notes as ifrt_computation_client.cpp

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants