diff --git a/test/legacy_test/test_gather_op.py b/test/legacy_test/test_gather_op.py index 7f13a2ece92d1..6d16404e861f4 100644 --- a/test/legacy_test/test_gather_op.py +++ b/test/legacy_test/test_gather_op.py @@ -747,6 +747,28 @@ def test_out2(self): np.testing.assert_allclose(result, expected_output, rtol=1e-05) +@unittest.skipIf( + not (core.is_compiled_with_cuda() or is_custom_device()), + "only support compiled with CUDA.", +) +class TestGatherGPUCPUConsistency(unittest.TestCase): + def test_gpu_cpu_consistency(self): + paddle.disable_static() + np.random.seed(42) + x = np.random.rand(1000, 128).astype("float32") + index = np.random.randint(0, 1000, size=(100,)) + cpu_out = paddle.gather( + paddle.to_tensor(x, place=paddle.CPUPlace()), + paddle.to_tensor(index), + ) + gpu_out = paddle.gather( + paddle.to_tensor(x, place=paddle.CUDAPlace(0)), + paddle.to_tensor(index), + ) + np.testing.assert_allclose(cpu_out.numpy(), gpu_out.numpy(), rtol=1e-6) + paddle.enable_static() + + class API_TestDygraphGather(unittest.TestCase): def test_out1(self): paddle.disable_static()