Skip to content

Commit 7fc15ea

Browse files
kvshbg-awskvshbg-aws
authored andcommitted
feat: abstraction of xla::OpSharding proto using wrapper class
1 parent 6acc8f3 commit 7fc15ea

20 files changed

+588
-182
lines changed

test/cpp/test_xla_sharding.cpp

Lines changed: 65 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,18 @@ TEST_F(XLAShardingTest, GetShardShape) {
4949
{0, 1},
5050
{2, 3},
5151
});
52-
auto sharding = xla::HloSharding::Tile(mesh).ToProto();
52+
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
53+
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
5354
auto sharding_spec =
5455
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
5556

5657
auto shard_shape = ShardingUtil::GetShardShape(sharding_spec);
5758
// For tiled sharding, each dimension should be halved
5859
EXPECT_EQ(shard_shape, std::vector<int64_t>({4, 4}));
5960

60-
sharding_spec->sharding = xla::HloSharding::Replicate().ToProto();
61+
xla_sharding = xla::HloSharding::Replicate().ToProto();
62+
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
63+
sharding_spec->sharding = sharding;
6164
shard_shape = ShardingUtil::GetShardShape(sharding_spec);
6265
// For replicated sharding, each dimension should be preserved
6366
EXPECT_EQ(shard_shape, std::vector<int64_t>({8, 7}));
@@ -73,7 +76,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
7376
{0, 1},
7477
{2, 3},
7578
});
76-
auto sharding = xla::HloSharding::Tile(mesh).ToProto();
79+
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
80+
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
7781
auto sharding_spec =
7882
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
7983
auto shard_shape = ShardingUtil::GetShardShape(sharding_spec);
@@ -102,7 +106,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
102106
EXPECT_EQ(slice.step(), 1);
103107
}
104108
}
105-
sharding = xla::HloSharding::Replicate().ToProto();
109+
xla_sharding = xla::HloSharding::Replicate().ToProto();
110+
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
106111
sharding_spec->sharding = sharding;
107112
shard_shape = ShardingUtil::GetShardShape(sharding_spec);
108113
replica_and_indices = ShardingUtil::GetShardReplicaAndIndicesForDevices(
@@ -125,11 +130,12 @@ TEST_F(XLAShardingTest, ShardTensor) {
125130
at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat));
126131
xla::Shape tensor_shape =
127132
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
128-
xla::OpSharding sharding =
133+
xla::OpSharding xla_sharding =
129134
xla::HloSharding::Tile1D(
130135
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()),
131136
devices.size())
132137
.ToProto();
138+
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
133139
auto sharding_spec =
134140
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
135141
auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
@@ -147,7 +153,8 @@ TEST_F(XLAShardingTest, ShardTensor) {
147153
{0, 1, 2, 3},
148154
{4, 5, 6, 7},
149155
});
150-
sharding = xla::HloSharding::Tile(mesh).ToProto();
156+
xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
157+
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
151158
sharding_spec =
152159
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
153160
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
@@ -159,15 +166,19 @@ TEST_F(XLAShardingTest, ShardTensor) {
159166
// 3D tiled, the first dim is replicated and the last halved. The last shard
160167
// size should be smaller in dim=1 because it's not evenly divisible.
161168
xla::Array3D<int64_t> cube({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}});
162-
sharding_spec->sharding = xla::HloSharding::Tile(cube).ToProto();
169+
xla_sharding = xla::HloSharding::Tile(cube).ToProto();
170+
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
171+
sharding_spec->sharding = sharding;
163172
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
164173
/*padded=*/false);
165174
EXPECT_EQ(shards.size(), 8);
166175
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({8, 2, 2}));
167176
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({8, 1, 2}));
168177

