Skip to content

Commit 4312a92

Browse files
authored
[feat]dcp pcp support aclgraph (#3731)
### What this PR does / why we need it? dcp pcp support full aclgraph, including mla attention_v1 - vLLM version: v0.11.0rc3 - vLLM main: vllm-project/vllm@c9461e0 Signed-off-by: weiguihua2 <[email protected]>
1 parent 8ab8111 commit 4312a92

File tree

5 files changed

+415
-69
lines changed

5 files changed

+415
-69
lines changed

tests/ut/attention/test_mla_v1.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,16 +176,30 @@ def test_ascend_mla_metadata_default(self):
176176

177177
class TestAscendMLAMetadataBuilder(TestBase):
178178

179-
def test_ascend_mla_metadata_builder_default(self):
179+
@patch('vllm.distributed.parallel_state.get_dcp_group')
180+
@patch('vllm.distributed.parallel_state._DCP',
181+
new_callable=lambda: MagicMock(spec=GroupCoordinator))
182+
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
183+
return_value=1)
184+
def test_ascend_mla_metadata_builder_default(self, mock_get_dcp_size,
185+
mock_dcp, mock_get_dcp_group):
180186
mock_vllm_config = MagicMock()
181187
mock_vllm_config.model_config.max_model_len = 1024
182188
mock_vllm_config.model_config.get_head_size.return_value = 64
183189
mock_vllm_config.model_config.dtype = torch.float16
184190
mock_vllm_config.cache_config.block_size = 16
185191
mock_vllm_config.scheduler_config.max_num_seqs = 4
192+
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
186193
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
187194
mock_device = 'cpu'
188195

196+
mock_dcp.world_size = 1
197+
dcp_group = MagicMock(spec=GroupCoordinator)
198+
dcp_group.rank_in_group = 0
199+
dcp_group.world_size = 1
200+
dcp_group.device_group = MagicMock()
201+
mock_get_dcp_group.return_value = dcp_group
202+
189203
mock_vllm_config.speculative_config = None
190204

191205
ascend_config = MagicMock()
@@ -200,16 +214,31 @@ def test_ascend_mla_metadata_builder_default(self):
200214
builder.chunked_prefill_enabled,
201215
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
202216

203-
def test_ascend_mla_metadata_builder_spec_decode(self):
217+
@patch('vllm.distributed.parallel_state.get_dcp_group')
218+
@patch('vllm.distributed.parallel_state._DCP',
219+
new_callable=lambda: MagicMock(spec=GroupCoordinator))
220+
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
221+
return_value=1)
222+
def test_ascend_mla_metadata_builder_spec_decode(self, mock_get_dcp_size,
223+
mock_dcp,
224+
mock_get_dcp_group):
204225
mock_vllm_config = MagicMock()
205226
mock_vllm_config.model_config.max_model_len = 1024
206227
mock_vllm_config.model_config.get_head_size.return_value = 64
207228
mock_vllm_config.model_config.dtype = torch.float16
208229
mock_vllm_config.cache_config.block_size = 16
209230
mock_vllm_config.scheduler_config.max_num_seqs = 4
231+
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
210232
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
211233
mock_device = 'cpu'
212234

235+
mock_dcp.world_size = 1
236+
dcp_group = MagicMock(spec=GroupCoordinator)
237+
dcp_group.rank_in_group = 0
238+
dcp_group.world_size = 1
239+
dcp_group.device_group = MagicMock()
240+
mock_get_dcp_group.return_value = dcp_group
241+
213242
mock_spec_config = MagicMock()
214243
mock_spec_config.num_speculative_tokens = 3
215244
mock_vllm_config.speculative_config = mock_spec_config
@@ -226,16 +255,30 @@ def test_ascend_mla_metadata_builder_spec_decode(self):
226255
builder.chunked_prefill_enabled,
227256
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
228257

229-
def test_reorder_batch(self):
258+
@patch('vllm.distributed.parallel_state.get_dcp_group')
259+
@patch('vllm.distributed.parallel_state._DCP',
260+
new_callable=lambda: MagicMock(spec=GroupCoordinator))
261+
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
262+
return_value=1)
263+
def test_reorder_batch(self, mock_get_dcp_size, mock_dcp,
264+
mock_get_dcp_group):
230265
ascend_config = MagicMock()
231266

232267
mock_vllm_config = MagicMock()
233268
mock_vllm_config.model_config.max_model_len = 1024
234269
mock_vllm_config.cache_config.block_size = 16
235270
mock_vllm_config.scheduler_config.max_num_seqs = 4
271+
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
236272
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
237273
mock_device = 'cpu'
238274

