Skip to content

Commit 0cdb6fb

Browse files
committed
Support SPMD placeholder tensors
1 parent 8b95c5d commit 0cdb6fb

File tree

10 files changed

+138
-25
lines changed

10 files changed

+138
-25
lines changed

test/neuron/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ function run_xla_op_tests3 {
242242
#run_test "$_TEST_DIR/spmd/test_xla_virtual_device.py"
243243
#run_test "$_TEST_DIR/spmd/test_dynamo_spmd.py"
244244
run_test "$_TEST_DIR/spmd/test_spmd_debugging.py"
245+
run_test "$_TEST_DIR/spmd/test_spmd_placeholder.py"
245246
#=run_test "$_TEST_DIR/spmd/test_xla_distributed_checkpoint.py"
246247
run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py"
247248
#run_test "$_TEST_DIR/spmd/test_dtensor_integration.py"

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ function run_xla_op_tests3 {
268268
run_test "$_TEST_DIR/test_persistent_cache.py"
269269
run_test "$_TEST_DIR/test_devices.py"
270270
run_test "$_TEST_DIR/test_manual_xla_registration.py"
271+
run_test "$_TEST_DIR/spmd/test_spmd_placeholder.py"
271272
# NOTE: this line below is testing export and don't care about GPU
272273
PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$_TEST_DIR/test_core_aten_ops.py"
273274
run_test "$_TEST_DIR/test_pallas.py"

test/spmd/test_dynamo_spmd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def test_dynamo_input_sharding_threashold(self):
176176
print('catch')
177177
# it is hard to catch the C++ runtime error in python, instead we can check if
178178
# after printing that dynamo_res is still a placeholder then it means C++ crashed.
179-
self.assertTrue(torch_xla._XLAC._is_placecholder(dynamo_res))
179+
self.assertTrue(torch_xla._XLAC._is_placeholder(dynamo_res))
180180
if saved_var != None:
181181
os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] = saved_var
182182
else:

test/spmd/test_spmd_placeholder.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import sys
2+
import unittest
3+
import torch
4+
import torch_xla
5+
from torch_xla.core.xla_builder import create_placeholder_tensor
6+
import torch_xla.debug.metrics as met
7+
import re
8+
import torch_xla.runtime as xr
9+
import torch_xla.distributed.spmd as xs
10+
11+
import test_xla_sharding_base
12+
13+
14+
class TestSPMDPlaceholder(test_xla_sharding_base.XlaShardingTest):
15+
16+
def setUp(self):
17+
super().setUpClass()
18+
19+
def test_create_placeholder(self):
20+
num_devices = self.n_devices
21+
for shape, dtype in zip(
22+
((num_devices, num_devices), (num_devices, num_devices, 2),
23+
(num_devices, num_devices, 2, 2)),
24+
(torch.float32, torch.bfloat16, torch.int8),
25+
):
26+
model_axis = max(1, self.n_devices // 2)
27+
data_axis = self.n_devices // model_axis
28+
mesh_shape = (data_axis, model_axis) + (1,) * (len(shape) - 2)
29+
axis_names = ('x', 'y') + tuple(f'z{i}' for i in range(1, len(shape) - 1))
30+
mesh = self._get_mesh(mesh_shape, axis_names=axis_names)
31+
32+
p = create_placeholder_tensor(shape, dtype)
33+
xs.mark_sharding(p, mesh, axis_names)
34+
assert isinstance(p, torch.Tensor)
35+
assert p.device == torch_xla.device()
36+
self.assertEqual(p.dtype, dtype)
37+
self.assertEqual(p.shape, shape)
38+
self.assertTrue(torch_xla._XLAC._is_placeholder(p))
39+
40+
def test_read_value_crashes(self):
41+
mesh = self._get_mesh((self.n_devices,), axis_names=('x',))
42+
p = create_placeholder_tensor((self.n_devices,), torch.bfloat16)
43+
xs.mark_sharding(p, mesh, ('x',))
44+
with self.assertRaises(RuntimeError):
45+
p.cpu()
46+
47+
def test_trace_graph(self):
48+
met.clear_all()
49+
self.assertFalse(met.metric_data("TransferToDeviceTime"))
50+
51+
model_axis = max(1, self.n_devices // 2)
52+
data_axis = self.n_devices // model_axis
53+
mesh_shape = (data_axis, model_axis)
54+
mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y'))
55+
56+
p1 = create_placeholder_tensor((128, 32), torch.bfloat16)
57+
xs.mark_sharding(p1, mesh, ('x', 'y'))
58+
a = torch.sin(p1)
59+
60+
p2 = create_placeholder_tensor((32, 64), torch.bfloat16)
61+
xs.mark_sharding(p2, mesh, ('x', 'y'))
62+
# We use p1 once and p2 twice. But the graph should still only have two parameters.
63+
b = (a @ p2) @ p2.T
64+
ir: str = torch_xla._XLAC._get_xla_tensors_text([b])
65+
self.assertEqual(ir.count("xla::device_data()"), 2)
66+
self.assertEqual(ir.count("bf16[32,64]{1,0} xla::device_data()"), 1)
67+
self.assertEqual(ir.count("bf16[128,32]{1,0} xla::device_data()"), 1)
68+
hlo: str = torch_xla._XLAC._get_xla_tensors_hlo([b])
69+
regex = r'\(p.*: bf16\[32,64\], p.*: bf16\[128,32\]\) -> \(bf16\[128,32\]\)'
70+
assert re.search(regex, hlo) is not None
71+
72+
# There should be no buffers transferred to the device during tracing
73+
self.assertFalse(met.metric_data("TransferToDeviceTime"))
74+
75+
def test_placeholder_handle_unique(self):
76+
mesh = self._get_mesh((self.n_devices,), axis_names=('x',))
77+
78+
p1 = create_placeholder_tensor((self.n_devices,), torch.bfloat16)
79+
xs.mark_sharding(p1, mesh, ('x',))
80+
81+
p2 = create_placeholder_tensor((self.n_devices,), torch.bfloat16)
82+
xs.mark_sharding(p2, mesh, ('x',))
83+
84+
h1, h2 = torch_xla._XLAC._get_tensors_handle([p1, p2])
85+
self.assertNotEqual(h1, h2)
86+
87+
88+
if __name__ == "__main__":
89+
test = unittest.main()
90+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/test_input_output_aliases.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def test_user_config_donation_with_ltc_donation(self):
247247

248248
# We surface the C++ runtime error by checking that the backend data is
249249
# no longer present for the IR node.
250-
self.assertTrue(torch_xla._XLAC._is_placecholder(t0))
250+
self.assertTrue(torch_xla._XLAC._is_placeholder(t0))
251251
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0)
252252

253253
@parameterized.parameters(True, False)
@@ -272,7 +272,7 @@ def test_user_config_donation_with_ltc_donation_graph_sync(
272272
# We surface the C++ runtime error by checking that the backend data is
273273
# no longer present for the IR node.
274274
self.assertEqual(
275-
torch_xla._XLAC._is_placecholder(t0), enable_buffer_donor_config)
275+
torch_xla._XLAC._is_placeholder(t0), enable_buffer_donor_config)
276276
self.assertEqual(
277277
met.metric_data("InputOutputAliasCount")[1],
278278
enable_buffer_donor_config)

test/test_placeholder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_create_placeholder(self):
2222
assert p.device == torch_xla.device()
2323
self.assertEqual(p.dtype, dtype)
2424
self.assertEqual(p.shape, shape)
25-
self.assertTrue(torch_xla._XLAC._is_placecholder(p))
25+
self.assertTrue(torch_xla._XLAC._is_placeholder(p))
2626

2727
def test_read_value_crashes(self):
2828
p = create_placeholder_tensor((1,), torch.bfloat16)
@@ -64,7 +64,7 @@ def test_cannot_get_handle_from_deleted_pjrt_buffer(self):
6464
_ = t0 + t1
6565
torch_xla.sync(wait=True)
6666

67-
self.assertTrue(torch_xla._XLAC._is_placecholder(t0))
67+
self.assertTrue(torch_xla._XLAC._is_placeholder(t0))
6868
with self.assertRaises(RuntimeError, msg='is deleted'):
6969
torch_xla._XLAC._get_tensors_handle([t0])
7070

test/tpu/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ run_test "$_TEST_DIR/test_grad_checkpoint.py" "$@" --test_autocast
7070
run_test "$_TEST_DIR/dynamo/test_dynamo.py"
7171
run_test "$_TEST_DIR/dynamo/test_dynamo_dynamic_shape.py"
7272
run_test "$_TEST_DIR/spmd/test_spmd_debugging.py"
73+
run_test "$_TEST_DIR/spmd/test_spmd_placeholder.py"
7374
XLA_PARAMETER_WRAPPING_THREADSHOLD=1 run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
7475
run_test "$_TEST_DIR/pjrt/test_dtypes.py"
7576
run_test "$_TEST_DIR/pjrt/test_dynamic_plugin_tpu.py"

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2878,7 +2878,7 @@ void InitXlaModuleBindings(py::module m) {
28782878
auto& coordinator = comp_client->GetCoordinator();
28792879
return coordinator.ReachedSyncPoint(step);
28802880
})
2881-
.def("_is_placecholder",
2881+
.def("_is_placeholder",
28822882
[](at::Tensor& input) {
28832883
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
28842884
return xtensor->CurrentDataHandle() &&

torch_xla/csrc/runtime/pjrt_computation_client.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,12 @@ class PjRtComputationClient : public ComputationClient {
278278
sharding(sharding) {}
279279

280280
Handle GetHandle() override {
281-
// Always returns `Handle` of the first shard.
281+
// If the data is a placeholder (no shards), use the address of this
282+
// object as the handle.
283+
if (shards.empty()) {
284+
return reinterpret_cast<std::uintptr_t>(this);
285+
}
286+
// Always returns `Handle` of the first shard, which is unique.
282287
return shards[0]->GetHandle();
283288
}
284289

torch_xla/csrc/xla_sharding_util.cpp

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -815,25 +815,40 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input,
815815
}
816816
}
817817

818-
// If the at::Tensor data is not present, we need to re-download the
819-
// tensor from the physical device to CPU. In that case, the value
820-
// must be present on the backend device.
821-
XLA_CHECK((xtensor->CurrentDataHandle() &&
822-
xtensor->CurrentDataHandle()->HasValue()) ||
823-
device_data_node != nullptr)
824-
<< "Cannot shard tensor. Data does not present on any device.";
825-
std::vector<XLATensorPtr> xla_tensors{xtensor};
826-
auto tensors = XLAGraphExecutor::Get()->GetTensors(&xla_tensors);
827-
XLA_CHECK_EQ(tensors.size(), 1);
828-
cpu_tensor = tensors[0];
829-
}
830-
auto xla_data = CreateTensorsData(
831-
std::vector<at::Tensor>{cpu_tensor},
832-
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec},
833-
std::vector<std::string>{GetVirtualDevice().toString()})[0];
834-
xtensor->SetXlaData(xla_data);
835-
xtensor->SetShardingSpec(*new_sharding_spec);
818+
if (xtensor->CurrentDataHandle() &&
819+
!xtensor->CurrentDataHandle()->HasValue()) {
820+
// For placeholder tensors, we skip the data transfer entirely and
821+
// directly create sharded placeholder data without materializing any CPU
822+
// tensor. This preserves the placeholder nature - it will still fail if
823+
// accessed.
824+
auto xla_data = runtime::GetComputationClient()->CreateDataPlaceholder(
825+
GetVirtualDevice().toString(),
826+
MakeShapeWithDeviceLayout(
827+
xtensor->shape(),
828+
static_cast<XlaDeviceType>(xtensor->GetDevice().type())),
829+
sharding);
830+
xtensor->SetXlaData(WrapXlaData({xla_data})[0]);
831+
} else {
832+
// If the at::Tensor data is not present, we need to re-download the
833+
// tensor from the physical device to CPU. In that case, the value
834+
// must be present on the backend device.
835+
XLA_CHECK((xtensor->CurrentDataHandle() &&
836+
xtensor->CurrentDataHandle()->HasValue()) ||
837+
device_data_node != nullptr)
838+
<< "Cannot shard tensor. Data does not present on any device.";
839+
std::vector<XLATensorPtr> xla_tensors{xtensor};
840+
auto tensors = XLAGraphExecutor::Get()->GetTensors(&xla_tensors);
841+
XLA_CHECK_EQ(tensors.size(), 1);
842+
cpu_tensor = tensors[0];
843+
auto xla_data = CreateTensorsData(
844+
std::vector<at::Tensor>{cpu_tensor},
845+
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec},
846+
std::vector<std::string>{GetVirtualDevice().toString()})[0];
847+
xtensor->SetXlaData(xla_data);
848+
}
849+
}
836850

851+
xtensor->SetShardingSpec(*new_sharding_spec);
837852
// Register sharded tensor data.
838853
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
839854
}

0 commit comments

Comments
 (0)