169178
// Replicated, all shards should be identical.
170-
sharding_spec->sharding = xla::HloSharding::Replicate().ToProto();
179+
xla_sharding = xla::HloSharding::Replicate().ToProto();
180+
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
181+
sharding_spec->sharding = sharding;
171182
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
172183
/*padded=*/false);
173184
EXPECT_EQ(shards.size(), 8);
@@ -181,7 +192,8 @@ TEST_F(XLAShardingTest, ShardTensor) {
181192
tensor_shape =
182193
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
183194
xla::Array4D<int64_t> tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}});
184-
sharding = xla::HloSharding::Tile(tesseract).ToProto();
195+
xla_sharding = xla::HloSharding::Tile(tesseract).ToProto();
196+
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
185197
sharding_spec =
186198
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
187199
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
@@ -205,7 +217,8 @@ TEST_F(XLAShardingTest, ShardTensor) {
205217
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
206218
xla::Array<int64_t> hypercube(std::vector<int64_t>{1, 1, 2, 2, 2});
207219
hypercube.FillIota(0);
208-
sharding = xla::HloSharding::Tile(hypercube).ToProto();
220+
xla_sharding = xla::HloSharding::Tile(hypercube).ToProto();
221+
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
209222
sharding_spec =
210223
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
211224
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
@@ -233,7 +246,8 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
233246
{4, 5, 0, 1},
234247
{6, 7, 2, 3},
235248
});
236-
auto sharding = xla::HloSharding::Tile(mesh).ToProto();
249+
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
250+
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
237251
auto sharding_spec =
238252
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
239253
// For devices at the start of the mesh, all shards should have the same
@@ -250,7 +264,9 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
250264
{0, 1, 4, 5},
251265
{2, 3, 6, 7},
252266
});
253-
sharding_spec->sharding = xla::HloSharding::Tile(mesh).ToProto();
267+
xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
268+
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
269+
sharding_spec->sharding = sharding;
254270
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
255271
/*padded=*/false);
256272
EXPECT_EQ(shards.size(), 4);
@@ -277,7 +293,8 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) {
277293
{{7}},
278294
});
279295

280-
auto sharding = xla::HloSharding::Tile(mesh).ToProto();
296+
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
297+
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
281298
auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
282299
sharding, global_shape, /*minibatch=*/true);
283300
auto shards = ShardingUtil::ShardTensor(minibatch_tensor, sharding_spec,
@@ -291,17 +308,20 @@ TEST_F(XLAShardingTest, EqualShardingSpecs) {
291308
auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat));
292309
xla::Shape tensor_shape =
293310
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
294-
XLATensor::ShardingSpec tiled_2d(xla::HloSharding::Tile({
295-
{0, 1, 2, 3},
296-
{4, 5, 6, 7},
297-
})
298-
.ToProto(),
299-
tensor_shape);
300-
XLATensor::ShardingSpec tiled_3d(
301-
xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto(),
302-
tensor_shape);
303-
XLATensor::ShardingSpec replicated(xla::HloSharding::Replicate().ToProto(),
304-
tensor_shape);
311+
auto xla_sharding = xla::HloSharding::Tile({
312+
{0, 1, 2, 3},
313+
{4, 5, 6, 7},
314+
})
315+
.ToProto();
316+
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
317+
XLATensor::ShardingSpec tiled_2d(sharding, tensor_shape);
318+
xla_sharding =
319+
xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto();
320+
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
321+
XLATensor::ShardingSpec tiled_3d(sharding, tensor_shape);
322+
xla_sharding = xla::HloSharding::Replicate().ToProto();
323+
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
324+
XLATensor::ShardingSpec replicated(sharding, tensor_shape);
305325
EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_2d));
306326
EXPECT_FALSE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_3d));
307327
EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(replicated, replicated));
@@ -322,12 +342,17 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
322342
std::vector<std::string> devices(3);
323343
std::fill_n(devices.begin(), devices.size(),
324344
bridge::GetDefaultDevice()->toString());
345+
auto replicate_xla_sharding = xla::HloSharding::Replicate().ToProto();
346+
auto unknown_xla_sharding = xla::HloSharding::Unknown().ToProto();
347+
torch_xla::OpSharding replicate_sharding(replicate_xla_sharding,
348+
std::nullopt);
349+
torch_xla::OpSharding unknown_sharding(unknown_xla_sharding, std::nullopt);
325350
std::vector<XLATensor::ShardingSpecPtr> shardings = {
326351
nullptr,
327-
std::make_shared<XLATensor::ShardingSpec>(
328-
xla::HloSharding::Replicate().ToProto(), tensor_shape),
329-
std::make_shared<XLATensor::ShardingSpec>(
330-
xla::HloSharding::Unknown().ToProto(), tensor_shape)};
352+
std::make_shared<XLATensor::ShardingSpec>(replicate_sharding,
353+
tensor_shape),
354+
std::make_shared<XLATensor::ShardingSpec>(unknown_sharding,
355+
tensor_shape)};
331356
std::vector<torch::lazy::BackendDataPtr> tensors_data =
332357
CreateTensorsData(tensors, shardings, devices);
333358

