Skip to content

Commit 88aa24f

Browse files
committed
Support SPMD placeholder tensors
1 parent 8b95c5d commit 88aa24f

File tree

9 files changed

+106
-7
lines changed

9 files changed

+106
-7
lines changed

test/neuron/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ function run_xla_op_tests3 {
242242
#run_test "$_TEST_DIR/spmd/test_xla_virtual_device.py"
243243
#run_test "$_TEST_DIR/spmd/test_dynamo_spmd.py"
244244
run_test "$_TEST_DIR/spmd/test_spmd_debugging.py"
245+
run_test "$_TEST_DIR/spmd/test_spmd_placeholder.py"
245246
#=run_test "$_TEST_DIR/spmd/test_xla_distributed_checkpoint.py"
246247
run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py"
247248
#run_test "$_TEST_DIR/spmd/test_dtensor_integration.py"

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ function run_xla_op_tests3 {
268268
run_test "$_TEST_DIR/test_persistent_cache.py"
269269
run_test "$_TEST_DIR/test_devices.py"
270270
run_test "$_TEST_DIR/test_manual_xla_registration.py"
271+
run_test "$_TEST_DIR/spmd/test_spmd_placeholder.py"
271272
# NOTE: this line below is testing export and don't care about GPU
272273
PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$_TEST_DIR/test_core_aten_ops.py"
273274
run_test "$_TEST_DIR/test_pallas.py"

test/spmd/test_dynamo_spmd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def test_dynamo_input_sharding_threashold(self):
176176
print('catch')
177177
# it is hard to catch the C++ runtime error in python, instead we can check if
178178
# after printing that dynamo_res is still a placeholder then it means C++ crashed.
179-
self.assertTrue(torch_xla._XLAC._is_placecholder(dynamo_res))
179+
self.assertTrue(torch_xla._XLAC._is_placeholder(dynamo_res))
180180
if saved_var != None:
181181
os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] = saved_var
182182
else:

test/spmd/test_spmd_placeholder.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import sys
2+
import unittest
3+
import torch
4+
import torch_xla
5+
from torch_xla.core.xla_builder import create_placeholder_tensor
6+
import torch_xla.debug.metrics as met
7+
import re
8+
import torch_xla.runtime as xr
9+
import torch_xla.distributed.spmd as xs
10+
11+
import test_xla_sharding_base
12+
13+
14+
class TestSPMDPlaceholder(test_xla_sharding_base.XlaShardingTest):
15+
16+
def setUp(self):
17+
super().setUpClass()
18+
19+
def test_create_placeholder(self):
20+
num_devices = self.n_devices
21+
for shape, dtype in zip(
22+
((num_devices, num_devices), (num_devices, num_devices, 2),
23+
(num_devices, num_devices, 2, 2)),
24+
(torch.float32, torch.bfloat16, torch.int8),
25+
):
26+
model_axis = max(1, self.n_devices // 2)
27+
data_axis = self.n_devices // model_axis
28+
mesh_shape = (data_axis, model_axis) + (1,) * (len(shape) - 2)
29+
axis_names = ('x', 'y') + tuple(f'z{i}' for i in range(1, len(shape) - 1))
30+
mesh = self._get_mesh(mesh_shape, axis_names=axis_names)
31+
32+
p = create_placeholder_tensor(shape, dtype)
33+
xs.mark_sharding(p, mesh, axis_names)
34+
assert isinstance(p, torch.Tensor)
35+
assert p.device == torch_xla.device()
36+
self.assertEqual(p.dtype, dtype)
37+
self.assertEqual(p.shape, shape)
38+
self.assertTrue(torch_xla._XLAC._is_placeholder(p))
39+
40+
def test_read_value_crashes(self):
41+
mesh = self._get_mesh((self.n_devices,), axis_names=('x',))
42+
p = create_placeholder_tensor((self.n_devices,), torch.bfloat16)
43+
xs.mark_sharding(p, mesh, ('x',))
44+
with self.assertRaises(RuntimeError):
45+
p.cpu()
46+
47+
def test_trace_graph(self):
48+
met.clear_all()
49+
self.assertFalse(met.metric_data("TransferToDeviceTime"))
50+
51+
model_axis = max(1, self.n_devices // 2)
52+
data_axis = self.n_devices // model_axis
53+
mesh_shape = (data_axis, model_axis)
54+
mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y'))
55+
56+
p1 = create_placeholder_tensor((128, 32), torch.bfloat16)
57+
xs.mark_sharding(p1, mesh, ('x', 'y'))
58+
a = torch.sin(p1)
59+
60+
p2 = create_placeholder_tensor((32, 64), torch.bfloat16)
61+
xs.mark_sharding(p2, mesh, ('x', 'y'))
62+
# We use p1 once and p2 twice. But the graph should still only have two parameters.
63+
b = (a @ p2) @ p2.T
64+
ir: str = torch_xla._XLAC._get_xla_tensors_text([b])
65+
self.assertEqual(ir.count("xla::device_data()"), 2)
66+
self.assertEqual(ir.count("bf16[32,64]{1,0} xla::device_data()"), 1)
67+
self.assertEqual(ir.count("bf16[128,32]{1,0} xla::device_data()"), 1)
68+
hlo: str = torch_xla._XLAC._get_xla_tensors_hlo([b])
69+
print(hlo)
70+
regex = r'\(p.*: bf16\[32,64\], p.*: bf16\[128,32\]\) -> \(bf16\[128,32\]\)'
71+
assert re.search(regex, hlo) is not None
72+
73+
# There should be no buffers transferred to the device during tracing
74+
self.assertFalse(met.metric_data("TransferToDeviceTime"))
75+
76+
def test_placeholder_handle_unique(self):
77+
mesh = self._get_mesh((self.n_devices,), axis_names=('x',))
78+
79+
p1 = create_placeholder_tensor((self.n_devices,), torch.bfloat16)
80+
xs.mark_sharding(p1, mesh, ('x',))
81+
82+
p2 = create_placeholder_tensor((self.n_devices,), torch.bfloat16)
83+
xs.mark_sharding(p2, mesh, ('x',))
84+
85+
h1, h2 = torch_xla._XLAC._get_tensors_handle([p1, p2])
86+
self.assertNotEqual(h1, h2)
87+
88+
89+
if __name__ == "__main__":
90+
test = unittest.main()
91+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/test_input_output_aliases.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def test_user_config_donation_with_ltc_donation(self):
247247

248248
# We surface the C++ runtime error by checking that the backend data is
249249
# no longer present for the IR node.
250-
self.assertTrue(torch_xla._XLAC._is_placecholder(t0))
250+
self.assertTrue(torch_xla._XLAC._is_placeholder(t0))
251251
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0)
252252

253253
@parameterized.parameters(True, False)
@@ -272,7 +272,7 @@ def test_user_config_donation_with_ltc_donation_graph_sync(
272272
# We surface the C++ runtime error by checking that the backend data is
273273
# no longer present for the IR node.
274274
self.assertEqual(
275-
torch_xla._XLAC._is_placecholder(t0), enable_buffer_donor_config)
275+
torch_xla._XLAC._is_placeholder(t0), enable_buffer_donor_config)
276276
self.assertEqual(
277277
met.metric_data("InputOutputAliasCount")[1],
278278
enable_buffer_donor_config)

test/test_placeholder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_create_placeholder(self):
2222
assert p.device == torch_xla.device()
2323
self.assertEqual(p.dtype, dtype)
2424
self.assertEqual(p.shape, shape)
25-
self.assertTrue(torch_xla._XLAC._is_placecholder(p))
25+
self.assertTrue(torch_xla._XLAC._is_placeholder(p))
2626

2727
def test_read_value_crashes(self):
2828
p = create_placeholder_tensor((1,), torch.bfloat16)
@@ -64,7 +64,7 @@ def test_cannot_get_handle_from_deleted_pjrt_buffer(self):
6464
_ = t0 + t1
6565
torch_xla.sync(wait=True)
6666

67-
self.assertTrue(torch_xla._XLAC._is_placecholder(t0))
67+
self.assertTrue(torch_xla._XLAC._is_placeholder(t0))
6868
with self.assertRaises(RuntimeError, msg='is deleted'):
6969
torch_xla._XLAC._get_tensors_handle([t0])
7070

test/tpu/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ run_test "$_TEST_DIR/test_grad_checkpoint.py" "$@" --test_autocast
7070
run_test "$_TEST_DIR/dynamo/test_dynamo.py"
7171
run_test "$_TEST_DIR/dynamo/test_dynamo_dynamic_shape.py"
7272
run_test "$_TEST_DIR/spmd/test_spmd_debugging.py"
73+
run_test "$_TEST_DIR/spmd/test_spmd_placeholder.py"
7374
XLA_PARAMETER_WRAPPING_THREADSHOLD=1 run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
7475
run_test "$_TEST_DIR/pjrt/test_dtypes.py"
7576
run_test "$_TEST_DIR/pjrt/test_dynamic_plugin_tpu.py"

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2878,7 +2878,7 @@ void InitXlaModuleBindings(py::module m) {
28782878
auto& coordinator = comp_client->GetCoordinator();
28792879
return coordinator.ReachedSyncPoint(step);
28802880
})
2881-
.def("_is_placecholder",
2881+
.def("_is_placeholder",
28822882
[](at::Tensor& input) {
28832883
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
28842884
return xtensor->CurrentDataHandle() &&

torch_xla/csrc/runtime/pjrt_computation_client.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,12 @@ class PjRtComputationClient : public ComputationClient {
278278
sharding(sharding) {}
279279

280280
Handle GetHandle() override {
281-
// Always returns `Handle` of the first shard.
281+
// If the data is a placeholder (no shards), use the address of this
282+
// object as the handle.
283+
if (shards.empty()) {
284+
return reinterpret_cast<std::uintptr_t>(this);
285+
}
286+
// Always returns `Handle` of the first shard, which is unique.
282287
return shards[0]->GetHandle();
283288
}
284289

0 commit comments

Comments
 (0)