Skip to content

feat: abstraction of xla::OpSharding proto using wrapper class #9467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
354 changes: 277 additions & 77 deletions test/cpp/test_xla_sharding.cpp

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ ptxla_cc_library(
":shape_builder",
":shape_helper",
":status",
":torch_xla_op_sharding",
":version",
"//torch_xla/csrc:hash_util",
"//torch_xla/csrc:thread_pool",
Expand Down Expand Up @@ -316,6 +317,7 @@ ptxla_cc_library(
":shape_helper",
":status",
":unwrap_data",
":torch_xla_op_sharding",
"//torch_xla/csrc/runtime:cache",
"//torch_xla/csrc/runtime:computation_client",
"@com_google_absl//absl/log:absl_check",
Expand Down Expand Up @@ -385,3 +387,13 @@ cc_library(
"@com_google_absl//absl/status:statusor",
],
)

cc_library(
name = "torch_xla_op_sharding",
srcs = ["torch_xla_op_sharding.cpp"],
hdrs = ["torch_xla_op_sharding.h"],
deps = [
"//torch_xla/csrc/runtime:debug_macros",
"@xla//xla/hlo/builder:xla_builder",
],
)
4 changes: 3 additions & 1 deletion torch_xla/csrc/debug_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/xla_util.h"
#include "torch_xla/csrc/status.h"
#include "torch_xla/csrc/torch_xla_op_sharding.h"
#include "torch_xla/csrc/xla_graph_executor.h"

namespace torch_xla {
Expand Down Expand Up @@ -218,7 +219,8 @@ void DebugUtil::SaveOutputShardingInfo(std::vector<XLATensorPtr>* tensors,
auto xtensor = (*tensors)[indices[i]];
ss << xtensor->shape().get().ToString() << " ";
if (xtensor->sharding_spec()) {
ss << xla::HloSharding::FromProto(xtensor->sharding_spec()->sharding)
ss << xla::HloSharding::FromProto(
xtensor->sharding_spec()->sharding.GetXlaOpSharding())
->ToString();
} else {
ss << xla::HloSharding::FromProto(xla::HloSharding::Unknown().ToProto())
Expand Down
23 changes: 13 additions & 10 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
#include "torch_xla/csrc/tensor_methods.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/torch_util.h"
#include "torch_xla/csrc/torch_xla_op_sharding.h"
#include "torch_xla/csrc/version.h"
#include "torch_xla/csrc/xla_backend_impl.h"
#include "torch_xla/csrc/xla_graph_executor.h"
Expand Down Expand Up @@ -740,7 +741,8 @@ std::string GetTensorsHloGraph(const std::vector<at::Tensor>& tensors,
std::string GetXLAShardingSpec(const XLATensorPtr xtensor) {
auto sharding_spec = xtensor->sharding_spec();
if (sharding_spec != nullptr) {
auto hlo_sharding = xla::HloSharding::FromProto(sharding_spec->sharding);
auto hlo_sharding =
xla::HloSharding::FromProto(sharding_spec->sharding.GetXlaOpSharding());
return hlo_sharding->ToString();
}
return std::string();
Expand Down Expand Up @@ -1540,7 +1542,7 @@ void InitXlaModuleBindings(py::module m) {
runtime::ComputationClient::ComputationPtr>(m, "XlaComputation");

// Define the _XLAC.OpSharding class.
PythonScope<py::class_<xla::OpSharding>>(m, "OpSharding")
PythonScope<py::class_<torch_xla::OpSharding>>(m, "OpSharding")
.def_init([](const py::list& tile_assignment,
const py::list& group_assignment,
const py::list& replication_groups, int sharding_type) {
Expand Down Expand Up @@ -2559,16 +2561,16 @@ void InitXlaModuleBindings(py::module m) {
}
})
.def("_xla_mark_sharding",
[](const at::Tensor& input, xla::OpSharding sharding) {
[](const at::Tensor& input, torch_xla::OpSharding sharding) {
ShardingUtil::XlaMarkSharding(input, sharding);
})
.def("_xla_annotate_custom_sharding",
[](const at::Tensor& input, xla::OpSharding sharding) {
[](const at::Tensor& input, torch_xla::OpSharding sharding) {
XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input));
ShardingUtil::XlaAnnotateCustomSharding(xtensor, sharding);
})
.def("_mark_manual_sharding",
[](const at::Tensor& input, xla::OpSharding sharding) {
[](const at::Tensor& input, torch_xla::OpSharding sharding) {
XLA_CHECK(IsNonDeviceDataIR(input))
<< "Marking any data tensors as manual is not supported";
ShardingUtil::XlaMarkSharding(input, sharding);
Expand All @@ -2588,13 +2590,14 @@ void InitXlaModuleBindings(py::module m) {
xtensor->CreateFrom(torch_xla::MakeNode<CustomSharding>(
xtensor->GetIrValue(), shard_shape,
CustomSharding::Type::kSPMDFullToShardShape));
output->SetShardingSpec(XLATensor::ShardingSpec(
xla::HloSharding::Manual().ToProto(), shard_shape));
torch_xla::OpSharding sharding(xla::HloSharding::Manual().ToProto(),
sharding_spec->sharding.GetDenormalizedTileAssignment());
output->SetShardingSpec(XLATensor::ShardingSpec(sharding, shard_shape));
return bridge::AtenFromXlaTensor(output);
})
.def(
"_spmd_shard_to_full_shape",
[](const at::Tensor& input, const xla::OpSharding& sharding,
[](const at::Tensor& input, const torch_xla::OpSharding& sharding,
const std::vector<int64_t>& output_shape,
const py::object& output_dtype) -> at::Tensor {
XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input));
Expand Down Expand Up @@ -2628,7 +2631,7 @@ void InitXlaModuleBindings(py::module m) {
return GetXLAShardingSpec(xtensor);
})
.def("_get_xla_op_sharding",
[](const at::Tensor& input) -> std::optional<xla::OpSharding> {
[](const at::Tensor& input) -> std::optional<torch_xla::OpSharding> {
XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input));
XLATensor::ShardingSpecPtr sharding_spec =
xtensor ? xtensor->sharding_spec() : nullptr;
Expand Down Expand Up @@ -2668,7 +2671,7 @@ void InitXlaModuleBindings(py::module m) {
// `torch_xla.runtime.local_runtime_devices()`.
"_global_tensor_from_cpu_shards",
[](const std::vector<at::Tensor>& shards,
const xla::OpSharding& sharding,
const torch_xla::OpSharding& sharding,
std::optional<std::vector<int64_t>>& global_shape) -> at::Tensor {
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
Expand Down
14 changes: 9 additions & 5 deletions torch_xla/csrc/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "torch_xla/csrc/runtime/cache.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/torch_xla_op_sharding.h"

namespace torch_xla {
namespace {
Expand Down Expand Up @@ -167,12 +168,12 @@ torch::lazy::hash_t XlaNode::GetOpHash(torch::lazy::OpKind op,
return torch::lazy::HashCombine(h, hash_seed);
}

void XlaNode::SetSharding(const xla::OpSharding& sharding, size_t index) {
void XlaNode::SetSharding(const torch_xla::OpSharding& sharding, size_t index) {
if (output_shardings_.size() == 0) {
output_shardings_ =
std::vector<std::shared_ptr<xla::OpSharding>>(num_outputs(), nullptr);
output_shardings_ = std::vector<std::shared_ptr<torch_xla::OpSharding>>(
num_outputs(), nullptr);
}
output_shardings_[index] = std::make_shared<xla::OpSharding>(sharding);
output_shardings_[index] = std::make_shared<torch_xla::OpSharding>(sharding);
// TODO(JackCaoG): fix this hashing
UpdateShardingHash();
}
Expand Down Expand Up @@ -207,7 +208,10 @@ void XlaNode::UpdateShardingHash() {
for (size_t i = 0; i < output_shardings_.size(); i++) {
// keep the index as part of the hash
sharding_hash_ = torch::lazy::HashCombine(sharding_hash_, (uint32_t)i);
std::shared_ptr<xla::OpSharding> sharding = output_shardings_[i];
std::shared_ptr<torch_xla::OpSharding> sharding =
std::make_shared<torch_xla::OpSharding>(
output_shardings_[i]->GetXlaOpSharding(),
output_shardings_[i]->GetDenormalizedTileAssignment());
// skip the hash compute for empty sharding
if (!sharding) {
continue;
Expand Down
12 changes: 9 additions & 3 deletions torch_xla/csrc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "absl/types/span.h"
#include "torch_xla/csrc/dynamic_shape_detector.h"
#include "torch_xla/csrc/runtime/types.h"
#include "torch_xla/csrc/torch_xla_op_sharding.h"
#include "xla/hlo/builder/xla_builder.h"

namespace torch_xla {
Expand Down Expand Up @@ -133,14 +134,19 @@ class XlaNode : public torch::lazy::Node {
torch::lazy::hash_t shardingHash() const { return sharding_hash_; }

// The node's outputs get assigned the same HLO sharding
const std::shared_ptr<xla::OpSharding> GetSharding(size_t index) const {
const std::shared_ptr<torch_xla::OpSharding> GetSharding(size_t index) const {
if (output_shardings_.size() == 0) {
return nullptr;
}
return output_shardings_[index];
}

void SetSharding(const xla::OpSharding& sharding, size_t index);
const std::vector<std::shared_ptr<torch_xla::OpSharding>> GetShardings()
const {
return output_shardings_;
}
Comment on lines +144 to +147
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep GetShardings consistend with SetSharding and have its implementation be in ir.cpp

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decided to keep the GetShardings consistent with GetSharding which is implemented in ir.h, hence, declared and defined it over here, lmk if this is major concern will move it to ir.cpp


void SetSharding(const torch_xla::OpSharding& sharding, size_t index);

void ClearSharding() {
output_shardings_.clear();
Expand Down Expand Up @@ -180,7 +186,7 @@ class XlaNode : public torch::lazy::Node {
torch::lazy::hash_t sharding_hash_ = 0;

// Experimental sharding annotations attached to the IR node.
std::vector<std::shared_ptr<xla::OpSharding>> output_shardings_;
std::vector<std::shared_ptr<torch_xla::OpSharding>> output_shardings_;
};

inline std::ostream& operator<<(std::ostream& stream, const XlaNode& node) {
Expand Down
24 changes: 21 additions & 3 deletions torch_xla/csrc/lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/stack_frame_index_builder.h"
#include "torch_xla/csrc/status.h"
#include "torch_xla/csrc/torch_xla_op_sharding.h"

namespace torch_xla {

Expand Down Expand Up @@ -133,9 +134,9 @@ xla::XlaOp LoweringContext::GetParameter(
const std::string param_name = absl::StrCat("p", param_index);
xla::XlaOp param;
if (data->HasSharding()) {
const xla::OpSharding sharding = data->GetSharding();
const xla::XlaScopedShardingAssignment scoped_sharding(builder(),
sharding);
const torch_xla::OpSharding sharding = data->GetSharding();
const xla::XlaScopedShardingAssignment scoped_sharding(
builder(), sharding.GetXlaOpSharding());
param = xla::Parameter(builder(), param_index, shape, param_name);
} else {
param = xla::Parameter(builder(), param_index, shape, param_name);
Expand Down Expand Up @@ -237,13 +238,30 @@ xla::XlaOp LoweringContext::GetOutputOp(const torch::lazy::Output& output) {
return it->second;
}

void LoweringContext::ExtractShardingAndSetDenormalizedTileAssignments(
std::vector<std::shared_ptr<torch_xla::OpSharding>> shardings) {
for (auto sharding : shardings) {
std::vector<int64_t> denormalized_tile_assignment =
sharding->GetDenormalizedTileAssignment();
if (!denormalized_tile_assignment.empty()) {
denormalized_tile_assignments_.push_back(denormalized_tile_assignment);
}
}
}

XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node& node) {
XlaOpVector result_ops;
try {
const HloMetadataSetter meta_setter(*this, node);
const XlaNode* const casted = dynamic_cast<const XlaNode*>(&node);

result_ops = casted->Lower(this);
// save the denormalized_tile_assignment from all nodes and then use it
// during Compile
auto shardings = casted->GetShardings();
if (!shardings.empty()) {
ExtractShardingAndSetDenormalizedTileAssignments(shardings);
}
if (!casted->dynamic_dims().empty()) {
const xla::internal::XlaBuilderFriend builder_friend;
auto* const inst = builder_friend.GetInstruction(result_ops[0]);
Expand Down
9 changes: 9 additions & 0 deletions torch_xla/csrc/lowering_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ class LoweringContext : public torch::lazy::LoweringContext {
int64_t AddStackFrameLocation(const torch::lazy::SourceLocation& source,
int64_t parent_id);

void ExtractShardingAndSetDenormalizedTileAssignments(
std::vector<std::shared_ptr<torch_xla::OpSharding>>);

const std::vector<std::vector<int64_t>>& GetDenormalizedTileAssignments()
const {
return denormalized_tile_assignments_;
}

private:
struct Parameter {
xla::XlaOp param;
Expand All @@ -135,6 +143,7 @@ class LoweringContext : public torch::lazy::LoweringContext {
std::string name_;

std::shared_ptr<StackFrameIndexBuilder> stack_frame_index_builder_;
std::vector<std::vector<int64_t>> denormalized_tile_assignments_;
}; // namespace torch_xla

} // namespace torch_xla
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/ops/device_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "torch_xla/csrc/ops/xla_ops.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/torch_xla_op_sharding.h"

namespace torch_xla {

Expand All @@ -16,7 +17,7 @@ DeviceData::DeviceData(std::shared_ptr<torch::lazy::BackendData> data)
/*num_outputs=*/1,
/*hash_seed=*/(uint32_t)101),
data_(std::move(data)) {
std::optional<xla::OpSharding> op_sharding =
std::optional<torch_xla::OpSharding> op_sharding =
torch_xla::runtime::GetComputationClientOrDie()->GetDataSharding(
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(data_));
if (op_sharding.has_value()) {
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ cc_library(
"//torch_xla/csrc:device",
"//torch_xla/csrc:dtype",
"//torch_xla/csrc:status",
"//torch_xla/csrc:torch_xla_op_sharding",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
Expand Down Expand Up @@ -123,6 +124,7 @@ cc_library(
":tf_logging",
":xla_coordinator",
"//torch_xla/csrc:status",
"//torch_xla/csrc:torch_xla_op_sharding",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
Expand Down
17 changes: 11 additions & 6 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "torch_xla/csrc/runtime/types.h"
#include "torch_xla/csrc/runtime/util.h"
#include "torch_xla/csrc/status.h"
#include "torch_xla/csrc/torch_xla_op_sharding.h"
#include "xla/hlo/builder/xla_computation.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/literal_util.h"
Expand Down Expand Up @@ -80,7 +81,7 @@ class ComputationClient {

virtual bool HasSharding() const = 0;

virtual xla::OpSharding GetSharding() const = 0;
virtual torch_xla::OpSharding GetSharding() const = 0;

private:
std::string xla_device_;
Expand Down Expand Up @@ -228,6 +229,7 @@ class ComputationClient {
std::vector<std::string> devices, const xla::Shape* output_shape,
bool parameter_is_tupled_arguments = false, bool is_sharded = false,
bool allow_spmd_sharding_propagation_to_output = true,
std::vector<std::vector<int64_t>> denormalized_tile_assignments = {},
bool use_auto_spmd_partitioning = false,
std::vector<int64_t> auto_spmd_mesh_shape = {},
std::vector<int64_t> auto_spmd_mesh_ids = {}, bool eager_mode = false)
Expand All @@ -239,6 +241,7 @@ class ComputationClient {
is_sharded(is_sharded),
allow_spmd_sharding_propagation_to_output(
allow_spmd_sharding_propagation_to_output),
denormalized_tile_assignments(denormalized_tile_assignments),
use_auto_spmd_partitioning(use_auto_spmd_partitioning),
auto_spmd_mesh_shape(auto_spmd_mesh_shape),
auto_spmd_mesh_ids(auto_spmd_mesh_ids),
Expand All @@ -248,6 +251,7 @@ class ComputationClient {
std::string compilation_device;
std::vector<std::string> devices;
const xla::Shape* output_shape = nullptr;
std::vector<std::vector<int64_t>> denormalized_tile_assignments;
bool parameter_is_tupled_arguments;
bool is_sharded;
bool allow_spmd_sharding_propagation_to_output;
Expand All @@ -273,7 +277,7 @@ class ComputationClient {
// will be populated in an asynchrounous fashion.
virtual DataPtr CreateDataPlaceholder(
std::string device, xla::Shape shape,
std::optional<xla::OpSharding> sharding = std::nullopt) = 0;
std::optional<torch_xla::OpSharding> sharding = std::nullopt) = 0;

// Returns data shards. We expect this to be called on PjRtShardedData to
// retrieve the shards. If other data type is passed, it returns the input
Expand All @@ -286,11 +290,12 @@ class ComputationClient {
// Returns wrapped data shards as PjRtShardedData.
virtual DataPtr WrapDataShards(absl::Span<const DataPtr> shards,
std::string device, xla::Shape shape,
xla::OpSharding sharding) = 0;
torch_xla::OpSharding sharding) = 0;

// Returns OpSharding attached to PjRtShardedData. The returned optional
// structure will be empty if there is no sharding, like with PjRtData.
virtual std::optional<xla::OpSharding> GetDataSharding(DataPtr handle) = 0;
virtual std::optional<torch_xla::OpSharding> GetDataSharding(
DataPtr handle) = 0;

virtual std::string PjRtDeviceToString(
xla::PjRtDevice* const device) const = 0;
Expand All @@ -303,13 +308,13 @@ class ComputationClient {
// input sharding spec is identical to the target `sharding` sharding spec.
virtual std::vector<DataPtr> ReshardData(
absl::Span<const DataPtr> handles,
absl::Span<const xla::OpSharding> shardings) = 0;
absl::Span<const torch_xla::OpSharding> shardings) = 0;

// Transfers local sharded tensor values to the TPU devices and returns a
// `PjRtShardedData`.
virtual DataPtr TransferShardsToDevice(
absl::Span<const std::shared_ptr<const TensorSource>> tensor_shards,
std::string device, xla::Shape shape, xla::OpSharding sharding) = 0;
std::string device, xla::Shape shape, torch_xla::OpSharding sharding) = 0;

// Copies `data->buffer` to `dst` device buffer.
virtual DataPtr CopyToDevice(DataPtr data, std::string dst) = 0;
Expand Down
Loading
Loading