Skip to content

Commit 7c4a3cd

Browse files
author
kvshbg-aws
committed
use tensors to get denormalized_tile_assignment directly instead of po_data
1 parent 0072e85 commit 7c4a3cd

File tree

8 files changed

+47
-37
lines changed

8 files changed

+47
-37
lines changed

test/cpp/test_xla_sharding.cpp

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -413,20 +413,28 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
413413
xla::XlaComputation xla_computation =
414414
GetValueOrThrow(b.Build(/*remove_dynamic_dimensions=*/false));
415415

416-
std::vector<torch::lazy::BackendDataPtr> parameters_data;
417-
parameters_data.push_back(
416+
std::vector<XLATensorPtr> tensors{XLATensor::Create(
418417
torch_xla::runtime::GetComputationClientOrDie()->CreateDataPlaceholder(
419-
bridge::GetDefaultDevice()->toString(), std::move(shape)));
418+
bridge::GetDefaultDevice()->toString(), std::move(shape)))};
419+
std::vector<std::vector<int64_t>> denormalized_tile_assignments;
420+
for (auto tensor : tensors) {
421+
auto sharding_spec = tensor->sharding_spec();
422+
if (sharding_spec) {
423+
denormalized_tile_assignments.push_back(
424+
sharding_spec->sharding.GetDenormalizedTileAssignment());
425+
}
426+
}
420427

421428
std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
422-
instances.push_back({std::move(xla_computation),
423-
bridge::GetDefaultDevice()->toString(),
424-
{bridge::GetDefaultDevice()->toString()},
425-
&shape,
426-
/*should_wrap_parameter=*/false,
427-
/*is_sharded=*/true,
428-
/*allow_spmd_sharding_propagation_to_output=*/true,
429-
/*parameters_data=*/parameters_data});
429+
instances.push_back(
430+
{std::move(xla_computation),
431+
bridge::GetDefaultDevice()->toString(),
432+
{bridge::GetDefaultDevice()->toString()},
433+
&shape,
434+
/*should_wrap_parameter=*/false,
435+
/*is_sharded=*/true,
436+
/*allow_spmd_sharding_propagation_to_output=*/true,
437+
/*denormalized_tile_assignments=*/denormalized_tile_assignments});
430438

431439
std::vector<
432440
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
@@ -437,9 +445,6 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
437445
"add", std::move(computations[0]->move_computation()));
438446

