@@ -387,8 +387,8 @@ PjRtComputationClient::ReplicateShardedData(
387
387
torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
388
388
execute_options;
389
389
auto sharded_results =
390
- ExecuteReplicated (*computations.front (), {sharded_data},
391
- GetLocalDevices (), execute_options);
390
+ GetValueOrThrow ( ExecuteReplicated (*computations.front (), {sharded_data},
391
+ GetLocalDevices (), execute_options) );
392
392
XLA_CHECK (sharded_results.size () > 0 )
393
393
<< " empty ExecuteReplicated results returned." ;
394
394
XLA_CHECK (sharded_results.size () == 1 )
@@ -474,8 +474,8 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::ReshardData(
474
474
475
475
torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
476
476
execute_options;
477
- auto resharded_results = ExecuteReplicated (
478
- *computation, handles, GetLocalDevices (), execute_options);
477
+ auto resharded_results = GetValueOrThrow ( ExecuteReplicated (
478
+ *computation, handles, GetLocalDevices (), execute_options)) ;
479
479
return resharded_results;
480
480
}
481
481
@@ -722,7 +722,7 @@ torch::lazy::hash_t PjRtComputationClient::HashCompilationEnv() {
722
722
return comp_env_hash_;
723
723
}
724
724
725
- std::vector<ComputationClient::DataPtr>
725
+ absl::StatusOr< std::vector<ComputationClient::DataPtr> >
726
726
PjRtComputationClient::ExecuteComputation (
727
727
const ComputationClient::Computation& computation,
728
728
absl::Span<const ComputationClient::DataPtr> arguments,
@@ -742,14 +742,14 @@ PjRtComputationClient::ExecuteComputation(
742
742
dynamic_cast <const PjRtComputation&>(computation);
743
743
744
744
xla::PjRtDevice* pjrt_device = StringToPjRtDevice (device);
745
- XLA_CHECK (pjrt_device->IsAddressable ()) << pjrt_device->DebugString ();
745
+ ABSL_CHECK (pjrt_device->IsAddressable ()) << pjrt_device->DebugString ();
746
746
747
747
std::vector<xla::PjRtBuffer*> buffers;
748
748
buffers.reserve (arguments.size ());
749
749
for (auto & argument : arguments) {
750
750
const PjRtData* pjrt_data = dynamic_cast <PjRtData*>(argument.get ());
751
751
752
- XLA_CHECK (pjrt_device == pjrt_data->buffer ->device ())
752
+ ABSL_CHECK (pjrt_device == pjrt_data->buffer ->device ())
753
753
<< " The device currently being used : " << pjrt_device->DebugString ()
754
754
<< " is different from the device where the buffer resides: "
755
755
<< pjrt_data->buffer ->device ()->DebugString ();
@@ -769,8 +769,9 @@ PjRtComputationClient::ExecuteComputation(
769
769
<< " Done" ;
770
770
771
771
std::optional<xla::PjRtFuture<>> returned_future;
772
- std::vector<std::unique_ptr<xla::PjRtBuffer>> results =
773
- GetValueOrThrow (pjrt_computation.executable ->ExecuteSharded (
772
+ XLA_ASSIGN_OR_RETURN_WITH_LOCATION (
773
+ std::vector<std::unique_ptr<xla::PjRtBuffer>> results,
774
+ pjrt_computation.executable ->ExecuteSharded (
774
775
buffers, pjrt_device, execute_options, returned_future));
775
776
776
777
returned_future->OnReady (std::move (
@@ -795,7 +796,7 @@ PjRtComputationClient::ExecuteComputation(
795
796
return datas;
796
797
}
797
798
798
- std::vector<ComputationClient::DataPtr>
799
+ absl::StatusOr< std::vector<ComputationClient::DataPtr> >
799
800
PjRtComputationClient::ExecuteReplicated (
800
801
const ComputationClient::Computation& computation,
801
802
absl::Span<const ComputationClient::DataPtr> arguments,
@@ -829,15 +830,15 @@ PjRtComputationClient::ExecuteReplicated(
829
830
for (int32_t i = start; i < end; ++i) {
830
831
auto pjrt_data =
831
832
std::dynamic_pointer_cast<PjRtShardedData>(arguments[i]);
832
- XLA_CHECK_EQ (pjrt_data->shards .size (), devices.size ())
833
+ ABSL_CHECK_EQ (pjrt_data->shards .size (), devices.size ())
833
834
<< " Expected one shard per device" ;
834
835
835
836
for (int32_t d = 0 ; d < devices.size (); d++) {
836
837
std::shared_ptr<PjRtData> shard = pjrt_data->shards [d];
837
838
838
839
xla::PjRtDevice* pjrt_device = StringToPjRtDevice (devices[d]);
839
- XLA_CHECK_EQ (shard->buffer ->device (), pjrt_device);
840
- XLA_CHECK (pjrt_device->IsAddressable ())
840
+ ABSL_CHECK_EQ (shard->buffer ->device (), pjrt_device);
841
+ ABSL_CHECK (pjrt_device->IsAddressable ())
841
842
<< pjrt_device->DebugString ();
842
843
843
844
argument_handles[d][i] = shard->buffer .get ();
@@ -873,8 +874,10 @@ PjRtComputationClient::ExecuteReplicated(
873
874
tsl::profiler::TraceMe activity (
874
875
" PjRtComputationClient::ExecuteReplicated_execute" ,
875
876
tsl::profiler::TraceMeLevel::kInfo );
876
- results = GetValueOrThrow (pjrt_computation.executable ->Execute (
877
- std::move (argument_handles), execute_options, returned_futures));
877
+ XLA_ASSIGN_OR_RETURN_WITH_LOCATION (
878
+ results,
879
+ pjrt_computation.executable ->Execute (
880
+ std::move (argument_handles), execute_options, returned_futures));
878
881
879
882
(*returned_futures)[0 ].OnReady (
880
883
std::move ([timed, op_tracker = std::move (op_tracker)](
@@ -897,7 +900,7 @@ PjRtComputationClient::ExecuteReplicated(
897
900
const std::vector<xla::Shape>& output_shapes =
898
901
result_shape.IsTuple () ? result_shape.tuple_shapes ()
899
902
: std::vector<xla::Shape>({result_shape});
900
- XLA_CHECK_EQ (output_shapes.size (), num_outputs);
903
+ ABSL_CHECK_EQ (output_shapes.size (), num_outputs);
901
904
902
905
const std::vector<xla::OpSharding>& output_shardings =
903
906
pjrt_computation.output_shardings_ .has_value () && num_outputs > 0
@@ -906,7 +909,7 @@ PjRtComputationClient::ExecuteReplicated(
906
909
// Without an explicit sharding annotation, the output is implicitly
907
910
// replicated, and we mark explicitly replicated here.
908
911
std::vector<xla::OpSharding>(num_outputs);
909
- XLA_CHECK_EQ (output_shardings.size (), num_outputs);
912
+ ABSL_CHECK_EQ (output_shardings.size (), num_outputs);
910
913
911
914
absl::BlockingCounter counter (num_outputs);
912
915
0 commit comments