-
Notifications
You must be signed in to change notification settings - Fork 559
feat: abstraction of xla::OpSharding proto using wrapper class #9467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
feat: abstraction of xla::OpSharding proto using wrapper class #9467
Conversation
d0502ab
to
7fc15ea
Compare
7c4a3cd
to
1d55ae9
Compare
1ddbb1b
to
2756c1a
Compare
…_assignment() is empty
2756c1a
to
4556faa
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left comments. Overall, the PR looks very good.
const std::vector<std::shared_ptr<torch_xla::OpSharding>> GetShardings() | ||
const { | ||
return output_shardings_; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would keep GetShardings
consistend with SetSharding
and have its implementation be in ir.cpp
std::vector<int64_t> denormalized_tile_assignment = | ||
sharding->GetDenormalizedTileAssignment(); | ||
if (!denormalized_tile_assignment.empty()) { | ||
denormalized_tile_assignments_.push_back( | ||
sharding->GetDenormalizedTileAssignment()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: Mixing declaration types here is a bit confusing to read. I would consider:
- In line 247, using:
denormalized_tile_assignments_.push_back(denormalized_tile_assignment);
OR
2) Remove denormalized_tile_assignment
, and just check !sharding->GetDenormalizedTileAssignment().empty()
std::vector<std::vector<int64_t>> denormalized_tile_assignments; | ||
for (const auto* node : po_data->post_order) { | ||
const XlaNode* const casted = dynamic_cast<const XlaNode*>(node); | ||
auto shardings = casted->GetShardings(); | ||
if (!shardings.empty()) { | ||
for (auto sharding : shardings) { | ||
std::vector<int64_t> denormalized_tile_assignment = | ||
sharding->GetDenormalizedTileAssignment(); | ||
if (!denormalized_tile_assignment.empty()) { | ||
denormalized_tile_assignments.push_back( | ||
sharding->GetDenormalizedTileAssignment()); | ||
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For posterity, a quick comment to outline what this is doing might be useful.
if (sharding.tile_assignment_devices().empty() && !tile_assignment.empty()) { | ||
// Convert the Python list tile_assignment to a flattened vector for | ||
// denormalized assignment | ||
xla::Array<int64_t> tile_array = TileListToArray(tile_assignment); | ||
denormalized_tile_assignment.assign(tile_array.begin(), tile_array.end()); | ||
} else { | ||
// Use the tile_assignment_devices from the XLA OpSharding object | ||
denormalized_tile_assignment.assign( | ||
sharding.tile_assignment_devices().begin(), | ||
sharding.tile_assignment_devices().end()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT suggestion: we can remove if branching here by doing something like:
// Use the tile_assignment_devices from the XLA OpSharding object by default
denormalized_tile_assignment.assign(
sharding.tile_assignment_devices().begin(),
sharding.tile_assignment_devices().end());
if (sharding.tile_assignment_devices().empty() && !tile_assignment.empty()) {
// Convert the Python list tile_assignment to a flattened vector for
// denormalized assignment
xla::Array<int64_t> tile_array = TileListToArray(tile_assignment);
denormalized_tile_assignment.assign(tile_array.begin(), tile_array.end());
}
} | ||
} | ||
if (input_shardings.size() == 0) { | ||
if (xla_input_shardings.size() == 0) { | ||
TF_VLOG(3) << "ReshardParamters... skip with empty input_shardings."; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: "ReshardParamters... skip with empty xla_input_shardings."
sharding_spec->sharding.GetDenormalizedTileAssignment(); | ||
} | ||
for (const auto& sharding : xla_input_shardings) { | ||
if (denormalized_tile_assignment.size() > 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this generates an odd use case. I am not sure if it would ever get triggered, but if (*tensors)[0]->sharding_spec()
does not exist, denormalized_tile_assignment
never gets created, which might generate a error here. The way to ensure this is to move the for loop into the if (sharding_spec)
check.
auto xla_sharding = xla::HloSharding::Tile({ | ||
{0, 1, 2, 3}, | ||
{4, 5, 6, 7}, | ||
}) | ||
.ToProto(); | ||
std::vector<int64_t> denormalized_tile_assignment = {0, 1, 2, 3, 4, 5, 6, 7}; | ||
torch_xla::OpSharding sharding(xla_sharding, denormalized_tile_assignment); | ||
XLATensor::ShardingSpec tiled_2d(sharding, tensor_shape); | ||
xla_sharding = | ||
xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto(); | ||
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment); | ||
XLATensor::ShardingSpec tiled_3d(sharding, tensor_shape); | ||
xla_sharding = xla::HloSharding::Replicate().ToProto(); | ||
sharding = torch_xla::OpSharding(xla_sharding, denormalized_tile_assignment); | ||
XLATensor::ShardingSpec replicated(sharding, tensor_shape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit confusing. Specifically, xla_sharding
is used and redefined multiple times through the test. I would perhaps separate these into separate instances. It is less efficient, but I would favor readability for the test
IfrtComputation( | ||
xla::XlaComputation computation, std::vector<std::string> devices, | ||
std::shared_ptr<xla::ifrt::LoadedExecutable> executable, | ||
std::optional<std::vector<int64_t>> denormalized_tile_assignment) | ||
: Computation(std::move(computation), std::move(devices)), | ||
executable(std::move(executable)) { | ||
output_shardings_ = this->executable->GetOutputShardings(); | ||
executable(std::move(executable)), | ||
denormalized_tile_assignment_(std::move( | ||
denormalized_tile_assignment.value_or(std::vector<int64_t>{}))) { | ||
xla_output_shardings_ = this->executable->GetOutputShardings(); | ||
if (xla_output_shardings_.has_value()) { | ||
output_shardings_ = std::vector<torch_xla::OpSharding>{}; | ||
output_shardings_->reserve(xla_output_shardings_.value().size()); | ||
for (const auto& sharding : xla_output_shardings_.value()) { | ||
// convert each into torch_xla::OpSharding object | ||
torch_xla::OpSharding torch_xla_op_sharding( | ||
sharding, denormalized_tile_assignment_); | ||
output_shardings_.value().push_back(torch_xla_op_sharding); | ||
} | ||
} else { | ||
output_shardings_ = std::nullopt; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to move the constructor to ifrt_computation_client.cpp. If you can do here, great. If you don't want to do it, please create an issue and link a TODO from here for posterity.
if (xla_output_shardings_.has_value()) { | ||
output_shardings_ = std::vector<torch_xla::OpSharding>{}; | ||
output_shardings_->reserve(xla_output_shardings_.value().size()); | ||
for (const auto& sharding : xla_output_shardings_.value()) { | ||
// convert each into torch_xla::OpSharding object | ||
torch_xla::OpSharding torch_xla_op_sharding( | ||
sharding, denormalized_tile_assignment_); | ||
output_shardings_.value().push_back(torch_xla_op_sharding); | ||
} | ||
} else { | ||
output_shardings_ = std::nullopt; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can simplify this if statement tree by setting output_shardings_ = std::nullopt;
as default. Ex.:
output_shardings_ = std::nullopt;
if (xla_output_shardings_.has_value()) {
output_shardings_ = std::vector<torch_xla::OpSharding>{};
output_shardings_->reserve(xla_output_shardings_.value().size());
for (const auto& sharding : xla_output_shardings_.value()) {
// convert each into torch_xla::OpSharding object
torch_xla::OpSharding torch_xla_op_sharding(
sharding, denormalized_tile_assignment_);
output_shardings_.value().push_back(torch_xla_op_sharding);
}
}
* @param executable The compiled PJRT executable | ||
* @param denormalized_tile_assignment Optional tile assignment for sharding | ||
*/ | ||
PjRtComputation( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same notes as ifrt_computation_client.cpp
This PR includes the changes related to abstracting
xla::OpSharidng
proto object into atorch_xla::OpSharding
wrapper class.This new class object will not have the requirements of xla::OpSharding (however, it will be an extension xla::OpSharding proto defined over here).
We have defined the wrapper class in torch/xla which will construct an xla::OpSharding object with additional fields such as global_device_ids/global_tile_assignment and will have forwarded/proxy functions to xla::OpSharding . These forwarded functions will help user still make use of the same
xla::OpSharding
APIs as they normally would. We can also define torch_xla specific functions in this wrapper class to further use the extra fields that were stored during the initialization of the OpSharding object. This approach also allows the flexibility of converting the torch_xla::OpSharding object back to xla::OpSharding while lowering into HLO, thus, giving user the flexibility to use the abstracted class (and other additional fields stored) anywhere in the code base as needed, this is particularly useful since the XLA's HLOs are 0th indexed, hence we need to use the normalized_device_ids (starting from index 0) when lowering the program into the HLO, whereas we can still use the denormalized/global_device_ids in other places such as inside pjrt client to set the device_assignment using the user specified device_ids.Component diagram for reference -

Ref issue - #9390