diff --git a/test/neuron/run_tests.sh b/test/neuron/run_tests.sh index 22e778945f9..92aca7c145e 100755 --- a/test/neuron/run_tests.sh +++ b/test/neuron/run_tests.sh @@ -242,6 +242,7 @@ function run_xla_op_tests3 { #run_test "$_TEST_DIR/spmd/test_xla_virtual_device.py" #run_test "$_TEST_DIR/spmd/test_dynamo_spmd.py" run_test "$_TEST_DIR/spmd/test_spmd_debugging.py" + run_test "$_TEST_DIR/spmd/test_spmd_placeholder.py" #=run_test "$_TEST_DIR/spmd/test_xla_distributed_checkpoint.py" run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py" #run_test "$_TEST_DIR/spmd/test_dtensor_integration.py" diff --git a/test/run_tests.sh b/test/run_tests.sh index 93f4cb33c06..4f9406167e5 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -268,6 +268,7 @@ function run_xla_op_tests3 { run_test "$_TEST_DIR/test_persistent_cache.py" run_test "$_TEST_DIR/test_devices.py" run_test "$_TEST_DIR/test_manual_xla_registration.py" + run_test "$_TEST_DIR/spmd/test_spmd_placeholder.py" # NOTE: this line below is testing export and don't care about GPU PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$_TEST_DIR/test_core_aten_ops.py" run_test "$_TEST_DIR/test_pallas.py" diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index 518e4203b45..654ece0066f 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -176,7 +176,7 @@ def test_dynamo_input_sharding_threashold(self): print('catch') # it is hard to catch the C++ runtime error in python, instead we can check if # after printing that dynamo_res is still a placeholder then it means C++ crashed. - self.assertTrue(torch_xla._XLAC._is_placecholder(dynamo_res)) + self.assertTrue(torch_xla._XLAC._is_placeholder(dynamo_res)) if saved_var != None: os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] = saved_var else: diff --git a/test/spmd/test_spmd_placeholder.py b/test/spmd/test_spmd_placeholder.py new file mode 100644 index 00000000000..124881eb1ba --- /dev/null +++ b/test/spmd/test_spmd_placeholder.py @@ -0,0 +1,90 @@ +import sys +import unittest +import torch +import torch_xla +from torch_xla.core.xla_builder import create_placeholder_tensor +import torch_xla.debug.metrics as met +import re +import torch_xla.runtime as xr +import torch_xla.distributed.spmd as xs + +import test_xla_sharding_base + + +class TestSPMDPlaceholder(test_xla_sharding_base.XlaShardingTest): + + def setUp(self): + super().setUpClass() + + def test_create_placeholder(self): + num_devices = self.n_devices + for shape, dtype in zip( + ((num_devices, num_devices), (num_devices, num_devices, 2), + (num_devices, num_devices, 2, 2)), + (torch.float32, torch.bfloat16, torch.int8), + ): + model_axis = max(1, self.n_devices // 2) + data_axis = self.n_devices // model_axis + mesh_shape = (data_axis, model_axis) + (1,) * (len(shape) - 2) + axis_names = ('x', 'y') + tuple(f'z{i}' for i in range(1, len(shape) - 1)) + mesh = self._get_mesh(mesh_shape, axis_names=axis_names) + + p = create_placeholder_tensor(shape, dtype) + xs.mark_sharding(p, mesh, axis_names) + assert isinstance(p, torch.Tensor) + assert p.device == torch_xla.device() + self.assertEqual(p.dtype, dtype) + self.assertEqual(p.shape, shape) + self.assertTrue(torch_xla._XLAC._is_placeholder(p)) + + def test_read_value_crashes(self): + mesh = self._get_mesh((self.n_devices,), axis_names=('x',)) + p = create_placeholder_tensor((self.n_devices,), torch.bfloat16) + xs.mark_sharding(p, mesh, ('x',)) + with self.assertRaises(RuntimeError): + p.cpu() + + def test_trace_graph(self): + met.clear_all() + self.assertFalse(met.metric_data("TransferToDeviceTime")) + + model_axis = max(1, self.n_devices // 2) + data_axis = self.n_devices // model_axis + mesh_shape = (data_axis, model_axis) + mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y')) + + p1 = create_placeholder_tensor((128, 32), torch.bfloat16) + xs.mark_sharding(p1, mesh, ('x', 'y')) + a = torch.sin(p1) + + p2 = create_placeholder_tensor((32, 64), torch.bfloat16) + xs.mark_sharding(p2, mesh, ('x', 'y')) + # We use p1 once and p2 twice. But the graph should still only have two parameters. + b = (a @ p2) @ p2.T + ir: str = torch_xla._XLAC._get_xla_tensors_text([b]) + self.assertEqual(ir.count("xla::device_data()"), 2) + self.assertEqual(ir.count("bf16[32,64]{1,0} xla::device_data()"), 1) + self.assertEqual(ir.count("bf16[128,32]{1,0} xla::device_data()"), 1) + hlo: str = torch_xla._XLAC._get_xla_tensors_hlo([b]) + regex = r'\(p.*: bf16\[32,64\], p.*: bf16\[128,32\]\) -> \(bf16\[128,32\]\)' + assert re.search(regex, hlo) is not None + + # There should be no buffers transferred to the device during tracing + self.assertFalse(met.metric_data("TransferToDeviceTime")) + + def test_placeholder_handle_unique(self): + mesh = self._get_mesh((self.n_devices,), axis_names=('x',)) + + p1 = create_placeholder_tensor((self.n_devices,), torch.bfloat16) + xs.mark_sharding(p1, mesh, ('x',)) + + p2 = create_placeholder_tensor((self.n_devices,), torch.bfloat16) + xs.mark_sharding(p2, mesh, ('x',)) + + h1, h2 = torch_xla._XLAC._get_tensors_handle([p1, p2]) + self.assertNotEqual(h1, h2) + + +if __name__ == "__main__": + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_input_output_aliases.py b/test/test_input_output_aliases.py index 3f20f9d25c9..2a161494b03 100644 --- a/test/test_input_output_aliases.py +++ b/test/test_input_output_aliases.py @@ -247,7 +247,7 @@ def test_user_config_donation_with_ltc_donation(self): # We surface the C++ runtime error by checking that the backend data is # no longer present for the IR node. - self.assertTrue(torch_xla._XLAC._is_placecholder(t0)) + self.assertTrue(torch_xla._XLAC._is_placeholder(t0)) self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0) @parameterized.parameters(True, False) @@ -272,7 +272,7 @@ def test_user_config_donation_with_ltc_donation_graph_sync( # We surface the C++ runtime error by checking that the backend data is # no longer present for the IR node. self.assertEqual( - torch_xla._XLAC._is_placecholder(t0), enable_buffer_donor_config) + torch_xla._XLAC._is_placeholder(t0), enable_buffer_donor_config) self.assertEqual( met.metric_data("InputOutputAliasCount")[1], enable_buffer_donor_config) diff --git a/test/test_placeholder.py b/test/test_placeholder.py index d5506bfacd5..a201a59286f 100644 --- a/test/test_placeholder.py +++ b/test/test_placeholder.py @@ -22,7 +22,7 @@ def test_create_placeholder(self): assert p.device == torch_xla.device() self.assertEqual(p.dtype, dtype) self.assertEqual(p.shape, shape) - self.assertTrue(torch_xla._XLAC._is_placecholder(p)) + self.assertTrue(torch_xla._XLAC._is_placeholder(p)) def test_read_value_crashes(self): p = create_placeholder_tensor((1,), torch.bfloat16) @@ -64,7 +64,7 @@ def test_cannot_get_handle_from_deleted_pjrt_buffer(self): _ = t0 + t1 torch_xla.sync(wait=True) - self.assertTrue(torch_xla._XLAC._is_placecholder(t0)) + self.assertTrue(torch_xla._XLAC._is_placeholder(t0)) with self.assertRaises(RuntimeError, msg='is deleted'): torch_xla._XLAC._get_tensors_handle([t0]) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 1f6f5249b93..f90c4ddbeb9 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -70,6 +70,7 @@ run_test "$_TEST_DIR/test_grad_checkpoint.py" "$@" --test_autocast run_test "$_TEST_DIR/dynamo/test_dynamo.py" run_test "$_TEST_DIR/dynamo/test_dynamo_dynamic_shape.py" run_test "$_TEST_DIR/spmd/test_spmd_debugging.py" +run_test "$_TEST_DIR/spmd/test_spmd_placeholder.py" XLA_PARAMETER_WRAPPING_THREADSHOLD=1 run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" run_test "$_TEST_DIR/pjrt/test_dtypes.py" run_test "$_TEST_DIR/pjrt/test_dynamic_plugin_tpu.py" diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index ce55969d693..60cad321cf6 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2878,7 +2878,7 @@ void InitXlaModuleBindings(py::module m) { auto& coordinator = comp_client->GetCoordinator(); return coordinator.ReachedSyncPoint(step); }) - .def("_is_placecholder", + .def("_is_placeholder", [](at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); return xtensor->CurrentDataHandle() && diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index b7c61e2ec74..681852d7644 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -278,7 +278,12 @@ class PjRtComputationClient : public ComputationClient { sharding(sharding) {} Handle GetHandle() override { - // Always returns `Handle` of the first shard. + // If the data is a placeholder (no shards), use the address of this + // object as the handle. + if (shards.empty()) { + return reinterpret_cast(this); + } + // Always returns `Handle` of the first shard, which is unique. return shards[0]->GetHandle(); } diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 3ddec53d700..4951a3a4caf 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -815,25 +815,41 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input, } } - // If the at::Tensor data is not present, we need to re-download the - // tensor from the physical device to CPU. In that case, the value - // must be present on the backend device. - XLA_CHECK((xtensor->CurrentDataHandle() && - xtensor->CurrentDataHandle()->HasValue()) || - device_data_node != nullptr) - << "Cannot shard tensor. Data does not present on any device."; - std::vector xla_tensors{xtensor}; - auto tensors = XLAGraphExecutor::Get()->GetTensors(&xla_tensors); - XLA_CHECK_EQ(tensors.size(), 1); - cpu_tensor = tensors[0]; - } - auto xla_data = CreateTensorsData( - std::vector{cpu_tensor}, - std::vector{new_sharding_spec}, - std::vector{GetVirtualDevice().toString()})[0]; - xtensor->SetXlaData(xla_data); - xtensor->SetShardingSpec(*new_sharding_spec); + if (xtensor->CurrentDataHandle() && + !xtensor->CurrentDataHandle()->HasValue()) { + // For placeholder tensors, we skip the data transfer entirely and + // directly create sharded placeholder data without materializing any CPU + // tensor. This preserves the placeholder nature - it will still fail if + // accessed. + auto xla_data = + runtime::GetComputationClientOrDie()->CreateDataPlaceholder( + GetVirtualDevice().toString(), + MakeShapeWithDeviceLayout( + xtensor->shape(), + static_cast(xtensor->GetDevice().type())), + sharding); + xtensor->SetXlaData(WrapXlaData({xla_data})[0]); + } else { + // If the at::Tensor data is not present, we need to re-download the + // tensor from the physical device to CPU. In that case, the value + // must be present on the backend device. + XLA_CHECK((xtensor->CurrentDataHandle() && + xtensor->CurrentDataHandle()->HasValue()) || + device_data_node != nullptr) + << "Cannot shard tensor. Data does not present on any device."; + std::vector xla_tensors{xtensor}; + auto tensors = XLAGraphExecutor::Get()->GetTensors(&xla_tensors); + XLA_CHECK_EQ(tensors.size(), 1); + cpu_tensor = tensors[0]; + auto xla_data = CreateTensorsData( + std::vector{cpu_tensor}, + std::vector{new_sharding_spec}, + std::vector{GetVirtualDevice().toString()})[0]; + xtensor->SetXlaData(xla_data); + } + } + xtensor->SetShardingSpec(*new_sharding_spec); // Register sharded tensor data. XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); }