Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/neuron/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ function run_xla_op_tests3 {
#run_test "$_TEST_DIR/spmd/test_xla_virtual_device.py"
#run_test "$_TEST_DIR/spmd/test_dynamo_spmd.py"
run_test "$_TEST_DIR/spmd/test_spmd_debugging.py"
run_test "$_TEST_DIR/spmd/test_spmd_placeholder.py"
#=run_test "$_TEST_DIR/spmd/test_xla_distributed_checkpoint.py"
run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py"
#run_test "$_TEST_DIR/spmd/test_dtensor_integration.py"
Expand Down
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ function run_xla_op_tests3 {
run_test "$_TEST_DIR/test_persistent_cache.py"
run_test "$_TEST_DIR/test_devices.py"
run_test "$_TEST_DIR/test_manual_xla_registration.py"
run_test "$_TEST_DIR/spmd/test_spmd_placeholder.py"
# NOTE: this line below is testing export and don't care about GPU
PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$_TEST_DIR/test_core_aten_ops.py"
run_test "$_TEST_DIR/test_pallas.py"
Expand Down
2 changes: 1 addition & 1 deletion test/spmd/test_dynamo_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def test_dynamo_input_sharding_threashold(self):
print('catch')
# it is hard to catch the C++ runtime error in python, instead we can check if
# after printing that dynamo_res is still a placeholder then it means C++ crashed.
self.assertTrue(torch_xla._XLAC._is_placecholder(dynamo_res))
self.assertTrue(torch_xla._XLAC._is_placeholder(dynamo_res))
if saved_var != None:
os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] = saved_var
else:
Expand Down
90 changes: 90 additions & 0 deletions test/spmd/test_spmd_placeholder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import sys
import unittest
import torch
import torch_xla
from torch_xla.core.xla_builder import create_placeholder_tensor
import torch_xla.debug.metrics as met
import re
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs

import test_xla_sharding_base


class TestSPMDPlaceholder(test_xla_sharding_base.XlaShardingTest):

def setUp(self):
super().setUpClass()

def test_create_placeholder(self):
num_devices = self.n_devices
for shape, dtype in zip(
((num_devices, num_devices), (num_devices, num_devices, 2),
(num_devices, num_devices, 2, 2)),
(torch.float32, torch.bfloat16, torch.int8),
):
model_axis = max(1, self.n_devices // 2)
data_axis = self.n_devices // model_axis
mesh_shape = (data_axis, model_axis) + (1,) * (len(shape) - 2)
axis_names = ('x', 'y') + tuple(f'z{i}' for i in range(1, len(shape) - 1))
mesh = self._get_mesh(mesh_shape, axis_names=axis_names)

p = create_placeholder_tensor(shape, dtype)
xs.mark_sharding(p, mesh, axis_names)
assert isinstance(p, torch.Tensor)
assert p.device == torch_xla.device()
self.assertEqual(p.dtype, dtype)
self.assertEqual(p.shape, shape)
self.assertTrue(torch_xla._XLAC._is_placeholder(p))

def test_read_value_crashes(self):
mesh = self._get_mesh((self.n_devices,), axis_names=('x',))
p = create_placeholder_tensor((self.n_devices,), torch.bfloat16)
xs.mark_sharding(p, mesh, ('x',))
with self.assertRaises(RuntimeError):
p.cpu()

def test_trace_graph(self):
met.clear_all()
self.assertFalse(met.metric_data("TransferToDeviceTime"))

model_axis = max(1, self.n_devices // 2)
data_axis = self.n_devices // model_axis
mesh_shape = (data_axis, model_axis)
mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y'))

p1 = create_placeholder_tensor((128, 32), torch.bfloat16)
xs.mark_sharding(p1, mesh, ('x', 'y'))
a = torch.sin(p1)

p2 = create_placeholder_tensor((32, 64), torch.bfloat16)
xs.mark_sharding(p2, mesh, ('x', 'y'))
# We use p1 once and p2 twice. But the graph should still only have two parameters.
b = (a @ p2) @ p2.T
ir: str = torch_xla._XLAC._get_xla_tensors_text([b])
self.assertEqual(ir.count("xla::device_data()"), 2)
self.assertEqual(ir.count("bf16[32,64]{1,0} xla::device_data()"), 1)
self.assertEqual(ir.count("bf16[128,32]{1,0} xla::device_data()"), 1)
hlo: str = torch_xla._XLAC._get_xla_tensors_hlo([b])
regex = r'\(p.*: bf16\[32,64\], p.*: bf16\[128,32\]\) -> \(bf16\[128,32\]\)'
assert re.search(regex, hlo) is not None

# There should be no buffers transferred to the device during tracing
self.assertFalse(met.metric_data("TransferToDeviceTime"))

def test_placeholder_handle_unique(self):
mesh = self._get_mesh((self.n_devices,), axis_names=('x',))

p1 = create_placeholder_tensor((self.n_devices,), torch.bfloat16)
xs.mark_sharding(p1, mesh, ('x',))

p2 = create_placeholder_tensor((self.n_devices,), torch.bfloat16)
xs.mark_sharding(p2, mesh, ('x',))

h1, h2 = torch_xla._XLAC._get_tensors_handle([p1, p2])
self.assertNotEqual(h1, h2)


if __name__ == "__main__":
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
4 changes: 2 additions & 2 deletions test/test_input_output_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def test_user_config_donation_with_ltc_donation(self):

# We surface the C++ runtime error by checking that the backend data is
# no longer present for the IR node.
self.assertTrue(torch_xla._XLAC._is_placecholder(t0))
self.assertTrue(torch_xla._XLAC._is_placeholder(t0))
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0)

@parameterized.parameters(True, False)
Expand All @@ -272,7 +272,7 @@ def test_user_config_donation_with_ltc_donation_graph_sync(
# We surface the C++ runtime error by checking that the backend data is
# no longer present for the IR node.
self.assertEqual(
torch_xla._XLAC._is_placecholder(t0), enable_buffer_donor_config)
torch_xla._XLAC._is_placeholder(t0), enable_buffer_donor_config)
self.assertEqual(
met.metric_data("InputOutputAliasCount")[1],
enable_buffer_donor_config)
Expand Down
4 changes: 2 additions & 2 deletions test/test_placeholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_create_placeholder(self):
assert p.device == torch_xla.device()
self.assertEqual(p.dtype, dtype)
self.assertEqual(p.shape, shape)
self.assertTrue(torch_xla._XLAC._is_placecholder(p))
self.assertTrue(torch_xla._XLAC._is_placeholder(p))

def test_read_value_crashes(self):
p = create_placeholder_tensor((1,), torch.bfloat16)
Expand Down Expand Up @@ -64,7 +64,7 @@ def test_cannot_get_handle_from_deleted_pjrt_buffer(self):
_ = t0 + t1
torch_xla.sync(wait=True)

self.assertTrue(torch_xla._XLAC._is_placecholder(t0))
self.assertTrue(torch_xla._XLAC._is_placeholder(t0))
with self.assertRaises(RuntimeError, msg='is deleted'):
torch_xla._XLAC._get_tensors_handle([t0])

Expand Down
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ run_test "$_TEST_DIR/test_grad_checkpoint.py" "$@" --test_autocast
run_test "$_TEST_DIR/dynamo/test_dynamo.py"
run_test "$_TEST_DIR/dynamo/test_dynamo_dynamic_shape.py"
run_test "$_TEST_DIR/spmd/test_spmd_debugging.py"
run_test "$_TEST_DIR/spmd/test_spmd_placeholder.py"
XLA_PARAMETER_WRAPPING_THREADSHOLD=1 run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
run_test "$_TEST_DIR/pjrt/test_dtypes.py"
run_test "$_TEST_DIR/pjrt/test_dynamic_plugin_tpu.py"
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2878,7 +2878,7 @@ void InitXlaModuleBindings(py::module m) {
auto& coordinator = comp_client->GetCoordinator();
return coordinator.ReachedSyncPoint(step);
})
.def("_is_placecholder",
.def("_is_placeholder",
[](at::Tensor& input) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
return xtensor->CurrentDataHandle() &&
Expand Down
7 changes: 6 additions & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,12 @@ class PjRtComputationClient : public ComputationClient {
sharding(sharding) {}

Handle GetHandle() override {
// Always returns `Handle` of the first shard.
// If the data is a placeholder (no shards), use the address of this
// object as the handle.
if (shards.empty()) {
return reinterpret_cast<std::uintptr_t>(this);
}
// Always returns `Handle` of the first shard, which is unique.
return shards[0]->GetHandle();
}

Expand Down
52 changes: 34 additions & 18 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -815,25 +815,41 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input,
}
}

// If the at::Tensor data is not present, we need to re-download the
// tensor from the physical device to CPU. In that case, the value
// must be present on the backend device.
XLA_CHECK((xtensor->CurrentDataHandle() &&
xtensor->CurrentDataHandle()->HasValue()) ||
device_data_node != nullptr)
<< "Cannot shard tensor. Data does not present on any device.";
std::vector<XLATensorPtr> xla_tensors{xtensor};
auto tensors = XLAGraphExecutor::Get()->GetTensors(&xla_tensors);
XLA_CHECK_EQ(tensors.size(), 1);
cpu_tensor = tensors[0];
}
auto xla_data = CreateTensorsData(
std::vector<at::Tensor>{cpu_tensor},
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec},
std::vector<std::string>{GetVirtualDevice().toString()})[0];
xtensor->SetXlaData(xla_data);
xtensor->SetShardingSpec(*new_sharding_spec);
if (xtensor->CurrentDataHandle() &&
!xtensor->CurrentDataHandle()->HasValue()) {
// For placeholder tensors, we skip the data transfer entirely and
// directly create sharded placeholder data without materializing any CPU
// tensor. This preserves the placeholder nature - it will still fail if
// accessed.
auto xla_data =
runtime::GetComputationClientOrDie()->CreateDataPlaceholder(
GetVirtualDevice().toString(),
MakeShapeWithDeviceLayout(
xtensor->shape(),
static_cast<XlaDeviceType>(xtensor->GetDevice().type())),
sharding);
xtensor->SetXlaData(WrapXlaData({xla_data})[0]);
} else {
// If the at::Tensor data is not present, we need to re-download the
// tensor from the physical device to CPU. In that case, the value
// must be present on the backend device.
XLA_CHECK((xtensor->CurrentDataHandle() &&
xtensor->CurrentDataHandle()->HasValue()) ||
device_data_node != nullptr)
<< "Cannot shard tensor. Data does not present on any device.";
std::vector<XLATensorPtr> xla_tensors{xtensor};
auto tensors = XLAGraphExecutor::Get()->GetTensors(&xla_tensors);
XLA_CHECK_EQ(tensors.size(), 1);
cpu_tensor = tensors[0];
auto xla_data = CreateTensorsData(
std::vector<at::Tensor>{cpu_tensor},
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec},
std::vector<std::string>{GetVirtualDevice().toString()})[0];
xtensor->SetXlaData(xla_data);
}
}

xtensor->SetShardingSpec(*new_sharding_spec);
// Register sharded tensor data.
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
}
Expand Down
Loading