@@ -896,7 +896,7 @@ xla::Literal GetTensorLiteral(const at::Tensor& tensor, const xla::Shape* shape,
896
896
return literal;
897
897
}
898
898
899
- std::vector<xla::Literal> ReleaseGilAndTransferData (
899
+ absl::StatusOr< std::vector<xla::Literal> > ReleaseGilAndTransferData (
900
900
absl::Span<const torch::lazy::BackendDataPtr> xla_data) {
901
901
// HACK: This method may be called outside of python (mainly in C++ tests) or
902
902
// when the GIL is already released, so we must check both cases here. If
@@ -909,20 +909,24 @@ std::vector<xla::Literal> ReleaseGilAndTransferData(
909
909
if (release_gil && Py_IsInitialized () && PyGILState_Check ()) {
910
910
save = PyEval_SaveThread ();
911
911
}
912
- std::vector<xla::Literal> literals =
913
- GetValueOrThrow (runtime::GetComputationClientOrDie ()->TransferFromDevice (
914
- UnwrapXlaData (xla_data)));
912
+
913
+ XLA_ASSIGN_OR_RETURN (runtime::ComputationClient * client,
914
+ runtime::GetComputationClient ());
915
+ XLA_ASSIGN_OR_RETURN (std::vector<xla::Literal> literals,
916
+ client->TransferFromDevice (UnwrapXlaData (xla_data)));
917
+
915
918
if (save) {
916
919
PyEval_RestoreThread (save);
917
920
}
918
921
919
922
return literals;
920
923
}
921
924
922
- std::vector<at::Tensor> XlaDataToTensors (
925
+ absl::StatusOr< std::vector<at::Tensor> > XlaDataToTensors (
923
926
absl::Span<const torch::lazy::BackendDataPtr> xla_data,
924
927
absl::Span<const at::ScalarType> dest_element_type) {
925
- std::vector<xla::Literal> literals = ReleaseGilAndTransferData (xla_data);
928
+ XLA_ASSIGN_OR_RETURN (std::vector<xla::Literal> literals,
929
+ ReleaseGilAndTransferData (xla_data));
926
930
std::vector<at::Tensor> tensors (literals.size ());
927
931
absl::BlockingCounter counter (literals.size ());
928
932
for (size_t i = 0 ; i < tensors.size (); ++i) {
0 commit comments