Skip to content

Commit b0ffc49

Browse files
authored
Error Handling: refactor ExecuteComputation and ExecuteReplicated to propagate status. (#9445)
1 parent 1ed6b46 commit b0ffc49

11 files changed

+74
-62
lines changed

test/cpp/cpp_test_util.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,11 @@ std::vector<torch_xla::runtime::ComputationClient::DataPtr> Execute(
295295
std::move(instances));
296296

297297
torch_xla::runtime::ComputationClient::ExecuteComputationOptions options;
298-
return torch_xla::runtime::GetComputationClientOrDie()->ExecuteComputation(
299-
*computations.front(), UnwrapXlaData(lowering_ctx.GetParametersData()),
300-
device.toString(), options);
298+
return GetValueOrThrow(
299+
torch_xla::runtime::GetComputationClientOrDie()->ExecuteComputation(
300+
*computations.front(),
301+
UnwrapXlaData(lowering_ctx.GetParametersData()), device.toString(),
302+
options));
301303
}
302304

303305
std::vector<at::Tensor> Fetch(

test/cpp/test_replication.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,13 @@ void TestSingleReplication(
6565
torch_xla::runtime::ComputationClient::ExecuteComputationOptions exec_options;
6666
for (size_t i = 0; i < device_strings.size(); ++i) {
6767
auto executor = [&, i]() {
68-
results[i] =
68+
results[i] = GetValueOrThrow(
6969
torch_xla::runtime::GetComputationClientOrDie()->ExecuteComputation(
7070
*compiled_computations[i],
7171
{std::dynamic_pointer_cast<
7272
torch_xla::runtime::ComputationClient::Data>(
7373
tensors_data[i])},
74-
device_strings[i], exec_options);
74+
device_strings[i], exec_options));
7575
counter.DecrementCount();
7676
};
7777
torch_xla::thread::Schedule(std::move(executor));

torch_xla/csrc/runtime/computation_client.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <vector>
1717

1818
#include "absl/container/flat_hash_map.h"
19+
#include "absl/status/statusor.h"
1920
#include "absl/types/optional.h"
2021
#include "absl/types/span.h"
2122
#include "torch_xla/csrc/device.h"
@@ -346,7 +347,7 @@ class ComputationClient {
346347
// The passed device must match the common device of the arguments Data.
347348
// If options.explode_tuple is true, the output tuple will be decomposed into
348349
// its single elements.
349-
virtual std::vector<DataPtr> ExecuteComputation(
350+
virtual absl::StatusOr<std::vector<DataPtr>> ExecuteComputation(
350351
const Computation& computation, absl::Span<const DataPtr> arguments,
351352
const std::string& device,
352353
const ExecuteComputationOptions& options =
@@ -357,7 +358,7 @@ class ComputationClient {
357358
// as `devices`. If options.explode_tuple is true, the output tuples will be
358359
// decomposed into their single elements. Returns a vector of outputs, each
359360
// of which is sharded in the same order as `devices`.
360-
virtual std::vector<DataPtr> ExecuteReplicated(
361+
virtual absl::StatusOr<std::vector<DataPtr>> ExecuteReplicated(
361362
const Computation& computation, absl::Span<const DataPtr> arguments,
362363
absl::Span<const std::string> devices,
363364
const ExecuteReplicatedOptions& options) = 0;

torch_xla/csrc/runtime/ifrt_computation_client.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <unordered_set>
55
#include <vector>
66

7+
#include "absl/log/absl_check.h"
78
#include "absl/strings/ascii.h"
89
#include "absl/synchronization/blocking_counter.h"
910
#include "absl/types/span.h"
@@ -416,8 +417,8 @@ tsl::RCReference<xla::ifrt::Array> IfrtComputationClient::ReplicateShardedData(
416417
torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
417418
execute_options;
418419

419-
auto sharded_results = ExecuteReplicated(*computations.front(), {{handle}},
420-
GetLocalDevices(), execute_options);
420+
auto sharded_results = GetValueOrThrow(ExecuteReplicated(
421+
*computations.front(), {{handle}}, GetLocalDevices(), execute_options));
421422
auto replicated_output =
422423
std::dynamic_pointer_cast<IfrtData>(sharded_results[0])
423424
->buffer->FullyReplicatedShard(
@@ -537,16 +538,16 @@ std::vector<ComputationClient::ComputationPtr> IfrtComputationClient::Compile(
537538
return computations;
538539
}
539540

540-
std::vector<ComputationClient::DataPtr>
541+
absl::StatusOr<std::vector<ComputationClient::DataPtr>>
541542
IfrtComputationClient::ExecuteComputation(
542543
const ComputationClient::Computation& computation,
543544
absl::Span<const ComputationClient::DataPtr> arguments,
544545
const std::string& device, const ExecuteComputationOptions& options) {
545546
// TODO: Implement sharded exec in IFRT
546-
XLA_ERROR() << __FUNCTION__ << " not implemented";
547+
return absl::UnimplementedError("ExecuteComputation not implemented");
547548
}
548549

549-
std::vector<ComputationClient::DataPtr>
550+
absl::StatusOr<std::vector<ComputationClient::DataPtr>>
550551
IfrtComputationClient::ExecuteReplicated(
551552
const ComputationClient::Computation& computation,
552553
const absl::Span<const ComputationClient::DataPtr> arguments,
@@ -591,11 +592,10 @@ IfrtComputationClient::ExecuteReplicated(
591592
TF_VLOG(5) << "ExecuteReplicated acquiring IFRT device lock for "
592593
<< spmd_device_str << " Done";
593594

594-
xla::ifrt::LoadedExecutable::ExecuteResult result =
595-
ifrt_computation.executable
596-
->Execute(absl::MakeSpan(argument_handles), execute_options,
597-
std::nullopt)
598-
.value();
595+
XLA_ASSIGN_OR_RETURN(
596+
xla::ifrt::LoadedExecutable::ExecuteResult result,
597+
ifrt_computation.executable->Execute(absl::MakeSpan(argument_handles),
598+
execute_options, std::nullopt));
599599

600600
result.status.OnReady(std::move([timed, op_tracker = std::move(op_tracker)](
601601
absl::Status status) mutable {
@@ -612,7 +612,7 @@ IfrtComputationClient::ExecuteReplicated(
612612
? *ifrt_computation.output_shardings_
613613
: std::vector(outputs.size(),
614614
xla::HloSharding::Replicate().ToProto());
615-
XLA_CHECK_EQ(output_shardings.size(), outputs.size());
615+
ABSL_CHECK_EQ(output_shardings.size(), outputs.size());
616616

617617
std::vector<ComputationClient::DataPtr> data_handles(outputs.size());
618618
{

torch_xla/csrc/runtime/ifrt_computation_client.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,12 @@ class IfrtComputationClient : public ComputationClient {
7878
std::vector<ComputationPtr> Compile(
7979
std::vector<CompileInstance> instances) override;
8080

81-
std::vector<DataPtr> ExecuteComputation(
81+
absl::StatusOr<std::vector<DataPtr>> ExecuteComputation(
8282
const Computation& computation, absl::Span<const DataPtr> arguments,
8383
const std::string& device,
8484
const ExecuteComputationOptions& options) override;
8585

86-
std::vector<DataPtr> ExecuteReplicated(
86+
absl::StatusOr<std::vector<DataPtr>> ExecuteReplicated(
8787
const Computation& computation, const absl::Span<const DataPtr> arguments,
8888
absl::Span<const std::string> devices,
8989
const ExecuteReplicatedOptions& options) override;

torch_xla/csrc/runtime/ifrt_computation_client_test.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,10 @@ TEST(PjRtComputationClientTest, Init) {
6464
std::make_shared<LiteralSource>(std::move(literal_y), device)};
6565

6666
// Execute the graph.
67-
std::vector<ComputationClient::DataPtr> results = client->ExecuteReplicated(
68-
*computations[0], client->TransferToDevice(absl::MakeConstSpan(args)),
69-
{device}, options);
67+
std::vector<ComputationClient::DataPtr> results =
68+
GetValueOrThrow(client->ExecuteReplicated(
69+
*computations[0], client->TransferToDevice(absl::MakeConstSpan(args)),
70+
{device}, options));
7071

7172
// Copy the output from device back to host and assert correctness..
7273
ASSERT_EQ(results.size(), 1);

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -387,8 +387,8 @@ PjRtComputationClient::ReplicateShardedData(
387387
torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
388388
execute_options;
389389
auto sharded_results =
390-
ExecuteReplicated(*computations.front(), {sharded_data},
391-
GetLocalDevices(), execute_options);
390+
GetValueOrThrow(ExecuteReplicated(*computations.front(), {sharded_data},
391+
GetLocalDevices(), execute_options));
392392
XLA_CHECK(sharded_results.size() > 0)
393393
<< "empty ExecuteReplicated results returned.";
394394
XLA_CHECK(sharded_results.size() == 1)
@@ -474,8 +474,8 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::ReshardData(
474474

475475
torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
476476
execute_options;
477-
auto resharded_results = ExecuteReplicated(
478-
*computation, handles, GetLocalDevices(), execute_options);
477+
auto resharded_results = GetValueOrThrow(ExecuteReplicated(
478+
*computation, handles, GetLocalDevices(), execute_options));
479479
return resharded_results;
480480
}
481481

@@ -722,7 +722,7 @@ torch::lazy::hash_t PjRtComputationClient::HashCompilationEnv() {
722722
return comp_env_hash_;
723723
}
724724

725-
std::vector<ComputationClient::DataPtr>
725+
absl::StatusOr<std::vector<ComputationClient::DataPtr>>
726726
PjRtComputationClient::ExecuteComputation(
727727
const ComputationClient::Computation& computation,
728728
absl::Span<const ComputationClient::DataPtr> arguments,
@@ -742,14 +742,14 @@ PjRtComputationClient::ExecuteComputation(
742742
dynamic_cast<const PjRtComputation&>(computation);
743743

744744
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(device);
745-
XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString();
745+
ABSL_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString();
746746

747747
std::vector<xla::PjRtBuffer*> buffers;
748748
buffers.reserve(arguments.size());
749749
for (auto& argument : arguments) {
750750
const PjRtData* pjrt_data = dynamic_cast<PjRtData*>(argument.get());
751751

752-
XLA_CHECK(pjrt_device == pjrt_data->buffer->device())
752+
ABSL_CHECK(pjrt_device == pjrt_data->buffer->device())
753753
<< "The device currently being used : " << pjrt_device->DebugString()
754754
<< " is different from the device where the buffer resides: "
755755
<< pjrt_data->buffer->device()->DebugString();
@@ -769,8 +769,9 @@ PjRtComputationClient::ExecuteComputation(
769769
<< " Done";
770770

771771
std::optional<xla::PjRtFuture<>> returned_future;
772-
std::vector<std::unique_ptr<xla::PjRtBuffer>> results =
773-
GetValueOrThrow(pjrt_computation.executable->ExecuteSharded(
772+
XLA_ASSIGN_OR_RETURN(
773+
std::vector<std::unique_ptr<xla::PjRtBuffer>> results,
774+
pjrt_computation.executable->ExecuteSharded(
774775
buffers, pjrt_device, execute_options, returned_future));
775776

776777
returned_future->OnReady(std::move(
@@ -795,7 +796,7 @@ PjRtComputationClient::ExecuteComputation(
795796
return datas;
796797
}
797798

798-
std::vector<ComputationClient::DataPtr>
799+
absl::StatusOr<std::vector<ComputationClient::DataPtr>>
799800
PjRtComputationClient::ExecuteReplicated(
800801
const ComputationClient::Computation& computation,
801802
absl::Span<const ComputationClient::DataPtr> arguments,
@@ -829,15 +830,15 @@ PjRtComputationClient::ExecuteReplicated(
829830
for (int32_t i = start; i < end; ++i) {
830831
auto pjrt_data =
831832
std::dynamic_pointer_cast<PjRtShardedData>(arguments[i]);
832-
XLA_CHECK_EQ(pjrt_data->shards.size(), devices.size())
833+
ABSL_CHECK_EQ(pjrt_data->shards.size(), devices.size())
833834
<< "Expected one shard per device";
834835

835836
for (int32_t d = 0; d < devices.size(); d++) {
836837
std::shared_ptr<PjRtData> shard = pjrt_data->shards[d];
837838

838839
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[d]);
839-
XLA_CHECK_EQ(shard->buffer->device(), pjrt_device);
840-
XLA_CHECK(pjrt_device->IsAddressable())
840+
ABSL_CHECK_EQ(shard->buffer->device(), pjrt_device);
841+
ABSL_CHECK(pjrt_device->IsAddressable())
841842
<< pjrt_device->DebugString();
842843

843844
argument_handles[d][i] = shard->buffer.get();
@@ -873,8 +874,9 @@ PjRtComputationClient::ExecuteReplicated(
873874
tsl::profiler::TraceMe activity(
874875
"PjRtComputationClient::ExecuteReplicated_execute",
875876
tsl::profiler::TraceMeLevel::kInfo);
876-
results = GetValueOrThrow(pjrt_computation.executable->Execute(
877-
std::move(argument_handles), execute_options, returned_futures));
877+
XLA_ASSIGN_OR_RETURN(results, pjrt_computation.executable->Execute(
878+
std::move(argument_handles),
879+
execute_options, returned_futures));
878880

879881
(*returned_futures)[0].OnReady(
880882
std::move([timed, op_tracker = std::move(op_tracker)](
@@ -897,7 +899,7 @@ PjRtComputationClient::ExecuteReplicated(
897899
const std::vector<xla::Shape>& output_shapes =
898900
result_shape.IsTuple() ? result_shape.tuple_shapes()
899901
: std::vector<xla::Shape>({result_shape});
900-
XLA_CHECK_EQ(output_shapes.size(), num_outputs);
902+
ABSL_CHECK_EQ(output_shapes.size(), num_outputs);
901903

902904
const std::vector<xla::OpSharding>& output_shardings =
903905
pjrt_computation.output_shardings_.has_value() && num_outputs > 0
@@ -906,7 +908,7 @@ PjRtComputationClient::ExecuteReplicated(
906908
// Without an explicit sharding annotation, the output is implicitly
907909
// replicated, and we mark explicitly replicated here.
908910
std::vector<xla::OpSharding>(num_outputs);
909-
XLA_CHECK_EQ(output_shardings.size(), num_outputs);
911+
ABSL_CHECK_EQ(output_shardings.size(), num_outputs);
910912

911913
absl::BlockingCounter counter(num_outputs);
912914

torch_xla/csrc/runtime/pjrt_computation_client.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,12 @@ class PjRtComputationClient : public ComputationClient {
8585

8686
ComputationPtr DeserializeComputation(const std::string& serialized) override;
8787

88-
std::vector<DataPtr> ExecuteComputation(
88+
absl::StatusOr<std::vector<DataPtr>> ExecuteComputation(
8989
const Computation& computation, absl::Span<const DataPtr> arguments,
9090
const std::string& device,
9191
const ExecuteComputationOptions& options) override;
9292

93-
std::vector<DataPtr> ExecuteReplicated(
93+
absl::StatusOr<std::vector<DataPtr>> ExecuteReplicated(
9494
const Computation& computation, absl::Span<const DataPtr> arguments,
9595
absl::Span<const std::string> devices,
9696
const ExecuteReplicatedOptions& options) override;

torch_xla/csrc/runtime/pjrt_computation_client_test.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,11 @@ TEST_F(PjRtComputationClientTest, Init) {
114114
std::make_shared<LiteralSource>(std::move(literal_y), device_)};
115115

116116
// Execute the graph.
117-
std::vector<ComputationClient::DataPtr> results = client_->ExecuteComputation(
118-
*computations[0], client_->TransferToDevice(absl::MakeConstSpan(args)),
119-
device_, options);
117+
std::vector<ComputationClient::DataPtr> results =
118+
GetValueOrThrow(client_->ExecuteComputation(
119+
*computations[0],
120+
client_->TransferToDevice(absl::MakeConstSpan(args)), device_,
121+
options));
120122

121123
// Copy the output from device back to host and assert correctness.
122124
ASSERT_EQ(results.size(), 1);

torch_xla/csrc/xla_backend_impl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,11 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
163163
torch::lazy::ComputationPtr computation,
164164
c10::ArrayRef<torch::lazy::BackendDataPtr> arguments,
165165
const torch::lazy::BackendDevice& device) const override {
166-
std::vector<runtime::ComputationClient::DataPtr> results =
166+
std::vector<runtime::ComputationClient::DataPtr> results = GetValueOrThrow(
167167
runtime::GetComputationClientOrDie()->ExecuteComputation(
168168
*std::dynamic_pointer_cast<runtime::ComputationClient::Computation>(
169169
computation),
170-
UnwrapXlaData(arguments), device.toString());
170+
UnwrapXlaData(arguments), device.toString()));
171171
return WrapXlaData(results);
172172
}
173173

0 commit comments

Comments
 (0)