Skip to content

Commit bfeb70a

Browse files
committed
Extend status propagation to XlaDataToTensors and update callers
1 parent 687356b commit bfeb70a

File tree

6 files changed

+11
-8
lines changed

6 files changed

+11
-8
lines changed

test/cpp/test_xla_sharding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ bool XlaDataValuesEqual(torch::lazy::BackendDataPtr a,
2929
torch::lazy::BackendDataPtr b,
3030
at::ScalarType element_type) {
3131
std::vector<at::Tensor> tensors =
32-
XlaDataToTensors({a, b}, {element_type, element_type});
32+
GetValueOrThrow(XlaDataToTensors({a, b}, {element_type, element_type}));
3333
return TensorCompare(tensors[0], tensors[1]);
3434
}
3535
} // namespace

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2712,7 +2712,7 @@ void InitXlaModuleBindings(py::module m) {
27122712
}
27132713

27142714
std::vector<at::Tensor> cpu_shards =
2715-
XlaDataToTensors(WrapXlaData(handles), element_types);
2715+
GetValueOrThrow(XlaDataToTensors(WrapXlaData(handles), element_types));
27162716
// Populate the resulting vector of shards and device strings
27172717
std::vector<std::vector<std::pair<at::Tensor, std::string>>> result;
27182718
int shards_per_tensor =

torch_xla/csrc/tensor.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"
4141
#include "torch_xla/csrc/runtime/sys_util.h"
4242
#include "torch_xla/csrc/runtime/xla_util.h"
43+
#include "torch_xla/csrc/status.h"
4344
#include "torch_xla/csrc/tensor_util.h"
4445
#include "torch_xla/csrc/torch_util.h"
4546
#include "torch_xla/csrc/xla_graph_executor.h"
@@ -512,7 +513,7 @@ at::Tensor XLATensor::ToTensor(bool detached) {
512513
// The GetXlaData() call will trigger an ApplyPendingGraph() if an IR
513514
// XlaNode is available on the tensor.
514515
std::vector<at::Tensor> tensors =
515-
XlaDataToTensors({GetXlaData()}, {dtype()});
516+
GetValueOrThrow(XlaDataToTensors({GetXlaData()}, {dtype()}));
516517
tensor = std::move(tensors.front());
517518
if (!detached) {
518519
SetTensorData(tensor);

torch_xla/csrc/tensor_util.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -922,11 +922,11 @@ absl::StatusOr<std::vector<xla::Literal>> ReleaseGilAndTransferData(
922922
return literals;
923923
}
924924

925-
std::vector<at::Tensor> XlaDataToTensors(
925+
absl::StatusOr<std::vector<at::Tensor>> XlaDataToTensors(
926926
absl::Span<const torch::lazy::BackendDataPtr> xla_data,
927927
absl::Span<const at::ScalarType> dest_element_type) {
928-
std::vector<xla::Literal> literals =
929-
GetValueOrThrow(ReleaseGilAndTransferData(xla_data));
928+
XLA_ASSIGN_OR_RETURN(std::vector<xla::Literal> literals,
929+
ReleaseGilAndTransferData(xla_data));
930930
std::vector<at::Tensor> tensors(literals.size());
931931
absl::BlockingCounter counter(literals.size());
932932
for (size_t i = 0; i < tensors.size(); ++i) {

torch_xla/csrc/tensor_util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ absl::StatusOr<std::vector<xla::Literal>> ReleaseGilAndTransferData(
3232
absl::Span<const torch::lazy::BackendDataPtr> xla_data);
3333

3434
// TODO LTC @wonjoo - Migrate to upstream after Device -> BackendDevice
35-
std::vector<at::Tensor> XlaDataToTensors(
35+
absl::StatusOr<std::vector<at::Tensor>> XlaDataToTensors(
3636
absl::Span<const torch::lazy::BackendDataPtr> xla_data,
3737
absl::Span<const at::ScalarType> dest_element_type);
3838

torch_xla/csrc/xla_backend_impl.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include "torch_xla/csrc/runtime/computation_client.h"
1111
#include "torch_xla/csrc/runtime/debug_macros.h"
1212
#include "torch_xla/csrc/runtime/runtime.h"
13+
#include "torch_xla/csrc/status.h"
14+
#include "torch_xla/csrc/tensor_util.h"
1315

1416
namespace at {
1517
// This function is defined in the codegenerated RegisterDispatchKey.cpp file.
@@ -92,7 +94,7 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
9294
const torch::lazy::BackendDataPtr data,
9395
std::optional<at::ScalarType> logical_scalar_type) const override {
9496
// TODO(JackCaoG): handle the logical_scalar_type == nullptr case
95-
return XlaDataToTensors({data}, {*logical_scalar_type})[0];
97+
return GetValueOrThrow(XlaDataToTensors({data}, {*logical_scalar_type}))[0];
9698
}
9799

98100
std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext(

0 commit comments

Comments
 (0)