@@ -386,13 +411,21 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
386411
auto y = xla::Add(x, xla::ConstantR0<float>(&b, 3));
387412
xla::XlaComputation xla_computation =
388413
ConsumeValue(b.Build(/*remove_dynamic_dimensions=*/false));
414+
415+
std::vector<torch::lazy::BackendDataPtr> parameters_data;
416+
parameters_data.push_back(
417+
torch_xla::runtime::GetComputationClientOrDie()->CreateDataPlaceholder(
418+
bridge::GetDefaultDevice()->toString(), std::move(shape)));
419+
389420
std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
390421
instances.push_back({std::move(xla_computation),
391422
bridge::GetDefaultDevice()->toString(),
392423
{bridge::GetDefaultDevice()->toString()},
393424
&shape,
394425
/*should_wrap_parameter=*/false,
395-
/*is_sharded=*/true});
426+
/*is_sharded=*/true,
427+
/*allow_spmd_sharding_propagation_to_output=*/true,
428+
/*parameters_data=*/parameters_data});
396429

397430
std::vector<
398431
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
@@ -416,11 +449,12 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
416449
if (n_devices > 1) {
417450
// Tiled sharding requires multiple devices.
418451
EXPECT_TRUE(xla::protobuf_util::HaveSameSerialization(
419-
tiled, sharding_specs[0]->sharding));
452+
tiled, sharding_specs[0]->sharding.GetXlaOpSharding()));
420453
} else {
421454
// Sincle device execution defaults to replication sharding.
422455
EXPECT_TRUE(xla::protobuf_util::HaveSameSerialization(
423-
xla::HloSharding::Replicate().ToProto(), sharding_specs[0]->sharding));
456+
xla::HloSharding::Replicate().ToProto(),
457+
sharding_specs[0]->sharding.GetXlaOpSharding()));
424458
}
425459

426460
// Check if the placeholder is on a SPMD device (sharded) with no real values.