275+
mock_dcp.world_size = 1
276+
dcp_group = MagicMock(spec=GroupCoordinator)
277+
dcp_group.rank_in_group = 0
278+
dcp_group.world_size = 1
279+
dcp_group.device_group = MagicMock()
280+
mock_get_dcp_group.return_value = dcp_group
281+
239282
mock_vllm_config.speculative_config = None
240283

241284
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",

vllm_ascend/attention/attention_v1.py

Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -865,26 +865,81 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
865865
num_heads = self.num_heads
866866

867867
# 1. Compute out&lse by "npu_fused_infer_attention_score"
868-
attn_out, attn_lse = torch.ops.npu.npu_fused_infer_attention_score(
869-
query.view(query.shape[0], 1, query.shape[1], query.shape[2]),
870-
# [b,num_heads,head_size] -> [b,1,num_heads,head_size]
871-
self.key_cache.view(self.key_cache.shape[0],
872-
self.key_cache.shape[1], -1),
873-
self.value_cache.view(self.key_cache.shape[0],
874-
self.key_cache.shape[1], -1),
875-
num_heads=num_heads,
876-
num_key_value_heads=self.num_kv_heads,
877-
input_layout="BSND",
878-
atten_mask=None,
879-
scale=self.scale,
880-
antiquant_mode=0,
881-
antiquant_scale=None,
882-
softmax_lse_flag=True,
883-
block_table=attn_metadata.block_tables,
884-
block_size=self.key_cache.shape[1],
885-
actual_seq_lengths_kv=attn_metadata.decode_meta.
868+
q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2])
869+
# [b,num_heads,head_size] -> [b,1,num_heads,head_size]
870+
k_nope = self.key_cache.view(self.key_cache.shape[0],
871+
self.key_cache.shape[1], -1)
872+
value = self.value_cache.view(self.key_cache.shape[0],
873+
self.key_cache.shape[1], -1)
874+
common_kwargs = {
875+
'num_heads':
876+
num_heads,
877+
'num_key_value_heads':
878+
self.num_kv_heads,
879+
'input_layout':
880+
"BSND",
881+
'atten_mask':
882+
None,
883+
'scale':
884+
self.scale,
885+
'antiquant_mode':
886+
0,
887+
'antiquant_scale':
888+
None,
889+
'softmax_lse_flag':
890+
True,
891+
'block_table':
892+
attn_metadata.block_tables,
893+
'block_size':
894+
self.key_cache.shape[1],
895+
"actual_seq_lengths_kv":
896+
attn_metadata.decode_meta.
886897
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
887-
)
898+
}
899+
graph_params = get_graph_params()
900+
forward_context: ForwardContext = get_forward_context()
901+
num_tokens = query.shape[0]
902+
if forward_context.capturing:
903+
stream = torch_npu.npu.current_stream()
904+
905+
event = torch.npu.ExternalEvent()
906+
event.wait(stream)
907+
event.reset(stream)
908+
graph_params.events[num_tokens].append(event)
909+
910+
workspace = graph_params.workspaces.get(num_tokens)
911+
if workspace is None:
912+
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
913+
q_nope, k_nope, value, **common_kwargs)
914+
update_graph_params_workspaces(num_tokens,
915+
weak_ref_tensors(workspace))
916+
attn_out = torch.empty_like(q_nope)
917+
attn_lse = torch.empty((num_tokens, num_heads, 1, 1),
918+
dtype=torch.float,
919+
device=q_nope.device)
920+
921+
graph_params.attn_params[num_tokens].append(
922+
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
923+
weak_ref_tensors(value), self.num_heads, self.num_kv_heads,
924+
self.scale, attn_metadata.block_tables,
925+
self.key_cache.shape[1], attn_metadata.decode_meta.
926+
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank,
927+
self.dcp_rank],
928+
weak_ref_tensors(attn_out), weak_ref_tensors(attn_lse),
929+
self.pcp_rank, self.dcp_rank, self.dcp_size))
930+
torch.npu.graph_task_group_begin(stream)
931+
torch_npu.npu_fused_infer_attention_score.out(
932+
q_nope,
933+
k_nope,
934+
value,
935+
**common_kwargs,
936+
workspace=workspace,
937+
out=[attn_out, attn_lse])
938+
handle = torch.npu.graph_task_group_end(stream)
939+
graph_params.handles[num_tokens].append(handle)
940+
else:
941+
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
942+
q_nope, k_nope, value, **common_kwargs)
888943

889944
attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2],
890945
attn_out.shape[3])

0 commit comments

Comments
 (0)