Skip to content

Commit 12e77e7

Browse files
authored
Error Handling: replace XLA_CHECK_OK() with status functions. (#9457)
1 parent 5496a36 commit 12e77e7

File tree

4 files changed

+27
-19
lines changed

4 files changed

+27
-19
lines changed

torch_xla/csrc/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ ptxla_cc_library(
132132
"//torch_xla/csrc/runtime:stablehlo_helper",
133133
"//torch_xla/csrc/runtime:xla_util",
134134
"@com_google_absl//absl/hash",
135+
"@com_google_absl//absl/log:absl_check",
135136
"@com_google_absl//absl/log:absl_log",
136137
"@com_google_absl//absl/memory",
137138
"@com_google_absl//absl/strings",

torch_xla/csrc/dl_convertor.cpp

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
#include <ATen/DLConvertor.h>
44

5+
#include <memory>
6+
#include <utility>
7+
#include <vector>
8+
9+
#include "absl/log/absl_check.h"
510
#include "absl/status/status.h"
611
#include "absl/types/span.h"
712
#include "torch_xla/csrc/aten_xla_bridge.h"
@@ -11,6 +16,7 @@
1116
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"
1217
#include "torch_xla/csrc/runtime/runtime.h"
1318
#include "torch_xla/csrc/runtime/tf_logging.h"
19+
#include "torch_xla/csrc/status.h"
1420
#include "torch_xla/csrc/tensor.h"
1521
#include "torch_xla/csrc/tensor_util.h"
1622
#include "torch_xla/csrc/unwrap_data.h"
@@ -115,32 +121,30 @@ std::vector<int64_t> StridesForShape(xla::PrimitiveType element_type,
115121

116122
// Convert an XLA tensor to a dlPack tensor.
117123
DLManagedTensor* toDLPack(const at::Tensor& input) {
118-
XLA_CHECK(bridge::IsXlaTensor(input)) << "The input should be an XLA tensor";
124+
ABSL_CHECK(bridge::IsXlaTensor(input)) << "The input should be an XLA tensor";
119125
std::shared_ptr<runtime::ComputationClient::Data> handle =
120126
get_data_handle(input);
121-
XLA_CHECK(handle != nullptr)
127+
ABSL_CHECK(handle != nullptr)
122128
<< "Could not extract a valid data handle from the input tensor";
123129

124130
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer =
125131
runtime::GetComputationClientOrDie()->GetPjRtBuffer(handle);
126-
XLA_CHECK(pjrt_buffer != nullptr) << "Could not get a valid pjrt_buffer";
132+
ABSL_CHECK(pjrt_buffer != nullptr) << "Could not get a valid pjrt_buffer";
127133

128-
XLA_CHECK(!pjrt_buffer->IsTuple())
134+
ABSL_CHECK(!pjrt_buffer->IsTuple())
129135
<< "Unimplemented. BufferToDLPackManagedTensor is not "
130136
"implemented for tuple buffers.";
131-
XLA_CHECK(!pjrt_buffer->has_dynamic_dimensions())
137+
ABSL_CHECK(!pjrt_buffer->has_dynamic_dimensions())
132138
<< "Unimplemented. DynamicShape is not implemented in DLPack.";
133139

134140
auto pack = std::make_unique<DLPackTensor>();
135141
DLTensor& dt = pack->tensor.dl_tensor;
136142
{
137143
// AcquireExternalReference may block
138-
auto external_ref = pjrt_buffer->AcquireExternalReference();
139-
XLA_CHECK_OK(external_ref.status());
140-
pack->external_reference = std::move(external_ref.value());
144+
pack->external_reference =
145+
GetValueOrThrow(pjrt_buffer->AcquireExternalReference());
141146
xla::PjRtFuture<> future = pjrt_buffer->GetReadyFuture();
142-
absl::Status status = future.Await();
143-
XLA_CHECK_OK(status);
147+
MaybeThrow(future.Await());
144148
}
145149
pack->buffer_reference = pjrt_buffer;
146150

@@ -299,7 +303,7 @@ absl::StatusOr<std::vector<int64_t>> StridesToLayout(
299303
}
300304

301305
at::Tensor fromDLPack(DLManagedTensor* dlmt) {
302-
XLA_CHECK(dlmt->dl_tensor.ndim >= 0)
306+
ABSL_CHECK(dlmt->dl_tensor.ndim >= 0)
303307
<< "Number of dimensions in DLManagedTensor must be nonnegative, got "
304308
<< dlmt->dl_tensor.ndim;
305309
xla::PjRtDevice* device = DeviceForDLDevice(dlmt->dl_tensor.device).value();
@@ -325,18 +329,17 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) {
325329
if (dlmt->deleter) {
326330
on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); };
327331
}
328-
absl::StatusOr<std::unique_ptr<xla::PjRtBuffer>> pjrt_buffer =
329-
device->client()->CreateViewOfDeviceBuffer(
332+
std::unique_ptr<xla::PjRtBuffer> pjrt_buffer =
333+
GetValueOrThrow(device->client()->CreateViewOfDeviceBuffer(
330334
static_cast<char*>(dlmt->dl_tensor.data) +
331335
dlmt->dl_tensor.byte_offset,
332-
shape, *device->default_memory_space(), on_delete_callback);
333-
XLA_CHECK_OK(pjrt_buffer.status()) << "Failed to create a pjrt buffer.";
334-
XLA_CHECK(pjrt_buffer.value() != nullptr) << "pjrt buffer is null.";
336+
shape, *device->default_memory_space(), on_delete_callback));
337+
ABSL_CHECK(pjrt_buffer.get() != nullptr) << "pjrt buffer is null.";
335338

336339
runtime::ComputationClient::DataPtr data =
337340
runtime::PjRtComputationClient::CreateData(
338341
runtime::GetComputationClientOrDie()->PjRtDeviceToString(device),
339-
shape, std::move(pjrt_buffer.value()));
342+
shape, std::move(pjrt_buffer));
340343

341344
at::ScalarType tensor_type = at::toScalarType(dlmt->dl_tensor.dtype);
342345
XLATensorPtr xla_tensor = XLATensor::Create(data, tensor_type);

torch_xla/csrc/runtime/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ cc_library(
403403
hdrs = ["tensor_source.h"],
404404
deps = [
405405
":debug_macros",
406+
"//torch_xla/csrc:status",
406407
"@torch//:headers",
407408
"@xla//xla:literal",
408409
"@xla//xla:shape_util",

torch_xla/csrc/runtime/tensor_source.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
#include <ATen/Tensor.h>
55
#include <torch/csrc/lazy/core/metrics.h>
66

7+
#include <string>
8+
#include <utility>
79
#include <vector>
810

911
#include "torch_xla/csrc/dtype.h"
1012
#include "torch_xla/csrc/runtime/debug_macros.h"
13+
#include "torch_xla/csrc/status.h"
1114
#include "xla/literal.h"
1215
#include "xla/shape.h"
1316
#include "xla/shape_util.h"
@@ -18,7 +21,7 @@ namespace runtime {
1821
// Owns a contiguous block of data with the shape and layout matching `shape()`.
1922
class TensorSource {
2023
public:
21-
TensorSource(std::string device) : device_(std::move(device)){};
24+
TensorSource(std::string device) : device_(std::move(device)) {}
2225

2326
virtual const void* data() const = 0;
2427

@@ -28,7 +31,7 @@ class TensorSource {
2831

2932
virtual std::vector<int64_t> byte_strides() const {
3033
std::vector<int64_t> byte_strides(shape().dimensions_size());
31-
XLA_CHECK_OK(
34+
MaybeThrow(
3235
xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides)));
3336
return byte_strides;
3437
}

0 commit comments

Comments
 (0)