1
1
import torch
2
2
import torch .nn .functional as F
3
+ import torch_npu
3
4
from vllm .distributed import (get_tensor_model_parallel_rank ,
4
5
get_tensor_model_parallel_world_size ,
5
6
tensor_model_parallel_all_gather ,
6
7
tensor_model_parallel_all_reduce ,
7
8
tensor_model_parallel_reduce_scatter )
8
9
from vllm .forward_context import get_forward_context
9
10
from vllm .utils import direct_register_custom_op
11
+ import vllm_ascend .envs as envs_ascend
10
12
11
13
12
14
def _maybe_chunk_residual_impl (x : torch .Tensor ,
13
15
residual : torch .Tensor ) -> torch .Tensor :
14
- if get_forward_context ().flashcomm_v1_enabled :
16
+ if x .size (0 ) != residual .size (0 ):
17
+ flashcomm_v1_enabled = get_forward_context ().flashcomm_v1_enabled
18
+ assert flashcomm_v1_enabled is True , (
19
+ "Currently, this situation only occurs "
20
+ "when flashcomm_v1 is enabled"
21
+ )
15
22
pad_size = get_forward_context ().pad_size
16
23
if pad_size > 0 :
17
24
residual = F .pad (residual , (0 , 0 , 0 , pad_size ))
@@ -44,6 +51,75 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
44
51
return tensor_model_parallel_all_reduce (x )
45
52
46
53
54
+ def _maybe_prefetch_mlp_gate_up_proj_impl (x_dependency : torch .Tensor ,
55
+ prefix : str ) -> None :
56
+ forward_context = get_forward_context ()
57
+ if not forward_context .prefetch_mlp_enabled :
58
+ return
59
+ prefetch_model = forward_context .prefetch_model
60
+ prefetch_stream = forward_context .prefetch_stream
61
+ layer_idx = int (prefix .split ('.' )[2 ])
62
+
63
+ # start point of gate_up_proj weight prefetch
64
+ if prefix .split ('.' )[- 2 ] == "self_attn" :
65
+ forward_context .prefetch_mlp_gate_up_proj = True
66
+ if forward_context .prefetch_mlp_gate_up_proj :
67
+ prefetch_stream .wait_stream (torch .npu .current_stream ())
68
+
69
+ with torch .npu .stream (prefetch_stream ):
70
+ MLP_GATE_UP_PREFETCH_SIZE = envs_ascend .MLP_GATE_UP_PREFETCH_SIZE
71
+ torch_npu .npu_prefetch (prefetch_model .model .layers [layer_idx ].mlp .gate_up_proj .weight , \
72
+ x_dependency , MLP_GATE_UP_PREFETCH_SIZE )
73
+ return
74
+
75
+
76
+ def _maybe_prefetch_mlp_gate_up_proj_impl_fake (x_dependency : torch .Tensor ,
77
+ prefix : str ) -> None :
78
+ return
79
+
80
+
81
+ def _maybe_prefetch_mlp_down_proj_impl (x_dependency : torch .Tensor ) -> None :
82
+ forward_context = get_forward_context ()
83
+ if not forward_context .prefetch_mlp_enabled :
84
+ return
85
+ forward_context .prefetch_mlp_down_proj = True
86
+ prefetch_model = forward_context .prefetch_model
87
+ prefetch_stream = forward_context .prefetch_stream
88
+ layer_idx = forward_context .layer_idx
89
+
90
+ # start point of down_proj weight prefetch
91
+ prefetch_stream .wait_stream (torch .npu .current_stream ())
92
+
93
+ with torch .npu .stream (prefetch_stream ):
94
+ MLP_DOWN_PREFETCH_SIZE = envs_ascend .MLP_DOWN_PREFETCH_SIZE
95
+ torch_npu .npu_prefetch (prefetch_model .model .layers [layer_idx ].mlp .down_proj .weight , \
96
+ x_dependency , MLP_DOWN_PREFETCH_SIZE )
97
+ forward_context .layer_idx += 1
98
+ return
99
+
100
+
101
+ def _maybe_prefetch_mlp_down_proj_impl_fake (x_dependency : torch .Tensor ) -> None :
102
+ return
103
+
104
+
105
+ def _maybe_wait_prefetch_done_impl (x : torch .Tensor ) -> None :
106
+ forward_context = get_forward_context ()
107
+ if not forward_context .prefetch_mlp_enabled :
108
+ return
109
+ if forward_context .prefetch_mlp_gate_up_proj or \
110
+ forward_context .prefetch_mlp_down_proj :
111
+ prefetch_stream = get_forward_context ().prefetch_stream
112
+ # wait until prefetch done
113
+ torch .npu .current_stream ().wait_stream (prefetch_stream )
114
+ forward_context .prefetch_mlp_gate_up_proj = False
115
+ forward_context .prefetch_mlp_down_proj = False
116
+ return
117
+
118
+
119
+ def _maybe_wait_prefetch_done_impl_fake (x : torch .Tensor ) -> None :
120
+ return
121
+
122
+
47
123
direct_register_custom_op (op_name = "maybe_chunk_residual" ,
48
124
op_func = _maybe_chunk_residual_impl ,
49
125
fake_impl = lambda x , residual : residual ,
@@ -60,4 +136,25 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
60
136
op_func = _maybe_pad_and_reduce_impl ,
61
137
fake_impl = lambda x : x ,
62
138
mutates_args = [],
139
+ dispatch_key = "PrivateUse1" )
140
+
141
+
142
+ direct_register_custom_op (op_name = "maybe_prefetch_mlp_gate_up_proj" ,
143
+ op_func = _maybe_prefetch_mlp_gate_up_proj_impl ,
144
+ fake_impl = _maybe_prefetch_mlp_gate_up_proj_impl_fake ,
145
+ mutates_args = [],
146
+ dispatch_key = "PrivateUse1" )
147
+
148
+
149
+ direct_register_custom_op (op_name = "maybe_prefetch_mlp_down_proj" ,
150
+ op_func = _maybe_prefetch_mlp_down_proj_impl ,
151
+ fake_impl = _maybe_prefetch_mlp_down_proj_impl_fake ,
152
+ mutates_args = [],
153
+ dispatch_key = "PrivateUse1" )
154
+
155
+
156
+ direct_register_custom_op (op_name = "maybe_wait_prefetch_done" ,
157
+ op_func = _maybe_wait_prefetch_done_impl ,
158
+ fake_impl = _maybe_wait_prefetch_done_impl_fake ,
159
+ mutates_args = [],
63
160
dispatch_key = "PrivateUse1" )
0 commit comments