diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index f9039bd6df73..ccae6a6e2c83 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -132,6 +132,7 @@ ptxla_cc_library( "//torch_xla/csrc/runtime:stablehlo_helper", "//torch_xla/csrc/runtime:xla_util", "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index 3130831b3695..9adb63747dcb 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -2,6 +2,11 @@ #include +#include +#include +#include + +#include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/types/span.h" #include "torch_xla/csrc/aten_xla_bridge.h" @@ -11,6 +16,7 @@ #include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/tf_logging.h" +#include "torch_xla/csrc/status.h" #include "torch_xla/csrc/tensor.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/unwrap_data.h" @@ -115,32 +121,30 @@ std::vector StridesForShape(xla::PrimitiveType element_type, // Convert an XLA tensor to a dlPack tensor. DLManagedTensor* toDLPack(const at::Tensor& input) { - XLA_CHECK(bridge::IsXlaTensor(input)) << "The input should be an XLA tensor"; + ABSL_CHECK(bridge::IsXlaTensor(input)) << "The input should be an XLA tensor"; std::shared_ptr handle = get_data_handle(input); - XLA_CHECK(handle != nullptr) + ABSL_CHECK(handle != nullptr) << "Could not extract a valid data handle from the input tensor"; std::shared_ptr pjrt_buffer = runtime::GetComputationClientOrDie()->GetPjRtBuffer(handle); - XLA_CHECK(pjrt_buffer != nullptr) << "Could not get a valid pjrt_buffer"; + ABSL_CHECK(pjrt_buffer != nullptr) << "Could not get a valid pjrt_buffer"; - XLA_CHECK(!pjrt_buffer->IsTuple()) + ABSL_CHECK(!pjrt_buffer->IsTuple()) << "Unimplemented. BufferToDLPackManagedTensor is not " "implemented for tuple buffers."; - XLA_CHECK(!pjrt_buffer->has_dynamic_dimensions()) + ABSL_CHECK(!pjrt_buffer->has_dynamic_dimensions()) << "Unimplemented. DynamicShape is not implemented in DLPack."; auto pack = std::make_unique(); DLTensor& dt = pack->tensor.dl_tensor; { // AcquireExternalReference may block - auto external_ref = pjrt_buffer->AcquireExternalReference(); - XLA_CHECK_OK(external_ref.status()); - pack->external_reference = std::move(external_ref.value()); + pack->external_reference = + GetValueOrThrow(pjrt_buffer->AcquireExternalReference()); xla::PjRtFuture<> future = pjrt_buffer->GetReadyFuture(); - absl::Status status = future.Await(); - XLA_CHECK_OK(status); + MaybeThrow(future.Await()); } pack->buffer_reference = pjrt_buffer; @@ -299,7 +303,7 @@ absl::StatusOr> StridesToLayout( } at::Tensor fromDLPack(DLManagedTensor* dlmt) { - XLA_CHECK(dlmt->dl_tensor.ndim >= 0) + ABSL_CHECK(dlmt->dl_tensor.ndim >= 0) << "Number of dimensions in DLManagedTensor must be nonnegative, got " << dlmt->dl_tensor.ndim; xla::PjRtDevice* device = DeviceForDLDevice(dlmt->dl_tensor.device).value(); @@ -325,18 +329,17 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) { if (dlmt->deleter) { on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; } - absl::StatusOr> pjrt_buffer = - device->client()->CreateViewOfDeviceBuffer( + std::unique_ptr pjrt_buffer = + GetValueOrThrow(device->client()->CreateViewOfDeviceBuffer( static_cast(dlmt->dl_tensor.data) + dlmt->dl_tensor.byte_offset, - shape, *device->default_memory_space(), on_delete_callback); - XLA_CHECK_OK(pjrt_buffer.status()) << "Failed to create a pjrt buffer."; - XLA_CHECK(pjrt_buffer.value() != nullptr) << "pjrt buffer is null."; + shape, *device->default_memory_space(), on_delete_callback)); + ABSL_CHECK(pjrt_buffer.get() != nullptr) << "pjrt buffer is null."; runtime::ComputationClient::DataPtr data = runtime::PjRtComputationClient::CreateData( runtime::GetComputationClientOrDie()->PjRtDeviceToString(device), - shape, std::move(pjrt_buffer.value())); + shape, std::move(pjrt_buffer)); at::ScalarType tensor_type = at::toScalarType(dlmt->dl_tensor.dtype); XLATensorPtr xla_tensor = XLATensor::Create(data, tensor_type); diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 7444a1977fe3..146de69924a8 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -401,6 +401,7 @@ cc_library( hdrs = ["tensor_source.h"], deps = [ ":debug_macros", + "//torch_xla/csrc:status", "@torch//:headers", "@xla//xla:literal", "@xla//xla:shape_util", diff --git a/torch_xla/csrc/runtime/tensor_source.h b/torch_xla/csrc/runtime/tensor_source.h index 19665c5627a4..280bc4f83484 100644 --- a/torch_xla/csrc/runtime/tensor_source.h +++ b/torch_xla/csrc/runtime/tensor_source.h @@ -4,10 +4,13 @@ #include #include +#include +#include #include #include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/status.h" #include "xla/literal.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -18,7 +21,7 @@ namespace runtime { // Owns a contiguous block of data with the shape and layout matching `shape()`. class TensorSource { public: - TensorSource(std::string device) : device_(std::move(device)){}; + TensorSource(std::string device) : device_(std::move(device)) {} virtual const void* data() const = 0; @@ -28,7 +31,7 @@ class TensorSource { virtual std::vector byte_strides() const { std::vector byte_strides(shape().dimensions_size()); - XLA_CHECK_OK( + MaybeThrow( xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides))); return byte_strides; }