@@ -30,9 +30,11 @@ def mock_add_rms_norm(x, residual, weight, eps):
30
30
[None , torch .randn (4 , 8 , dtype = torch .float32 )])
31
31
@patch ("torch_npu.npu_rms_norm" , side_effect = mock_rms_norm )
32
32
@patch ("torch_npu.npu_add_rms_norm" , side_effect = mock_add_rms_norm )
33
+ @patch ("torch.ops.vllm.maybe_wait_prefetch_done" , side_effect = lambda x : None )
33
34
@patch ("torch.ops.vllm.maybe_chunk_residual" ,
34
35
side_effect = mock_maybe_chunk_residual )
35
- def test_RMSNorm_forward (mock_maybe_chunk_residual , mock_add_rmsnorm ,
36
+ def test_RMSNorm_forward (mock_maybe_chunk_residual ,
37
+ mock_maybe_wait_prefetch_done , mock_add_rmsnorm ,
36
38
mock_rmsnorm , is_310p_return , residual , dummy_tensor ):
37
39
38
40
with patch ("vllm_ascend.utils.is_310p" , return_value = is_310p_return ):
@@ -45,13 +47,17 @@ def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm,
45
47
expected_out_x = expected_arg_x + 1
46
48
expected_out_residual = expected_arg_x .to (residual .dtype )
47
49
50
+ mock_maybe_chunk_residual .assert_called_once ()
48
51
mock_rmsnorm .assert_called_once ()
52
+ mock_maybe_wait_prefetch_done .assert_called_once ()
49
53
assert torch .allclose (out_x , expected_out_x )
50
54
assert torch .allclose (out_residual , expected_out_residual )
51
55
else :
52
56
expected_out_x = 2 * dummy_tensor
53
57
expected_out_residual = 2 * residual
58
+ mock_maybe_chunk_residual .assert_called_once ()
54
59
mock_add_rmsnorm .assert_called_once ()
60
+ mock_maybe_wait_prefetch_done .assert_called_once ()
55
61
assert torch .allclose (out_x , expected_out_x )
56
62
assert torch .allclose (out_residual , expected_out_residual )
57
63
else :
@@ -64,9 +70,11 @@ def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm,
64
70
65
71
@patch ("vllm_ascend.utils.is_310p" , return_value = False )
66
72
@patch ("torch_npu.npu_add_rms_norm" , side_effect = mock_add_rms_norm )
73
+ @patch ("torch.ops.vllm.maybe_wait_prefetch_done" , side_effect = lambda x : None )
67
74
@patch ("torch.ops.vllm.maybe_chunk_residual" ,
68
75
side_effect = mock_maybe_chunk_residual )
69
76
def test_RMSNorm_forward_with_flashcomm_v1 (mock_maybe_chunk_residual ,
77
+ mock_maybe_wait_prefetch_done ,
70
78
mock_add_rms_norm , mock_is310p ):
71
79
x = torch .randn (4 , 512 , dtype = torch .bfloat16 )
72
80
residual = torch .randn (16 , 512 , dtype = torch .bfloat16 )
@@ -79,6 +87,7 @@ def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual,
79
87
80
88
mock_maybe_chunk_residual .assert_called_once ()
81
89
mock_add_rms_norm .assert_called_once ()
90
+ mock_maybe_wait_prefetch_done .assert_called_once ()
82
91
assert out_residual .size (0 ) == 4
83
92
assert torch .allclose (out_x , expected_out_x )
84
93
assert torch .allclose (out_residual , expected_out_residual )
0 commit comments