439447
// Prepare output sharding propagation, expect a sharded output placeholder.
440-
std::vector<XLATensorPtr> tensors{XLATensor::Create(
441-
torch_xla::runtime::GetComputationClientOrDie()->CreateDataPlaceholder(
442-
bridge::GetDefaultDevice()->toString(), std::move(shape)))};
443448
std::vector<torch::lazy::BackendDataPtr> data_placeholders;
444449
std::vector<XLATensor::ShardingSpecPtr> sharding_specs;
445450
ShardingUtil::PrepareOutputShardingPropagation(

torch_xla/csrc/ir.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,10 @@ void XlaNode::UpdateShardingHash() {
208208
for (size_t i = 0; i < output_shardings_.size(); i++) {
209209
// keep the index as part of the hash
210210
sharding_hash_ = torch::lazy::HashCombine(sharding_hash_, (uint32_t)i);
211-
std::shared_ptr<xla::OpSharding> sharding =
212-
std::make_shared<xla::OpSharding>(
213-
output_shardings_[i]->GetXlaOpSharding());
211+
std::shared_ptr<torch_xla::OpSharding> sharding =
212+
std::make_shared<torch_xla::OpSharding>(
213+
output_shardings_[i]->GetXlaOpSharding(),
214+
output_shardings_[i]->GetDenormalizedTileAssignment());
214215
// skip the hash compute for empty sharding
215216
if (!sharding) {
216217
continue;

torch_xla/csrc/runtime/computation_client.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ class ComputationClient {
228228
std::vector<std::string> devices, const xla::Shape* output_shape,
229229
bool parameter_is_tupled_arguments = false, bool is_sharded = false,
230230
bool allow_spmd_sharding_propagation_to_output = true,
231-
std::vector<torch::lazy::BackendDataPtr> parameters_data = {},
231+
std::vector<std::vector<int64_t>> denormalized_tile_assignments = {},
232232
bool use_auto_spmd_partitioning = false,
233233
std::vector<int64_t> auto_spmd_mesh_shape = {},
234234
std::vector<int64_t> auto_spmd_mesh_ids = {}, bool eager_mode = false)
@@ -240,7 +240,7 @@ class ComputationClient {
240240
is_sharded(is_sharded),
241241
allow_spmd_sharding_propagation_to_output(
242242
allow_spmd_sharding_propagation_to_output),
243-
parameters_data(parameters_data),
243+
denormalized_tile_assignments(denormalized_tile_assignments),
244244
use_auto_spmd_partitioning(use_auto_spmd_partitioning),
245245
auto_spmd_mesh_shape(auto_spmd_mesh_shape),
246246
auto_spmd_mesh_ids(auto_spmd_mesh_ids),
@@ -250,7 +250,7 @@ class ComputationClient {
250250
std::string compilation_device;
251251
std::vector<std::string> devices;
252252
const xla::Shape* output_shape = nullptr;
253-
std::vector<torch::lazy::BackendDataPtr> parameters_data;
253+
std::vector<std::vector<int64_t>> denormalized_tile_assignments;
254254
bool parameter_is_tupled_arguments;
255255
bool is_sharded;
256256
bool allow_spmd_sharding_propagation_to_output;

torch_xla/csrc/runtime/ifrt_computation_client.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -471,14 +471,8 @@ std::vector<ComputationClient::ComputationPtr> IfrtComputationClient::Compile(
471471

472472
for (auto& instance : instances) {
473473
std::vector<int64_t> denormalized_tile_assignment;
474-
if (!instance.parameters_data.empty() && instance.parameters_data[0]) {
475-
auto sharding_opt = GetDataSharding(
476-
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
477-
instance.parameters_data[0]));
478-
if (sharding_opt.has_value()) {
479-
denormalized_tile_assignment =
480-
sharding_opt.value().GetDenormalizedTileAssignment();
481-
}
474+
if (!instance.denormalized_tile_assignments.empty()) {
475+
denormalized_tile_assignment = instance.denormalized_tile_assignments[0];
482476
}
483477

484478
xla::CompileOptions compile_options;

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -547,14 +547,8 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
547547

548548
for (auto& instance : instances) {
549549
std::vector<int64_t> denormalized_tile_assignment;
550-
if (!instance.parameters_data.empty() && instance.parameters_data[0]) {
551-
auto sharding_opt = GetDataSharding(
552-
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
553-
instance.parameters_data[0]));
554-
if (sharding_opt.has_value()) {
555-
denormalized_tile_assignment =
556-
sharding_opt.value().GetDenormalizedTileAssignment();
557-
}
550+
if (!instance.denormalized_tile_assignments.empty()) {
551+
denormalized_tile_assignment = instance.denormalized_tile_assignments[0];
558552
}
559553
xla::CompileOptions compile_options;
560554
if (enable_cm_in_mp) {

torch_xla/csrc/torch_xla_op_sharding.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ OpSharding::iota_transpose_perm() const {
9999
return op_sharding_->iota_transpose_perm();
100100
}
101101

102+
const ::google::protobuf::RepeatedField<int32_t>& OpSharding::last_tile_dims()
103+
const {
104+
return op_sharding_->last_tile_dims();
105+
}
106+
102107
const xla::ShapeProto& OpSharding::tile_shape() const {
103108
return op_sharding_->tile_shape();
104109
}

torch_xla/csrc/torch_xla_op_sharding.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class OpSharding {
5858
const ::google::protobuf::RepeatedField<int64_t>& tile_assignment_devices()
5959
const;
6060
const ::google::protobuf::RepeatedField<int32_t>& iota_transpose_perm() const;
61+
const ::google::protobuf::RepeatedField<int32_t>& last_tile_dims() const;
6162
const xla::ShapeProto& tile_shape() const;
6263

6364
// Access to underlying xla::OpSharding

torch_xla/csrc/xla_graph_executor.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1435,14 +1435,24 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
14351435
xla::Shape shape = MakeShapeWithDeviceLayout(
14361436
program_shape.result(), static_cast<XlaDeviceType>(coll.device.type()));
14371437

1438+
std::vector<std::vector<int64_t>> denormalized_tile_assignments;
1439+
for (auto tensor : tensors) {
1440+
auto sharding_spec = tensor->sharding_spec();
1441+
if (sharding_spec) {
1442+
denormalized_tile_assignments.push_back(
1443+
sharding_spec->sharding.GetDenormalizedTileAssignment());
1444+
} else {
1445+
TF_VLOG(5) << "no sharding spec for tensor - " << tensor;
1446+
}
1447+
}
14381448
std::vector<runtime::ComputationClient::CompileInstance> instances;
14391449
instances.push_back(
14401450
{std::move(computation), coll.device.toString(),
14411451
runtime::GetComputationClientOrDie()->GetCompilationDevices(
14421452
coll.device.toString(), devices),
14431453
&shape, should_wrap_parameter, is_sharded,
14441454
/*allow_spmd_sharding_propagation_to_output=*/true,
1445-
/*parameters_data=*/po_data->parameters_data});
1455+
/*denormalized_tile_assignments=*/denormalized_tile_assignments});
14461456
instances.front().eager_mode = UseEagerMode();
14471457
if (use_autosharding) {
14481458
TF_VLOG(5) << "use_auto_spmd_partitioning is set.";

0 commit comments

Comments
 (0)