torch_xla/csrc/BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ ptxla_cc_library(
125125
":layout_manager",
126126
":shape_builder",
127127
":shape_helper",
128+
":torch_xla_op_sharding",
128129
":version",
129130
"//torch_xla/csrc:hash_util",
130131
"//torch_xla/csrc:thread_pool",
@@ -311,6 +312,7 @@ ptxla_cc_library(
311312
":device",
312313
":shape_helper",
313314
":unwrap_data",
315+
":torch_xla_op_sharding",
314316
"//torch_xla/csrc/runtime:cache",
315317
"//torch_xla/csrc/runtime:computation_client",
316318
"@com_google_absl//absl/log:absl_check",
@@ -379,3 +381,13 @@ cc_library(
379381
"@com_google_absl//absl/status:statusor",
380382
],
381383
)
384+
385+
cc_library(
386+
name = "torch_xla_op_sharding",
387+
srcs = ["torch_xla_op_sharding.cpp"],
388+
hdrs = ["torch_xla_op_sharding.h"],
389+
deps = [
390+
"//torch_xla/csrc/runtime:debug_macros",
391+
"@xla//xla/hlo/builder:xla_builder",
392+
],
393+
)

torch_xla/csrc/debug_util.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "torch_xla/csrc/runtime/debug_macros.h"
2121
#include "torch_xla/csrc/runtime/sys_util.h"
2222
#include "torch_xla/csrc/runtime/xla_util.h"
23+
#include "torch_xla/csrc/torch_xla_op_sharding.h"
2324
#include "torch_xla/csrc/xla_graph_executor.h"
2425

2526
namespace torch_xla {
@@ -217,7 +218,8 @@ void DebugUtil::SaveOutputShardingInfo(std::vector<XLATensorPtr>* tensors,
217218
auto xtensor = (*tensors)[indices[i]];
218219
ss << xtensor->shape().get().ToString() << " ";
219220
if (xtensor->sharding_spec()) {
220-
ss << xla::HloSharding::FromProto(xtensor->sharding_spec()->sharding)
221+
ss << xla::HloSharding::FromProto(
222+
xtensor->sharding_spec()->sharding.GetXlaOpSharding())
221223
->ToString();
222224
} else {
223225
ss << xla::HloSharding::FromProto(xla::HloSharding::Unknown().ToProto())

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
#include "torch_xla/csrc/tensor_methods.h"
7171
#include "torch_xla/csrc/tensor_util.h"
7272
#include "torch_xla/csrc/torch_util.h"
73+
#include "torch_xla/csrc/torch_xla_op_sharding.h"
7374
#include "torch_xla/csrc/version.h"
7475
#include "torch_xla/csrc/xla_backend_impl.h"
7576
#include "torch_xla/csrc/xla_graph_executor.h"
@@ -707,7 +708,8 @@ std::string GetTensorsHloGraph(const std::vector<at::Tensor>& tensors,
707708
std::string GetXLAShardingSpec(const XLATensorPtr xtensor) {
708709
auto sharding_spec = xtensor->sharding_spec();
709710
if (sharding_spec != nullptr) {
710-
auto hlo_sharding = xla::HloSharding::FromProto(sharding_spec->sharding);
711+
auto hlo_sharding =
712+
xla::HloSharding::FromProto(sharding_spec->sharding.GetXlaOpSharding());
711713
return hlo_sharding->ToString();
712714
}
713715
return std::string();
@@ -1503,7 +1505,7 @@ void InitXlaModuleBindings(py::module m) {
15031505
runtime::ComputationClient::ComputationPtr>(m, "XlaComputation");
15041506

15051507
// Define the _XLAC.OpSharding class.
1506-
PythonScope<py::class_<xla::OpSharding>>(m, "OpSharding")
1508+
PythonScope<py::class_<torch_xla::OpSharding>>(m, "OpSharding")
15071509
.def_init([](const py::list& tile_assignment,
15081510
const py::list& group_assignment,
15091511
const py::list& replication_groups, int sharding_type) {
@@ -2268,6 +2270,7 @@ void InitXlaModuleBindings(py::module m) {
22682270
[](const std::vector<at::Tensor>& tensors, const std::string& device,
22692271
const std::vector<std::string>& devices,
22702272
bool emit_bytecode) -> py::bytes {
2273+
NoGilSection nogil;
22712274
EmitMode mode = emit_bytecode ? EmitMode::kStableHloBytecode
22722275
: EmitMode::kStableHloReadable;
22732276
std::vector<XLATensorPtr> xtensors;
@@ -2504,16 +2507,16 @@ void InitXlaModuleBindings(py::module m) {
25042507
}
25052508
})
25062509
.def("_xla_mark_sharding",
2507-
[](const at::Tensor& input, xla::OpSharding sharding) {
2510+
[](const at::Tensor& input, torch_xla::OpSharding sharding) {
25082511
ShardingUtil::XlaMarkSharding(input, sharding);
25092512
})
25102513
.def("_xla_annotate_custom_sharding",
2511-
[](const at::Tensor& input, xla::OpSharding sharding) {
2514+
[](const at::Tensor& input, torch_xla::OpSharding sharding) {
25122515
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
25132516
ShardingUtil::XlaAnnotateCustomSharding(xtensor, sharding);
25142517
})
25152518
.def("_mark_manual_sharding",
2516-
[](const at::Tensor& input, xla::OpSharding sharding) {
2519+
[](const at::Tensor& input, torch_xla::OpSharding sharding) {
25172520
XLA_CHECK(IsNonDeviceDataIR(input))
25182521
<< "Marking any data tensors as manual is not supported";
25192522
ShardingUtil::XlaMarkSharding(input, sharding);
@@ -2533,13 +2536,14 @@ void InitXlaModuleBindings(py::module m) {
25332536
xtensor->CreateFrom(torch_xla::MakeNode<CustomSharding>(
25342537
xtensor->GetIrValue(), shard_shape,
25352538
CustomSharding::Type::kSPMDFullToShardShape));
2536-
output->SetShardingSpec(XLATensor::ShardingSpec(
2537-
xla::HloSharding::Manual().ToProto(), shard_shape));
2539+
torch_xla::OpSharding sharding(xla::HloSharding::Manual().ToProto(),
2540+
sharding_spec->sharding.GetDenormalizedTileAssignment());
2541+
output->SetShardingSpec(XLATensor::ShardingSpec(sharding, shard_shape));
25382542
return bridge::AtenFromXlaTensor(output);
25392543
})
25402544
.def(
25412545
"_spmd_shard_to_full_shape",
2542-
[](const at::Tensor& input, const xla::OpSharding& sharding,
2546+
[](const at::Tensor& input, const torch_xla::OpSharding& sharding,
25432547
const std::vector<int64_t>& output_shape,
25442548
const py::object& output_dtype) -> at::Tensor {
25452549
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
@@ -2578,7 +2582,7 @@ void InitXlaModuleBindings(py::module m) {
25782582
XLATensor::ShardingSpecPtr sharding_spec =
25792583
xtensor ? xtensor->sharding_spec() : nullptr;
25802584
if (sharding_spec != nullptr) {
2581-
return sharding_spec->sharding;
2585+
return sharding_spec->sharding.GetXlaOpSharding();
25822586
}
25832587
return std::nullopt;
25842588
})
@@ -2613,7 +2617,7 @@ void InitXlaModuleBindings(py::module m) {
26132617
// `torch_xla.runtime.local_runtime_devices()`.
26142618
"_global_tensor_from_cpu_shards",
26152619
[](const std::vector<at::Tensor>& shards,
2616-
const xla::OpSharding& sharding,
2620+
const torch_xla::OpSharding& sharding,
26172621
std::optional<std::vector<int64_t>>& global_shape) -> at::Tensor {
26182622
XLA_CHECK(UseVirtualDevice())
26192623
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";

0 commit comments

Comments
 (0)