1
1
import torch
2
2
import torch .nn .functional as F
3
+ import torch_npu
3
4
from vllm .distributed import (get_tensor_model_parallel_rank ,
4
5
get_tensor_model_parallel_world_size ,
5
6
tensor_model_parallel_all_gather ,
8
9
from vllm .forward_context import get_forward_context
9
10
from vllm .utils import direct_register_custom_op
10
11
12
+ import vllm_ascend .envs as envs_ascend
13
+
11
14
12
15
def _maybe_chunk_residual_impl (x : torch .Tensor ,
13
16
residual : torch .Tensor ) -> torch .Tensor :
14
- if get_forward_context ().flashcomm_v1_enabled :
17
+ if x .size (0 ) != residual .size (0 ):
18
+ flashcomm_v1_enabled = get_forward_context ().flashcomm_v1_enabled
19
+ assert flashcomm_v1_enabled is True , (
20
+ "Currently, this situation only occurs "
21
+ "when flashcomm_v1 is enabled" )
15
22
pad_size = get_forward_context ().pad_size
16
23
if pad_size > 0 :
17
24
residual = F .pad (residual , (0 , 0 , 0 , pad_size ))
@@ -44,6 +51,76 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
44
51
return tensor_model_parallel_all_reduce (x )
45
52
46
53
54
+ def _maybe_prefetch_mlp_gate_up_proj_impl (x_dependency : torch .Tensor ,
55
+ prefix : str ) -> None :
56
+ forward_context = get_forward_context ()
57
+ if not forward_context .prefetch_mlp_enabled :
58
+ return
59
+ prefetch_model = forward_context .prefetch_model
60
+ prefetch_stream = forward_context .prefetch_stream
61
+ layer_idx = int (prefix .split ('.' )[2 ])
62
+
63
+ # start point of gate_up_proj weight prefetch
64
+ if prefix .split ('.' )[- 2 ] == "self_attn" :
65
+ forward_context .prefetch_mlp_gate_up_proj = True
66
+ if forward_context .prefetch_mlp_gate_up_proj :
67
+ prefetch_stream .wait_stream (torch .npu .current_stream ())
68
+
69
+ with torch .npu .stream (prefetch_stream ):
70
+ MLP_GATE_UP_PREFETCH_SIZE = envs_ascend .MLP_GATE_UP_PREFETCH_SIZE
71
+ torch_npu .npu_prefetch (prefetch_model .model .layers [layer_idx ].mlp .gate_up_proj .weight , \
72
+ x_dependency , MLP_GATE_UP_PREFETCH_SIZE )
73
+ return
74
+
75
+
76
+ def _maybe_prefetch_mlp_gate_up_proj_impl_fake (x_dependency : torch .Tensor ,
77
+ prefix : str ) -> None :
78
+ return
79
+
80
+
81
+ def _maybe_prefetch_mlp_down_proj_impl (x_dependency : torch .Tensor ) -> None :
82
+ forward_context = get_forward_context ()
83
+ if not forward_context .prefetch_mlp_enabled :
84
+ return
85
+ forward_context .prefetch_mlp_down_proj = True
86
+ prefetch_model = forward_context .prefetch_model
87
+ prefetch_stream = forward_context .prefetch_stream
88
+ layer_idx = forward_context .layer_idx
89
+
90
+ # start point of down_proj weight prefetch
91
+ prefetch_stream .wait_stream (torch .npu .current_stream ())
92
+
93
+ with torch .npu .stream (prefetch_stream ):
94
+ MLP_DOWN_PREFETCH_SIZE = envs_ascend .MLP_DOWN_PREFETCH_SIZE
95
+ torch_npu .npu_prefetch (prefetch_model .model .layers [layer_idx ].mlp .down_proj .weight , \
96
+ x_dependency , MLP_DOWN_PREFETCH_SIZE )
97
+ forward_context .layer_idx += 1
98
+ return
99
+
100
+
101
+ def _maybe_prefetch_mlp_down_proj_impl_fake (
102
+ x_dependency : torch .Tensor ) -> None :
103
+ return
104
+
105
+
106
+ def _maybe_wait_prefetch_done_impl (x : torch .Tensor ) -> None :
107
+ forward_context = get_forward_context ()
108
+ if not forward_context .prefetch_mlp_enabled :
109
+ return
110
+ if forward_context .prefetch_mlp_gate_up_proj or \
111
+ forward_context .prefetch_mlp_down_proj :
112
+ prefetch_stream = get_forward_context ().prefetch_stream
113
+ # wait until prefetch done
114
+ torch .npu .current_stream ().wait_stream (prefetch_stream )
115
+ forward_context .prefetch_mlp_gate_up_proj = False
116
+ forward_context .prefetch_mlp_down_proj = False
117
+ return
118
+
119
+
120
+ def _maybe_wait_prefetch_done_impl_fake (x : torch .Tensor ) -> None :
121
+ return
122
+
123
+
47
124
direct_register_custom_op (op_name = "maybe_chunk_residual" ,
48
125
op_func = _maybe_chunk_residual_impl ,
49
126
fake_impl = lambda x , residual : residual ,
@@ -60,4 +137,22 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
60
137
op_func = _maybe_pad_and_reduce_impl ,
61
138
fake_impl = lambda x : x ,
62
139
mutates_args = [],
63
- dispatch_key = "PrivateUse1" )
140
+ dispatch_key = "PrivateUse1" )
141
+
142
+ direct_register_custom_op (op_name = "maybe_prefetch_mlp_gate_up_proj" ,
143
+ op_func = _maybe_prefetch_mlp_gate_up_proj_impl ,
144
+ fake_impl = _maybe_prefetch_mlp_gate_up_proj_impl_fake ,
145
+ mutates_args = [],
146
+ dispatch_key = "PrivateUse1" )
147
+
148
+ direct_register_custom_op (op_name = "maybe_prefetch_mlp_down_proj" ,
149
+ op_func = _maybe_prefetch_mlp_down_proj_impl ,
150
+ fake_impl = _maybe_prefetch_mlp_down_proj_impl_fake ,
151
+ mutates_args = [],
152
+ dispatch_key = "PrivateUse1" )
153
+
154
+ direct_register_custom_op (op_name = "maybe_wait_prefetch_done" ,
155
+ op_func = _maybe_wait_prefetch_done_impl ,
156
+ fake_impl = _maybe_wait_prefetch_done_impl_fake ,
157
+ mutates_args = [],
158
+ dispatch_key = "PrivateUse1" )
0 commit comments