From 2779a6eb4eb40a4aa2e6fcf1ab9104e97d0796de Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 1 Jul 2025 10:34:06 -0300 Subject: [PATCH 1/2] Replace `GetValueOrThrow` with status propagation in `ReleaseGilAndTransferData` --- torch_xla/csrc/tensor_util.cpp | 14 +++++++++----- torch_xla/csrc/tensor_util.h | 2 +- torch_xla/csrc/xla_graph_executor.cpp | 3 ++- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index e2cd3a025f59..a769d60aa707 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -896,7 +896,7 @@ xla::Literal GetTensorLiteral(const at::Tensor& tensor, const xla::Shape* shape, return literal; } -std::vector ReleaseGilAndTransferData( +absl::StatusOr> ReleaseGilAndTransferData( absl::Span xla_data) { // HACK: This method may be called outside of python (mainly in C++ tests) or // when the GIL is already released, so we must check both cases here. If @@ -909,9 +909,12 @@ std::vector ReleaseGilAndTransferData( if (release_gil && Py_IsInitialized() && PyGILState_Check()) { save = PyEval_SaveThread(); } - std::vector literals = - GetValueOrThrow(runtime::GetComputationClientOrDie()->TransferFromDevice( - UnwrapXlaData(xla_data))); + + XLA_ASSIGN_OR_RETURN(runtime::ComputationClient * client, + runtime::GetComputationClient()); + XLA_ASSIGN_OR_RETURN(std::vector literals, + client->TransferFromDevice(UnwrapXlaData(xla_data))); + if (save) { PyEval_RestoreThread(save); } @@ -922,7 +925,8 @@ std::vector ReleaseGilAndTransferData( std::vector XlaDataToTensors( absl::Span xla_data, absl::Span dest_element_type) { - std::vector literals = ReleaseGilAndTransferData(xla_data); + std::vector literals = + GetValueOrThrow(ReleaseGilAndTransferData(xla_data)); std::vector tensors(literals.size()); absl::BlockingCounter counter(literals.size()); for (size_t i = 0; i < tensors.size(); ++i) { diff --git a/torch_xla/csrc/tensor_util.h b/torch_xla/csrc/tensor_util.h index 0804d3e9f781..7e1a05d974f9 100644 --- a/torch_xla/csrc/tensor_util.h +++ b/torch_xla/csrc/tensor_util.h @@ -28,7 +28,7 @@ at::Tensor MakeTensorFromXlaLiteral(const xla::Literal& literal, // Execution and data transfer are async in PJRT, so TransferFromDevice may // block until `DataPtr`s are ready. Release the GIL so other threads can // proceed and unblock any transfers or collective computations. -std::vector ReleaseGilAndTransferData( +absl::StatusOr> ReleaseGilAndTransferData( absl::Span xla_data); // TODO LTC @wonjoo - Migrate to upstream after Device -> BackendDevice diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 65eee78bc023..0931578047e7 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -497,7 +497,8 @@ std::vector XLAGraphExecutor::GetTensors( async != nullptr ? async->tensors_data : absl::Span()); - std::vector literals = ReleaseGilAndTransferData(tensors_data); + std::vector literals = + GetValueOrThrow(ReleaseGilAndTransferData(tensors_data)); return FetchTensors(tensors, literals, async != nullptr ? &async->indices : nullptr); From f64ae9e505aa5ca99c4e038781b1f071403e2000 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 1 Jul 2025 10:48:01 -0300 Subject: [PATCH 2/2] Extend status propagation to `XlaDataToTensors` and update callers --- test/cpp/test_xla_sharding.cpp | 2 +- torch_xla/csrc/init_python_bindings.cpp | 2 +- torch_xla/csrc/tensor.cpp | 3 ++- torch_xla/csrc/tensor_util.cpp | 6 +++--- torch_xla/csrc/tensor_util.h | 2 +- torch_xla/csrc/xla_backend_impl.cpp | 4 +++- 6 files changed, 11 insertions(+), 8 deletions(-) diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index 3d276f2dc263..b179c6e523cc 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -29,7 +29,7 @@ bool XlaDataValuesEqual(torch::lazy::BackendDataPtr a, torch::lazy::BackendDataPtr b, at::ScalarType element_type) { std::vector tensors = - XlaDataToTensors({a, b}, {element_type, element_type}); + GetValueOrThrow(XlaDataToTensors({a, b}, {element_type, element_type})); return TensorCompare(tensors[0], tensors[1]); } } // namespace diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 5b62d95efd57..8873fb434e0f 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2712,7 +2712,7 @@ void InitXlaModuleBindings(py::module m) { } std::vector cpu_shards = - XlaDataToTensors(WrapXlaData(handles), element_types); + GetValueOrThrow(XlaDataToTensors(WrapXlaData(handles), element_types)); // Populate the resulting vector of shards and device strings std::vector>> result; int shards_per_tensor = diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 1a1a7737ccfe..6459293a87ff 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -40,6 +40,7 @@ #include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/xla_util.h" +#include "torch_xla/csrc/status.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" #include "torch_xla/csrc/xla_graph_executor.h" @@ -512,7 +513,7 @@ at::Tensor XLATensor::ToTensor(bool detached) { // The GetXlaData() call will trigger an ApplyPendingGraph() if an IR // XlaNode is available on the tensor. std::vector tensors = - XlaDataToTensors({GetXlaData()}, {dtype()}); + GetValueOrThrow(XlaDataToTensors({GetXlaData()}, {dtype()})); tensor = std::move(tensors.front()); if (!detached) { SetTensorData(tensor); diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index a769d60aa707..26c669b1e4f8 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -922,11 +922,11 @@ absl::StatusOr> ReleaseGilAndTransferData( return literals; } -std::vector XlaDataToTensors( +absl::StatusOr> XlaDataToTensors( absl::Span xla_data, absl::Span dest_element_type) { - std::vector literals = - GetValueOrThrow(ReleaseGilAndTransferData(xla_data)); + XLA_ASSIGN_OR_RETURN(std::vector literals, + ReleaseGilAndTransferData(xla_data)); std::vector tensors(literals.size()); absl::BlockingCounter counter(literals.size()); for (size_t i = 0; i < tensors.size(); ++i) { diff --git a/torch_xla/csrc/tensor_util.h b/torch_xla/csrc/tensor_util.h index 7e1a05d974f9..a0f6dea480f1 100644 --- a/torch_xla/csrc/tensor_util.h +++ b/torch_xla/csrc/tensor_util.h @@ -32,7 +32,7 @@ absl::StatusOr> ReleaseGilAndTransferData( absl::Span xla_data); // TODO LTC @wonjoo - Migrate to upstream after Device -> BackendDevice -std::vector XlaDataToTensors( +absl::StatusOr> XlaDataToTensors( absl::Span xla_data, absl::Span dest_element_type); diff --git a/torch_xla/csrc/xla_backend_impl.cpp b/torch_xla/csrc/xla_backend_impl.cpp index bf130e1fab73..df52770b11ef 100644 --- a/torch_xla/csrc/xla_backend_impl.cpp +++ b/torch_xla/csrc/xla_backend_impl.cpp @@ -10,6 +10,8 @@ #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/runtime.h" +#include "torch_xla/csrc/status.h" +#include "torch_xla/csrc/tensor_util.h" namespace at { // This function is defined in the codegenerated RegisterDispatchKey.cpp file. @@ -92,7 +94,7 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface { const torch::lazy::BackendDataPtr data, std::optional logical_scalar_type) const override { // TODO(JackCaoG): handle the logical_scalar_type == nullptr case - return XlaDataToTensors({data}, {*logical_scalar_type})[0]; + return GetValueOrThrow(XlaDataToTensors({data}, {*logical_scalar_type}))[0]; } std::unique_ptr CreateLoweringContext(