From 75ecdf36231f7c31a3e23bbbdc9d1203eab71e1b Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 28 Jun 2025 13:04:58 -0300 Subject: [PATCH 1/7] Propagate status on OOM crashes and exception. --- test/cpp/cpp_test_util.cpp | 5 ++- test/cpp/test_replication.cpp | 5 ++- test/test_operations.py | 8 +++++ torch_xla/csrc/init_python_bindings.cpp | 4 +-- torch_xla/csrc/runtime/BUILD | 1 + torch_xla/csrc/runtime/computation_client.h | 2 +- .../csrc/runtime/ifrt_computation_client.cpp | 10 +++--- .../csrc/runtime/ifrt_computation_client.h | 2 +- .../runtime/ifrt_computation_client_test.cpp | 2 +- .../csrc/runtime/pjrt_computation_client.cpp | 35 +++++++++---------- .../csrc/runtime/pjrt_computation_client.h | 2 +- .../runtime/pjrt_computation_client_test.cpp | 2 +- torch_xla/csrc/tensor_util.cpp | 5 +-- 13 files changed, 44 insertions(+), 39 deletions(-) diff --git a/test/cpp/cpp_test_util.cpp b/test/cpp/cpp_test_util.cpp index 03efa4207191..afe573101ebc 100644 --- a/test/cpp/cpp_test_util.cpp +++ b/test/cpp/cpp_test_util.cpp @@ -303,9 +303,8 @@ std::vector Execute( std::vector Fetch( absl::Span device_data) { - std::vector literals = - torch_xla::runtime::GetComputationClientOrDie()->TransferFromDevice( - device_data); + std::vector literals = GetValueOrThrow( + runtime::GetComputationClientOrDie()->TransferFromDevice(device_data)); std::vector tensors; for (auto& literal : literals) { tensors.push_back(MakeTensorFromXlaLiteral( diff --git a/test/cpp/test_replication.cpp b/test/cpp/test_replication.cpp index 88175c2fdbb7..b565dc44cd08 100644 --- a/test/cpp/test_replication.cpp +++ b/test/cpp/test_replication.cpp @@ -79,9 +79,8 @@ void TestSingleReplication( counter.Wait(); for (size_t i = 0; i < results.size(); ++i) { - std::vector literals = - torch_xla::runtime::GetComputationClientOrDie()->TransferFromDevice( - results[i]); + std::vector literals = GetValueOrThrow( + runtime::GetComputationClientOrDie()->TransferFromDevice(results[i])); ASSERT_EQ(literals.size(), 1); // The result must be the original tensor value, multiplied by the number of diff --git a/test/test_operations.py b/test/test_operations.py index f037ad4b8cb2..db9246c3075d 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2458,6 +2458,14 @@ def test_add_broadcast_error(self): torch.add(a, b) torch_xla.sync() + def test_construct_large_tensor_raises_error(self): + a = torch.rand(1024, 1024, 1024, 1024, 1024, device=torch_xla.device()) + + # OOM is raised when we try to bring data from the device. + with self.assertRaisesRegex(RuntimeError, r"Out of memory allocating \d* bytes"): + b = a.sum() + b.cpu() + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index ce55969d693b..5b62d95efd57 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1229,9 +1229,9 @@ class PyLoweringContext { lowering_ctx.GetParametersData(); // Fetch this parameter data - std::vector literals = + std::vector literals = GetValueOrThrow( runtime::GetComputationClientOrDie()->TransferFromDevice( - UnwrapXlaData(device_data)); + UnwrapXlaData(device_data))); // Create a mapping from paramater id to the tensor data std::unordered_map results; diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index d98329718906..c4760783f4d0 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -123,6 +123,7 @@ cc_library( ":tf_logging", ":xla_coordinator", "//torch_xla/csrc:status", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index c7603c8932af..c2f9389a4a0a 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -318,7 +318,7 @@ class ComputationClient { // Note: `TransferFromDevice` call will block until the `DataPtrs` are ready // if they were created by `TransferToDevice` or `Execute*`. Calling this from // python while holding the GIL can cause deadlocks! - virtual std::vector TransferFromDevice( + virtual absl::StatusOr> TransferFromDevice( absl::Span handles) = 0; virtual std::uintptr_t UnsafeBufferPointer(const DataPtr handle) = 0; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cpp b/torch_xla/csrc/runtime/ifrt_computation_client.cpp index a463f79a226f..f5a6af1b267c 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cpp @@ -436,8 +436,8 @@ std::shared_ptr IfrtComputationClient::GetPjRtBuffer( XLA_ERROR() << __FUNCTION__ << " not implemented"; } -std::vector IfrtComputationClient::TransferFromDevice( - absl::Span handles) { +absl::StatusOr> +IfrtComputationClient::TransferFromDevice(absl::Span handles) { metrics::TimedSection timed(TransferFromDeviceMetric()); tsl::profiler::TraceMe activity("IfrtComputationClient::TransferFromDevice", tsl::profiler::TraceMeLevel::kInfo); @@ -455,9 +455,9 @@ std::vector IfrtComputationClient::TransferFromDevice( auto& literal = literals.emplace_back( xla::ShapeUtil::DeviceShapeToHostShape(ifrt_data->shape())); std::vector byte_strides(literal.shape().dimensions_size()); - XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal.shape(), - absl::MakeSpan(byte_strides))); - XLA_CHECK_OK( + XLA_RETURN_IF_ERROR(xla::ShapeUtil::ByteStrides( + literal.shape(), absl::MakeSpan(byte_strides))); + XLA_RETURN_IF_ERROR( replicated_array ->CopyToHostBuffer(literal.untyped_data(), byte_strides, xla::ifrt::ArrayCopySemantics::kAlwaysCopy) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 9c21d7a8d7fb..46b6343dc10a 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -62,7 +62,7 @@ class IfrtComputationClient : public ComputationClient { XLA_ERROR() << __FUNCTION__ << " not implemented"; } - std::vector TransferFromDevice( + absl::StatusOr> TransferFromDevice( absl::Span handles) override; std::uintptr_t UnsafeBufferPointer(const DataPtr handle) override; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp b/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp index 7a4741fc1bc5..eb39f9b2e23f 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp +++ b/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp @@ -70,7 +70,7 @@ TEST(PjRtComputationClientTest, Init) { // Copy the output from device back to host and assert correctness.. ASSERT_EQ(results.size(), 1); - auto result_literals = client->TransferFromDevice(results); + auto result_literals = GetValueOrThrow(client->TransferFromDevice(results)); ASSERT_THAT(result_literals, ::testing::SizeIs(1)); EXPECT_TRUE(xla::LiteralTestUtil::Equal( xla::LiteralUtil::CreateR2({{6.0f, 8.0f}, {10.0f, 12.0f}}), diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cpp b/torch_xla/csrc/runtime/pjrt_computation_client.cpp index 8239da35846d..2b9d6bf5edfc 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cpp @@ -4,6 +4,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/strings/ascii.h" #include "absl/synchronization/blocking_counter.h" #include "absl/types/span.h" @@ -508,8 +509,8 @@ std::shared_ptr PjRtComputationClient::GetPjRtBuffer( } } -std::vector PjRtComputationClient::TransferFromDevice( - absl::Span handles) { +absl::StatusOr> +PjRtComputationClient::TransferFromDevice(absl::Span handles) { metrics::TimedSection timed(TransferFromDeviceMetric()); tsl::profiler::TraceMe activity("PjRtComputationClient::TransferFromDevice", tsl::profiler::TraceMeLevel::kInfo); @@ -522,21 +523,21 @@ std::vector PjRtComputationClient::TransferFromDevice( // Use XLA replication to reassemble the sharded data. If input handle // is not sharded, then it is a no-op. std::shared_ptr pjrt_data = ReplicateShardedData(handle); - XLA_CHECK(pjrt_data) << "PjRt_data is null in " << __FUNCTION__; - XLA_CHECK(pjrt_data->buffer != nullptr) + ABSL_CHECK(pjrt_data) << "PjRt_data is null in " << __FUNCTION__; + ABSL_CHECK(pjrt_data->buffer != nullptr) << "PjRt buffer is null in " << __FUNCTION__; - xla::Literal& literal = - literals.emplace_back(host_output_shape(pjrt_data->buffer.get())); + // Constructing a literal too large will make the whole program crash. + // Instead, we pass allocate_arrays=False, which makes this kind of + // error possible to be handled in the `Await()` call below. + xla::Literal& literal = literals.emplace_back( + xla::Literal(host_output_shape(pjrt_data->buffer.get()), + /* allocate_arrays= */ false)); futures.push_back(pjrt_data->buffer->ToLiteral(&literal)); total_size += literal.size_bytes(); } - for (auto& future : futures) { - absl::Status status = future.Await(); - XLA_CHECK_OK(status) << "Failed to await future from buffer to literal in" - << __FUNCTION__; - } + XLA_RETURN_IF_ERROR(xla::JoinFutures(futures).Await()); InboundDataMetric()->AddSample(total_size); return literals; @@ -773,10 +774,8 @@ PjRtComputationClient::ExecuteComputation( std::optional> returned_future; std::vector> results = - pjrt_computation.executable - ->ExecuteSharded(buffers, pjrt_device, execute_options, - returned_future) - .value(); + GetValueOrThrow(pjrt_computation.executable->ExecuteSharded( + buffers, pjrt_device, execute_options, returned_future)); returned_future->OnReady(std::move( [timed, op_tracker = std::move(op_tracker)](absl::Status unused) mutable { @@ -878,10 +877,8 @@ PjRtComputationClient::ExecuteReplicated( tsl::profiler::TraceMe activity( "PjRtComputationClient::ExecuteReplicated_execute", tsl::profiler::TraceMeLevel::kInfo); - results = pjrt_computation.executable - ->Execute(std::move(argument_handles), execute_options, - returned_futures) - .value(); + results = GetValueOrThrow(pjrt_computation.executable->Execute( + std::move(argument_handles), execute_options, returned_futures)); (*returned_futures)[0].OnReady( std::move([timed, op_tracker = std::move(op_tracker)]( diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index b7c61e2ec74c..3a6b4478f722 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -65,7 +65,7 @@ class PjRtComputationClient : public ComputationClient { absl::Span handles, absl::Span shardings) override; - std::vector TransferFromDevice( + absl::StatusOr> TransferFromDevice( absl::Span handles) override; std::uintptr_t UnsafeBufferPointer(const DataPtr handle) override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp b/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp index 3398e61a2782..0fe2b2a70fcb 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp @@ -120,7 +120,7 @@ TEST_F(PjRtComputationClientTest, Init) { // Copy the output from device back to host and assert correctness. ASSERT_EQ(results.size(), 1); - auto result_literals = client_->TransferFromDevice(results); + auto result_literals = GetValueOrThrow(client_->TransferFromDevice(results)); ASSERT_THAT(result_literals, ::testing::SizeIs(1)); EXPECT_TRUE(xla::LiteralTestUtil::Equal( xla::LiteralUtil::CreateR2({{6.0f, 8.0f}, {10.0f, 12.0f}}), diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 0a7f184cda77..e2cd3a025f59 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -24,6 +24,7 @@ #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/runtime/util.h" +#include "torch_xla/csrc/status.h" #include "torch_xla/csrc/thread_pool.h" #include "torch_xla/csrc/torch_util.h" #include "torch_xla/csrc/xla_backend_impl.h" @@ -909,8 +910,8 @@ std::vector ReleaseGilAndTransferData( save = PyEval_SaveThread(); } std::vector literals = - runtime::GetComputationClientOrDie()->TransferFromDevice( - UnwrapXlaData(xla_data)); + GetValueOrThrow(runtime::GetComputationClientOrDie()->TransferFromDevice( + UnwrapXlaData(xla_data))); if (save) { PyEval_RestoreThread(save); } From 68dc43be538320d542e7d9768f5bb992d426f5d2 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 17 Jul 2025 20:23:28 -0300 Subject: [PATCH 2/7] Fix lint error. --- test/test_operations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_operations.py b/test/test_operations.py index db9246c3075d..b6540e362bd7 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2462,7 +2462,8 @@ def test_construct_large_tensor_raises_error(self): a = torch.rand(1024, 1024, 1024, 1024, 1024, device=torch_xla.device()) # OOM is raised when we try to bring data from the device. - with self.assertRaisesRegex(RuntimeError, r"Out of memory allocating \d* bytes"): + with self.assertRaisesRegex(RuntimeError, + r"Out of memory allocating \d* bytes"): b = a.sum() b.cpu() From d45c852865d5d885fa82082530fd0e601fed9a0b Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 19 Jul 2025 08:55:56 -0300 Subject: [PATCH 3/7] Fix segfault. --- torch_xla/csrc/runtime/pjrt_computation_client.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cpp b/torch_xla/csrc/runtime/pjrt_computation_client.cpp index 2b9d6bf5edfc..dd4950d87f5e 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cpp @@ -527,12 +527,8 @@ PjRtComputationClient::TransferFromDevice(absl::Span handles) { ABSL_CHECK(pjrt_data->buffer != nullptr) << "PjRt buffer is null in " << __FUNCTION__; - // Constructing a literal too large will make the whole program crash. - // Instead, we pass allocate_arrays=False, which makes this kind of - // error possible to be handled in the `Await()` call below. xla::Literal& literal = literals.emplace_back( - xla::Literal(host_output_shape(pjrt_data->buffer.get()), - /* allocate_arrays= */ false)); + xla::Literal(host_output_shape(pjrt_data->buffer.get()))); futures.push_back(pjrt_data->buffer->ToLiteral(&literal)); total_size += literal.size_bytes(); From 990c79694d174d69dd9fbfc3b39fa39f437c474d Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 19 Jul 2025 09:38:36 -0300 Subject: [PATCH 4/7] Fix test. --- test/test_operations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index b6540e362bd7..fb4f654c3e83 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2459,12 +2459,12 @@ def test_add_broadcast_error(self): torch_xla.sync() def test_construct_large_tensor_raises_error(self): - a = torch.rand(1024, 1024, 1024, 1024, 1024, device=torch_xla.device()) - - # OOM is raised when we try to bring data from the device. with self.assertRaisesRegex(RuntimeError, r"Out of memory allocating \d* bytes"): + # When eager-mode is enabled, OOM is triggered here. + a = torch.rand(1024, 1024, 1024, 1024, 1024, device=torch_xla.device()) b = a.sum() + # OOM is raised when we try to bring data from the device. b.cpu() From c1c32a5b2b011e8721a9e3609fab652901030915 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 21 Jul 2025 10:12:29 -0300 Subject: [PATCH 5/7] Address review. --- test/test_operations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_operations.py b/test/test_operations.py index fb4f654c3e83..9639996efb0d 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2460,7 +2460,7 @@ def test_add_broadcast_error(self): def test_construct_large_tensor_raises_error(self): with self.assertRaisesRegex(RuntimeError, - r"Out of memory allocating \d* bytes"): + r"Out of memory allocating \d+ bytes"): # When eager-mode is enabled, OOM is triggered here. a = torch.rand(1024, 1024, 1024, 1024, 1024, device=torch_xla.device()) b = a.sum() From 95d509ba05b43d99d717266aa00bc1a04f4c0b83 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 22 Jul 2025 12:06:17 -0300 Subject: [PATCH 6/7] Run test only on CPU. --- test/test_operations.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/test_operations.py b/test/test_operations.py index 9639996efb0d..cc22275d7694 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -88,6 +88,10 @@ def skipIfFunctionalizationDisabled(reason): return _skipIfFunctionalization(value=True, reason=reason) +def onlyOnCPU(fn): + accelerator = os.environ.get("PJRT_DEVICE").lower() + return unittest.skipIf(accelerator != "cpu", "PJRT_DEVICE=CUDA required")(fn) + def onlyOnCUDA(fn): accelerator = os.environ.get("PJRT_DEVICE").lower() return unittest.skipIf(accelerator != "cuda", "PJRT_DEVICE=CUDA required")(fn) @@ -2458,6 +2462,7 @@ def test_add_broadcast_error(self): torch.add(a, b) torch_xla.sync() + @onlyOnCPU def test_construct_large_tensor_raises_error(self): with self.assertRaisesRegex(RuntimeError, r"Out of memory allocating \d+ bytes"): From 103cd0fc7b70f0747acf4d02d3b320fae04866ae Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 23 Jul 2025 14:01:18 -0300 Subject: [PATCH 7/7] Fix lint issues. --- test/test_operations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_operations.py b/test/test_operations.py index cc22275d7694..68aa0b6c2c82 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -92,6 +92,7 @@ def onlyOnCPU(fn): accelerator = os.environ.get("PJRT_DEVICE").lower() return unittest.skipIf(accelerator != "cpu", "PJRT_DEVICE=CUDA required")(fn) + def onlyOnCUDA(fn): accelerator = os.environ.get("PJRT_DEVICE").lower() return unittest.skipIf(accelerator != "cuda", "PJRT_DEVICE=CUDA required")(fn)