Skip to content

Commit ecfb820

Browse files
author
wangxiaoxin-sherie
committed
optimaze upstream in fullgraph
Signed-off-by: wangxiaoxin-sherie <[email protected]>
1 parent 789ba4c commit ecfb820

File tree

1 file changed

+92
-89
lines changed

1 file changed

+92
-89
lines changed

vllm_ascend/compilation/acl_graph.py

Lines changed: 92 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -196,27 +196,27 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
196196
graph_params = get_graph_params()
197197
# FIXME: Behold! We are using a temporary hack here to update the args
198198
# for each layer's attention op in the graph.
199-
for key, param, handle, event in zip(
200-
forward_context.attn_metadata,
201-
graph_params.attn_params[runtime_shape],
202-
graph_params.handles[runtime_shape],
203-
graph_params.events[runtime_shape],
204-
):
205-
(
206-
query,
207-
key_cache,
208-
value_cache,
209-
num_kv_heads,
210-
num_heads,
211-
scale,
212-
block_table,
213-
seq_lens,
214-
output,
215-
) = param
216-
seq_lens = forward_context.attn_metadata[key].seq_lens
217-
torch_npu_check = version_check()
218-
219-
with torch.npu.stream(update_stream):
199+
with torch.npu.stream(update_stream):
200+
for key, param, handle, event in zip(
201+
forward_context.attn_metadata,
202+
graph_params.attn_params[runtime_shape],
203+
graph_params.handles[runtime_shape],
204+
graph_params.events[runtime_shape],
205+
):
206+
(
207+
query,
208+
key_cache,
209+
value_cache,
210+
num_kv_heads,
211+
num_heads,
212+
scale,
213+
block_table,
214+
seq_lens,
215+
output,
216+
) = param
217+
seq_lens = forward_context.attn_metadata[key].seq_lens
218+
torch_npu_check = version_check()
219+
220220
torch.npu.graph_task_update_begin(update_stream, handle)
221221
if torch_npu_check:
222222
torch_npu._npu_paged_attention(
@@ -250,30 +250,32 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
250250
graph_params = get_graph_params()
251251
# FIXME: Behold! We are using a temporary hack here to update the args
252252
# for each layer's attention op in the graph.
253-
for key, param, handle, event in zip(
254-
forward_context.attn_metadata,
255-
graph_params.attn_params[runtime_shape],
256-
graph_params.handles[runtime_shape],
257-
graph_params.events[runtime_shape],
258-
):
259-
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
260-
spec_attn_mask, sparse_mode, scale, block_table, block_size,
261-
seq_lens_list, actual_seq_lengths, attn_output, softmax_lse) = param
262-
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
263-
if speculative_config and speculative_config.method == "deepseek_mtp":
264-
actual_seq_lengths = forward_context.attn_metadata[
265-
key].decode.actual_seq_lengths_q
266-
spec_multiple = speculative_config.num_speculative_tokens + 1
267-
seq_lens_list = seq_lens_list + [0] * (
268-
runtime_shape // spec_multiple - len(seq_lens_list))
269-
actual_seq_lengths = [
270-
spec_multiple * (i + 1)
271-
for i in range(runtime_shape // spec_multiple)
272-
]
273-
else:
274-
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
275-
len(seq_lens_list))
276-
with torch.npu.stream(update_stream):
253+
with torch.npu.stream(update_stream):
254+
for key, param, handle, event in zip(
255+
forward_context.attn_metadata,
256+
graph_params.attn_params[runtime_shape],
257+
graph_params.handles[runtime_shape],
258+
graph_params.events[runtime_shape],
259+
):
260+
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
261+
spec_attn_mask, sparse_mode, scale, block_table, block_size,
262+
seq_lens_list, actual_seq_lengths, attn_output,
263+
softmax_lse) = param
264+
seq_lens_list = forward_context.attn_metadata[
265+
key].decode.seq_lens_list
266+
if speculative_config and speculative_config.method == "deepseek_mtp":
267+
actual_seq_lengths = forward_context.attn_metadata[
268+
key].decode.actual_seq_lengths_q
269+
spec_multiple = speculative_config.num_speculative_tokens + 1
270+
seq_lens_list = seq_lens_list + [0] * (
271+
runtime_shape // spec_multiple - len(seq_lens_list))
272+
actual_seq_lengths = [
273+
spec_multiple * (i + 1)
274+
for i in range(runtime_shape // spec_multiple)
275+
]
276+
else:
277+
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
278+
len(seq_lens_list))
277279
torch.npu.graph_task_update_begin(update_stream, handle)
278280

279281
torch_npu.npu_fused_infer_attention_score.out(
@@ -305,26 +307,27 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
305307
graph_params = get_graph_params()
306308
# FIXME: Behold! We are using a temporary hack here to update the args
307309
# for each layer's attention op in the graph.
308-
for key, param, handle, event in zip(
309-
forward_context.attn_metadata,
310-
graph_params.attn_params[runtime_shape],
311-
graph_params.handles[runtime_shape],
312-
graph_params.events[runtime_shape],
313-
):
314-
(q_nope, k_nope, value, num_heads, num_kv_heads, scale, block_table,
315-
block_size, actual_seq_lengths_kv, attn_output, softmax_lse, cp_rank,
316-
dcp_rank, dcp_size) = param
317-
actual_seq_lengths_kv = forward_context.attn_metadata[
318-
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank,
319-
dcp_rank]
320-
pad_length = runtime_shape - len(actual_seq_lengths_kv)
321-
pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype)
322-
actual_seq_lengths_kv = np.concatenate(
323-
[actual_seq_lengths_kv, pad_tensor])
324-
if dcp_size > 1:
325-
num_heads = num_heads * dcp_size
326-
327-
with torch.npu.stream(update_stream):
310+
with torch.npu.stream(update_stream):
311+
for key, param, handle, event in zip(
312+
forward_context.attn_metadata,
313+
graph_params.attn_params[runtime_shape],
314+
graph_params.handles[runtime_shape],
315+
graph_params.events[runtime_shape],
316+
):
317+
(q_nope, k_nope, value, num_heads, num_kv_heads, scale, block_table,
318+
block_size, actual_seq_lengths_kv, attn_output, softmax_lse, cp_rank,
319+
dcp_rank, dcp_size) = param
320+
actual_seq_lengths_kv = forward_context.attn_metadata[
321+
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank,
322+
dcp_rank]
323+
pad_length = runtime_shape - len(actual_seq_lengths_kv)
324+
pad_tensor = np.zeros(pad_length,
325+
dtype=actual_seq_lengths_kv.dtype)
326+
actual_seq_lengths_kv = np.concatenate(
327+
[actual_seq_lengths_kv, pad_tensor])
328+
if dcp_size > 1:
329+
num_heads = num_heads * dcp_size
330+
328331
torch.npu.graph_task_update_begin(update_stream, handle)
329332

330333
torch_npu.npu_fused_infer_attention_score.out(
@@ -354,30 +357,30 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
354357
graph_params = get_graph_params()
355358
# FIXME: Behold! We are using a temporary hack here to update the args
356359
# for each layer's attention op in the graph.
357-
for key, param, handle, event in zip(
358-
forward_context.attn_metadata,
359-
graph_params.attn_params[runtime_shape],
360-
graph_params.handles[runtime_shape],
361-
graph_params.events[runtime_shape],
362-
):
363-
(q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads, scale,
364-
num_kv_heads, attn_output, softmax_lse) = param
365-
366-
decode_meta = forward_context.attn_metadata[key].decode
367-
seq_len = decode_meta.cp_seq_len
368-
369-
if speculative_config and speculative_config.method == "deepseek_mtp":
370-
spec_multiple = speculative_config.num_speculative_tokens + 1
371-
seq_len = seq_len + [0] * (runtime_shape // spec_multiple -
372-
len(seq_len))
373-
else:
374-
pad_length = runtime_shape - len(seq_len)
375-
pad_tensor = torch.zeros(pad_length,
376-
dtype=seq_len.dtype,
377-
device=seq_len.device)
378-
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
379-
380-
with torch.npu.stream(update_stream):
360+
with torch.npu.stream(update_stream):
361+
for key, param, handle, event in zip(
362+
forward_context.attn_metadata,
363+
graph_params.attn_params[runtime_shape],
364+
graph_params.handles[runtime_shape],
365+
graph_params.events[runtime_shape],
366+
):
367+
(q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads,
368+
scale, num_kv_heads, attn_output, softmax_lse) = param
369+
370+
decode_meta = forward_context.attn_metadata[key].decode
371+
seq_len = decode_meta.cp_seq_len
372+
373+
if speculative_config and speculative_config.method == "deepseek_mtp":
374+
spec_multiple = speculative_config.num_speculative_tokens + 1
375+
seq_len = seq_len + [0] * (runtime_shape // spec_multiple -
376+
len(seq_len))
377+
else:
378+
pad_length = runtime_shape - len(seq_len)
379+
pad_tensor = torch.zeros(pad_length,
380+
dtype=seq_len.dtype,
381+
device=seq_len.device)
382+
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
383+
381384
torch.npu.graph_task_update_begin(update_stream, handle)
382385

383386
torch_npu.atb.npu_multi_head_latent_attention(

0 commit comments

Comments
 (0)