Skip to content

Commit 30d9f62

Browse files
committed
Make ExecuteComputation and ExecuteReplicated return StatusOr<T>
Key changes: - Updated base `ComputationClient` interface to return `absl::StatusOr<std::vector<DataPtr>>` - Modified IFRT and PjRt implementations to use proper error propagation - Replaced raw `.value()` calls with `XLA_ASSIGN_OR_RETURN_WITH_LOCATION` macros - Updated all call sites to use `GetValueOrThrow` for exception-based error handling
1 parent 29ae4c7 commit 30d9f62

9 files changed

+69
-57
lines changed

torch_xla/csrc/runtime/computation_client.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <vector>
1717

1818
#include "absl/container/flat_hash_map.h"
19+
#include "absl/status/statusor.h"
1920
#include "absl/types/optional.h"
2021
#include "absl/types/span.h"
2122
#include "torch_xla/csrc/device.h"
@@ -346,7 +347,7 @@ class ComputationClient {
346347
// The passed device must match the common device of the arguments Data.
347348
// If options.explode_tuple is true, the output tuple will be decomposed into
348349
// its single elements.
349-
virtual std::vector<DataPtr> ExecuteComputation(
350+
virtual absl::StatusOr<std::vector<DataPtr>> ExecuteComputation(
350351
const Computation& computation, absl::Span<const DataPtr> arguments,
351352
const std::string& device,
352353
const ExecuteComputationOptions& options =
@@ -357,7 +358,7 @@ class ComputationClient {
357358
// as `devices`. If options.explode_tuple is true, the output tuples will be
358359
// decomposed into their single elements. Returns a vector of outputs, each
359360
// of which is sharded in the same order as `devices`.
360-
virtual std::vector<DataPtr> ExecuteReplicated(
361+
virtual absl::StatusOr<std::vector<DataPtr>> ExecuteReplicated(
361362
const Computation& computation, absl::Span<const DataPtr> arguments,
362363
absl::Span<const std::string> devices,
363364
const ExecuteReplicatedOptions& options) = 0;

torch_xla/csrc/runtime/ifrt_computation_client.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <unordered_set>
55
#include <vector>
66

7+
#include "absl/log/absl_check.h"
78
#include "absl/strings/ascii.h"
89
#include "absl/synchronization/blocking_counter.h"
910
#include "absl/types/span.h"
@@ -416,8 +417,8 @@ tsl::RCReference<xla::ifrt::Array> IfrtComputationClient::ReplicateShardedData(
416417
torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
417418
execute_options;
418419

419-
auto sharded_results = ExecuteReplicated(*computations.front(), {{handle}},
420-
GetLocalDevices(), execute_options);
420+
auto sharded_results = GetValueOrThrow(ExecuteReplicated(
421+
*computations.front(), {{handle}}, GetLocalDevices(), execute_options));
421422
auto replicated_output =
422423
std::dynamic_pointer_cast<IfrtData>(sharded_results[0])
423424
->buffer->FullyReplicatedShard(
@@ -537,16 +538,16 @@ std::vector<ComputationClient::ComputationPtr> IfrtComputationClient::Compile(
537538
return computations;
538539
}
539540

540-
std::vector<ComputationClient::DataPtr>
541+
absl::StatusOr<std::vector<ComputationClient::DataPtr>>
541542
IfrtComputationClient::ExecuteComputation(
542543
const ComputationClient::Computation& computation,
543544
absl::Span<const ComputationClient::DataPtr> arguments,
544545
const std::string& device, const ExecuteComputationOptions& options) {
545546
// TODO: Implement sharded exec in IFRT
546-
XLA_ERROR() << __FUNCTION__ << " not implemented";
547+
return absl::UnimplementedError("ExecuteComputation not implemented");
547548
}
548549

549-
std::vector<ComputationClient::DataPtr>
550+
absl::StatusOr<std::vector<ComputationClient::DataPtr>>
550551
IfrtComputationClient::ExecuteReplicated(
551552
const ComputationClient::Computation& computation,
552553
const absl::Span<const ComputationClient::DataPtr> arguments,
@@ -591,11 +592,10 @@ IfrtComputationClient::ExecuteReplicated(
591592
TF_VLOG(5) << "ExecuteReplicated acquiring IFRT device lock for "
592593
<< spmd_device_str << " Done";
593594

594-
xla::ifrt::LoadedExecutable::ExecuteResult result =
595-
ifrt_computation.executable
596-
->Execute(absl::MakeSpan(argument_handles), execute_options,
597-
std::nullopt)
598-
.value();
595+
XLA_ASSIGN_OR_RETURN_WITH_LOCATION(
596+
xla::ifrt::LoadedExecutable::ExecuteResult result,
597+
ifrt_computation.executable->Execute(absl::MakeSpan(argument_handles),
598+
execute_options, std::nullopt));
599599

600600
result.status.OnReady(std::move([timed, op_tracker = std::move(op_tracker)](
601601
absl::Status status) mutable {
@@ -612,7 +612,7 @@ IfrtComputationClient::ExecuteReplicated(
612612
? *ifrt_computation.output_shardings_
613613
: std::vector(outputs.size(),
614614
xla::HloSharding::Replicate().ToProto());
615-
XLA_CHECK_EQ(output_shardings.size(), outputs.size());
615+
ABSL_CHECK_EQ(output_shardings.size(), outputs.size());
616616

617617
std::vector<ComputationClient::DataPtr> data_handles(outputs.size());
618618
{

torch_xla/csrc/runtime/ifrt_computation_client.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,12 @@ class IfrtComputationClient : public ComputationClient {
7878
std::vector<ComputationPtr> Compile(
7979
std::vector<CompileInstance> instances) override;
8080

81-
std::vector<DataPtr> ExecuteComputation(
81+
absl::StatusOr<std::vector<DataPtr>> ExecuteComputation(
8282
const Computation& computation, absl::Span<const DataPtr> arguments,
8383
const std::string& device,
8484
const ExecuteComputationOptions& options) override;
8585

86-
std::vector<DataPtr> ExecuteReplicated(
86+
absl::StatusOr<std::vector<DataPtr>> ExecuteReplicated(
8787
const Computation& computation, const absl::Span<const DataPtr> arguments,
8888
absl::Span<const std::string> devices,
8989
const ExecuteReplicatedOptions& options) override;

torch_xla/csrc/runtime/ifrt_computation_client_test.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,10 @@ TEST(PjRtComputationClientTest, Init) {
6464
std::make_shared<LiteralSource>(std::move(literal_y), device)};
6565

6666
// Execute the graph.
67-
std::vector<ComputationClient::DataPtr> results = client->ExecuteReplicated(
68-
*computations[0], client->TransferToDevice(absl::MakeConstSpan(args)),
69-
{device}, options);
67+
std::vector<ComputationClient::DataPtr> results =
68+
GetValueOrThrow(client->ExecuteReplicated(
69+
*computations[0], client->TransferToDevice(absl::MakeConstSpan(args)),
70+
{device}, options));
7071

7172
// Copy the output from device back to host and assert correctness..
7273
ASSERT_EQ(results.size(), 1);

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -387,8 +387,8 @@ PjRtComputationClient::ReplicateShardedData(
387387
torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
388388
execute_options;
389389
auto sharded_results =
390-
ExecuteReplicated(*computations.front(), {sharded_data},
391-
GetLocalDevices(), execute_options);
390+
GetValueOrThrow(ExecuteReplicated(*computations.front(), {sharded_data},
391+
GetLocalDevices(), execute_options));
392392
XLA_CHECK(sharded_results.size() > 0)
393393
<< "empty ExecuteReplicated results returned.";
394394
XLA_CHECK(sharded_results.size() == 1)
@@ -474,8 +474,8 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::ReshardData(
474474

475475
torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
476476
execute_options;
477-
auto resharded_results = ExecuteReplicated(
478-
*computation, handles, GetLocalDevices(), execute_options);
477+
auto resharded_results = GetValueOrThrow(ExecuteReplicated(
478+
*computation, handles, GetLocalDevices(), execute_options));
479479
return resharded_results;
480480
}
481481

@@ -722,7 +722,7 @@ torch::lazy::hash_t PjRtComputationClient::HashCompilationEnv() {
722722
return comp_env_hash_;
723723
}
724724

725-
std::vector<ComputationClient::DataPtr>
725+
absl::StatusOr<std::vector<ComputationClient::DataPtr>>
726726
PjRtComputationClient::ExecuteComputation(
727727
const ComputationClient::Computation& computation,
728728
absl::Span<const ComputationClient::DataPtr> arguments,
@@ -742,14 +742,14 @@ PjRtComputationClient::ExecuteComputation(
742742
dynamic_cast<const PjRtComputation&>(computation);
743743

744744
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(device);
745-
XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString();
745+
ABSL_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString();
746746

747747
std::vector<xla::PjRtBuffer*> buffers;
748748
buffers.reserve(arguments.size());
749749
for (auto& argument : arguments) {
750750
const PjRtData* pjrt_data = dynamic_cast<PjRtData*>(argument.get());
751751

752-
XLA_CHECK(pjrt_device == pjrt_data->buffer->device())
752+
ABSL_CHECK(pjrt_device == pjrt_data->buffer->device())
753753
<< "The device currently being used : " << pjrt_device->DebugString()
754754
<< " is different from the device where the buffer resides: "
755755
<< pjrt_data->buffer->device()->DebugString();
@@ -769,8 +769,9 @@ PjRtComputationClient::ExecuteComputation(
769769
<< " Done";
770770

771771
std::optional<xla::PjRtFuture<>> returned_future;
772-
std::vector<std::unique_ptr<xla::PjRtBuffer>> results =
773-
GetValueOrThrow(pjrt_computation.executable->ExecuteSharded(
772+
XLA_ASSIGN_OR_RETURN_WITH_LOCATION(
773+
std::vector<std::unique_ptr<xla::PjRtBuffer>> results,
774+
pjrt_computation.executable->ExecuteSharded(
774775
buffers, pjrt_device, execute_options, returned_future));
775776

776777
returned_future->OnReady(std::move(
@@ -795,7 +796,7 @@ PjRtComputationClient::ExecuteComputation(
795796
return datas;
796797
}
797798

798-
std::vector<ComputationClient::DataPtr>
799+
absl::StatusOr<std::vector<ComputationClient::DataPtr>>
799800
PjRtComputationClient::ExecuteReplicated(
800801
const ComputationClient::Computation& computation,
801802
absl::Span<const ComputationClient::DataPtr> arguments,
@@ -829,15 +830,15 @@ PjRtComputationClient::ExecuteReplicated(
829830
for (int32_t i = start; i < end; ++i) {
830831
auto pjrt_data =
831832
std::dynamic_pointer_cast<PjRtShardedData>(arguments[i]);
832-
XLA_CHECK_EQ(pjrt_data->shards.size(), devices.size())
833+
ABSL_CHECK_EQ(pjrt_data->shards.size(), devices.size())
833834
<< "Expected one shard per device";
834835

835836
for (int32_t d = 0; d < devices.size(); d++) {
836837
std::shared_ptr<PjRtData> shard = pjrt_data->shards[d];
837838

838839
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[d]);
839-
XLA_CHECK_EQ(shard->buffer->device(), pjrt_device);
840-
XLA_CHECK(pjrt_device->IsAddressable())
840+
ABSL_CHECK_EQ(shard->buffer->device(), pjrt_device);
841+
ABSL_CHECK(pjrt_device->IsAddressable())
841842
<< pjrt_device->DebugString();
842843

843844
argument_handles[d][i] = shard->buffer.get();
@@ -873,8 +874,10 @@ PjRtComputationClient::ExecuteReplicated(
873874
tsl::profiler::TraceMe activity(
874875
"PjRtComputationClient::ExecuteReplicated_execute",
875876
tsl::profiler::TraceMeLevel::kInfo);
876-
results = GetValueOrThrow(pjrt_computation.executable->Execute(
877-
std::move(argument_handles), execute_options, returned_futures));
877+
XLA_ASSIGN_OR_RETURN_WITH_LOCATION(
878+
results,
879+
pjrt_computation.executable->Execute(
880+
std::move(argument_handles), execute_options, returned_futures));
878881

879882
(*returned_futures)[0].OnReady(
880883
std::move([timed, op_tracker = std::move(op_tracker)](
@@ -897,7 +900,7 @@ PjRtComputationClient::ExecuteReplicated(
897900
const std::vector<xla::Shape>& output_shapes =
898901
result_shape.IsTuple() ? result_shape.tuple_shapes()
899902
: std::vector<xla::Shape>({result_shape});
900-
XLA_CHECK_EQ(output_shapes.size(), num_outputs);
903+
ABSL_CHECK_EQ(output_shapes.size(), num_outputs);
901904

902905
const std::vector<xla::OpSharding>& output_shardings =
903906
pjrt_computation.output_shardings_.has_value() && num_outputs > 0
@@ -906,7 +909,7 @@ PjRtComputationClient::ExecuteReplicated(
906909
// Without an explicit sharding annotation, the output is implicitly
907910
// replicated, and we mark explicitly replicated here.
908911
std::vector<xla::OpSharding>(num_outputs);
909-
XLA_CHECK_EQ(output_shardings.size(), num_outputs);
912+
ABSL_CHECK_EQ(output_shardings.size(), num_outputs);
910913

911914
absl::BlockingCounter counter(num_outputs);
912915

torch_xla/csrc/runtime/pjrt_computation_client.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,12 @@ class PjRtComputationClient : public ComputationClient {
8585

8686
ComputationPtr DeserializeComputation(const std::string& serialized) override;
8787

88-
std::vector<DataPtr> ExecuteComputation(
88+
absl::StatusOr<std::vector<DataPtr>> ExecuteComputation(
8989
const Computation& computation, absl::Span<const DataPtr> arguments,
9090
const std::string& device,
9191
const ExecuteComputationOptions& options) override;
9292

93-
std::vector<DataPtr> ExecuteReplicated(
93+
absl::StatusOr<std::vector<DataPtr>> ExecuteReplicated(
9494
const Computation& computation, absl::Span<const DataPtr> arguments,
9595
absl::Span<const std::string> devices,
9696
const ExecuteReplicatedOptions& options) override;

torch_xla/csrc/runtime/pjrt_computation_client_test.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,11 @@ TEST_F(PjRtComputationClientTest, Init) {
114114
std::make_shared<LiteralSource>(std::move(literal_y), device_)};
115115

116116
// Execute the graph.
117-
std::vector<ComputationClient::DataPtr> results = client_->ExecuteComputation(
118-
*computations[0], client_->TransferToDevice(absl::MakeConstSpan(args)),
119-
device_, options);
117+
std::vector<ComputationClient::DataPtr> results =
118+
GetValueOrThrow(client_->ExecuteComputation(
119+
*computations[0],
120+
client_->TransferToDevice(absl::MakeConstSpan(args)), device_,
121+
options));
120122

121123
// Copy the output from device back to host and assert correctness.
122124
ASSERT_EQ(results.size(), 1);

torch_xla/csrc/xla_backend_impl.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "torch_xla/csrc/runtime/computation_client.h"
1111
#include "torch_xla/csrc/runtime/debug_macros.h"
1212
#include "torch_xla/csrc/runtime/runtime.h"
13+
#include "torch_xla/csrc/status.h"
1314

1415
namespace at {
1516
// This function is defined in the codegenerated RegisterDispatchKey.cpp file.
@@ -161,11 +162,11 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
161162
torch::lazy::ComputationPtr computation,
162163
c10::ArrayRef<torch::lazy::BackendDataPtr> arguments,
163164
const torch::lazy::BackendDevice& device) const override {
164-
std::vector<runtime::ComputationClient::DataPtr> results =
165+
std::vector<runtime::ComputationClient::DataPtr> results = GetValueOrThrow(
165166
runtime::GetComputationClientOrDie()->ExecuteComputation(
166167
*std::dynamic_pointer_cast<runtime::ComputationClient::Computation>(
167168
computation),
168-
UnwrapXlaData(arguments), device.toString());
169+
UnwrapXlaData(arguments), device.toString()));
169170
return WrapXlaData(results);
170171
}
171172

torch_xla/csrc/xla_graph_executor.cpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -844,10 +844,11 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(
844844
// tensor results. Both sharded and unsharded results should be
845845
// "Assign"ed to the corresponding data placeholders.
846846
std::vector<runtime::ComputationClient::DataPtr> outputs =
847-
runtime::GetComputationClientOrDie()->ExecuteReplicated(
848-
*async->cached_computation->computation,
849-
UnwrapXlaData(async->parameters_data), devices,
850-
execute_options);
847+
GetValueOrThrow(
848+
runtime::GetComputationClientOrDie()->ExecuteReplicated(
849+
*async->cached_computation->computation,
850+
UnwrapXlaData(async->parameters_data), devices,
851+
execute_options));
851852
results = WrapXlaData(outputs);
852853
TF_VLOG(3) << "Executing Dynamo IR sharded graph hash "
853854
<< torch::lazy::HashToString(hash) << " on devices "
@@ -942,8 +943,8 @@ std::vector<torch::lazy::BackendDataPtr> XLAGraphExecutor::ExecuteStablehlo(
942943
}
943944

944945
std::vector<runtime::ComputationClient::DataPtr> result_data =
945-
runtime::GetComputationClientOrDie()->ExecuteComputation(
946-
*computations[0], UnwrapXlaData(arguments), device.toString());
946+
GetValueOrThrow(runtime::GetComputationClientOrDie()->ExecuteComputation(
947+
*computations[0], UnwrapXlaData(arguments), device.toString()));
947948

948949
return WrapXlaData(result_data);
949950
}
@@ -1119,10 +1120,11 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph(
11191120
// tensor results. Both sharded and unsharded results should be
11201121
// "Assign"ed to the corresponding data placeholders.
11211122
std::vector<runtime::ComputationClient::DataPtr> outputs =
1122-
runtime::GetComputationClientOrDie()->ExecuteReplicated(
1123-
*async->cached_computation->computation,
1124-
UnwrapXlaData(async->parameters_data), devices,
1125-
execute_options);
1123+
GetValueOrThrow(
1124+
runtime::GetComputationClientOrDie()->ExecuteReplicated(
1125+
*async->cached_computation->computation,
1126+
UnwrapXlaData(async->parameters_data), devices,
1127+
execute_options));
11261128
results = WrapXlaData(outputs);
11271129
TORCH_LAZY_COUNTER("ExecuteReplicated", 1);
11281130
TF_VLOG(3) << "Executing IR graph hash "
@@ -1134,11 +1136,13 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph(
11341136
<< torch::lazy::HashToString(hash) << " on device "
11351137
<< async->device << " ...";
11361138
std::vector<runtime::ComputationClient::DataPtr> outputs =
1137-
runtime::GetComputationClientOrDie()->ExecuteComputation(
1138-
*async->cached_computation->computation,
1139-
UnwrapXlaData(async->parameters_data), async->device.toString(),
1140-
{/*explode_tuple=*/true,
1141-
/*eager_mode=*/use_eager_mode});
1139+
GetValueOrThrow(
1140+
runtime::GetComputationClientOrDie()->ExecuteComputation(
1141+
*async->cached_computation->computation,
1142+
UnwrapXlaData(async->parameters_data),
1143+
async->device.toString(),
1144+
{/*explode_tuple=*/true,
1145+
/*eager_mode=*/use_eager_mode}));
11421146
results = WrapXlaData(outputs);
11431147
TORCH_LAZY_COUNTER("ExecuteComputation", 1);
11441148
TF_VLOG(3) << "Executing IR graph hash "

0 commit comments

Comments
 (0)