2
2
3
3
#include < ATen/DLConvertor.h>
4
4
5
+ #include < memory>
6
+ #include < utility>
7
+ #include < vector>
8
+
9
+ #include " absl/log/absl_check.h"
5
10
#include " absl/status/status.h"
6
11
#include " absl/types/span.h"
7
12
#include " torch_xla/csrc/aten_xla_bridge.h"
11
16
#include " torch_xla/csrc/runtime/pjrt_computation_client.h"
12
17
#include " torch_xla/csrc/runtime/runtime.h"
13
18
#include " torch_xla/csrc/runtime/tf_logging.h"
19
+ #include " torch_xla/csrc/status.h"
14
20
#include " torch_xla/csrc/tensor.h"
15
21
#include " torch_xla/csrc/tensor_util.h"
16
22
#include " torch_xla/csrc/unwrap_data.h"
@@ -115,32 +121,30 @@ std::vector<int64_t> StridesForShape(xla::PrimitiveType element_type,
115
121
116
122
// Convert an XLA tensor to a dlPack tensor.
117
123
DLManagedTensor* toDLPack (const at::Tensor& input) {
118
- XLA_CHECK (bridge::IsXlaTensor (input)) << " The input should be an XLA tensor" ;
124
+ ABSL_CHECK (bridge::IsXlaTensor (input)) << " The input should be an XLA tensor" ;
119
125
std::shared_ptr<runtime::ComputationClient::Data> handle =
120
126
get_data_handle (input);
121
- XLA_CHECK (handle != nullptr )
127
+ ABSL_CHECK (handle != nullptr )
122
128
<< " Could not extract a valid data handle from the input tensor" ;
123
129
124
130
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer =
125
131
runtime::GetComputationClientOrDie ()->GetPjRtBuffer (handle);
126
- XLA_CHECK (pjrt_buffer != nullptr ) << " Could not get a valid pjrt_buffer" ;
132
+ ABSL_CHECK (pjrt_buffer != nullptr ) << " Could not get a valid pjrt_buffer" ;
127
133
128
- XLA_CHECK (!pjrt_buffer->IsTuple ())
134
+ ABSL_CHECK (!pjrt_buffer->IsTuple ())
129
135
<< " Unimplemented. BufferToDLPackManagedTensor is not "
130
136
" implemented for tuple buffers." ;
131
- XLA_CHECK (!pjrt_buffer->has_dynamic_dimensions ())
137
+ ABSL_CHECK (!pjrt_buffer->has_dynamic_dimensions ())
132
138
<< " Unimplemented. DynamicShape is not implemented in DLPack." ;
133
139
134
140
auto pack = std::make_unique<DLPackTensor>();
135
141
DLTensor& dt = pack->tensor .dl_tensor ;
136
142
{
137
143
// AcquireExternalReference may block
138
- auto external_ref = pjrt_buffer->AcquireExternalReference ();
139
- XLA_CHECK_OK (external_ref.status ());
140
- pack->external_reference = std::move (external_ref.value ());
144
+ pack->external_reference =
145
+ GetValueOrThrow (pjrt_buffer->AcquireExternalReference ());
141
146
xla::PjRtFuture<> future = pjrt_buffer->GetReadyFuture ();
142
- absl::Status status = future.Await ();
143
- XLA_CHECK_OK (status);
147
+ MaybeThrow (future.Await ());
144
148
}
145
149
pack->buffer_reference = pjrt_buffer;
146
150
@@ -299,7 +303,7 @@ absl::StatusOr<std::vector<int64_t>> StridesToLayout(
299
303
}
300
304
301
305
at::Tensor fromDLPack (DLManagedTensor* dlmt) {
302
- XLA_CHECK (dlmt->dl_tensor .ndim >= 0 )
306
+ ABSL_CHECK (dlmt->dl_tensor .ndim >= 0 )
303
307
<< " Number of dimensions in DLManagedTensor must be nonnegative, got "
304
308
<< dlmt->dl_tensor .ndim ;
305
309
xla::PjRtDevice* device = DeviceForDLDevice (dlmt->dl_tensor .device ).value ();
@@ -325,18 +329,17 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) {
325
329
if (dlmt->deleter ) {
326
330
on_delete_callback = [dlmt]() { dlmt->deleter (dlmt); };
327
331
}
328
- absl::StatusOr< std::unique_ptr<xla::PjRtBuffer> > pjrt_buffer =
329
- device->client ()->CreateViewOfDeviceBuffer (
332
+ std::unique_ptr<xla::PjRtBuffer> pjrt_buffer =
333
+ GetValueOrThrow ( device->client ()->CreateViewOfDeviceBuffer (
330
334
static_cast <char *>(dlmt->dl_tensor .data ) +
331
335
dlmt->dl_tensor .byte_offset ,
332
- shape, *device->default_memory_space (), on_delete_callback);
333
- XLA_CHECK_OK (pjrt_buffer.status ()) << " Failed to create a pjrt buffer." ;
334
- XLA_CHECK (pjrt_buffer.value () != nullptr ) << " pjrt buffer is null." ;
336
+ shape, *device->default_memory_space (), on_delete_callback));
337
+ ABSL_CHECK (pjrt_buffer.get () != nullptr ) << " pjrt buffer is null." ;
335
338
336
339
runtime::ComputationClient::DataPtr data =
337
340
runtime::PjRtComputationClient::CreateData (
338
341
runtime::GetComputationClientOrDie ()->PjRtDeviceToString (device),
339
- shape, std::move (pjrt_buffer. value () ));
342
+ shape, std::move (pjrt_buffer));
340
343
341
344
at::ScalarType tensor_type = at::toScalarType (dlmt->dl_tensor .dtype );
342
345
XLATensorPtr xla_tensor = XLATensor::Create (data, tensor_type);
0 commit comments