Skip to content

Commit ed9d5d7

Browse files
author
wangxiaoxin-sherie
committed
optimaze upstream in fullgraph
Signed-off-by: wangxiaoxin-sherie <[email protected]>
1 parent fcc9a0e commit ed9d5d7

File tree

1 file changed

+90
-88
lines changed

1 file changed

+90
-88
lines changed

vllm_ascend/compilation/acl_graph.py

Lines changed: 90 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -194,26 +194,25 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
194194
graph_params = get_graph_params()
195195
# FIXME: Behold! We are using a temporary hack here to update the args
196196
# for each layer's attention op in the graph.
197-
for key, param, handle, event in zip(
198-
forward_context.attn_metadata,
199-
graph_params.attn_params[runtime_shape],
200-
graph_params.handles[runtime_shape],
201-
graph_params.events[runtime_shape],
202-
):
203-
(
204-
query,
205-
key_cache,
206-
value_cache,
207-
num_kv_heads,
208-
num_heads,
209-
scale,
210-
block_table,
211-
seq_lens,
212-
output,
213-
) = param
214-
seq_lens = forward_context.attn_metadata[key].seq_lens
215-
216-
with torch.npu.stream(update_stream):
197+
with torch.npu.stream(update_stream):
198+
for key, param, handle, event in zip(
199+
forward_context.attn_metadata,
200+
graph_params.attn_params[runtime_shape],
201+
graph_params.handles[runtime_shape],
202+
graph_params.events[runtime_shape],
203+
):
204+
(
205+
query,
206+
key_cache,
207+
value_cache,
208+
num_kv_heads,
209+
num_heads,
210+
scale,
211+
block_table,
212+
seq_lens,
213+
output,
214+
) = param
215+
seq_lens = forward_context.attn_metadata[key].seq_lens
217216
torch.npu.graph_task_update_begin(update_stream, handle)
218217
torch_npu._npu_paged_attention(
219218
query=query,
@@ -236,30 +235,32 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
236235
graph_params = get_graph_params()
237236
# FIXME: Behold! We are using a temporary hack here to update the args
238237
# for each layer's attention op in the graph.
239-
for key, param, handle, event in zip(
240-
forward_context.attn_metadata,
241-
graph_params.attn_params[runtime_shape],
242-
graph_params.handles[runtime_shape],
243-
graph_params.events[runtime_shape],
244-
):
245-
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
246-
spec_attn_mask, sparse_mode, scale, block_table, block_size,
247-
seq_lens_list, actual_seq_lengths, attn_output, softmax_lse) = param
248-
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
249-
if speculative_config and speculative_config.method == "deepseek_mtp":
250-
actual_seq_lengths = forward_context.attn_metadata[
251-
key].decode.actual_seq_lengths_q
252-
spec_multiple = speculative_config.num_speculative_tokens + 1
253-
seq_lens_list = seq_lens_list + [0] * (
254-
runtime_shape // spec_multiple - len(seq_lens_list))
255-
actual_seq_lengths = [
256-
spec_multiple * (i + 1)
257-
for i in range(runtime_shape // spec_multiple)
258-
]
259-
else:
260-
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
261-
len(seq_lens_list))
262-
with torch.npu.stream(update_stream):
238+
with torch.npu.stream(update_stream):
239+
for key, param, handle, event in zip(
240+
forward_context.attn_metadata,
241+
graph_params.attn_params[runtime_shape],
242+
graph_params.handles[runtime_shape],
243+
graph_params.events[runtime_shape],
244+
):
245+
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
246+
spec_attn_mask, sparse_mode, scale, block_table, block_size,
247+
seq_lens_list, actual_seq_lengths, attn_output,
248+
softmax_lse) = param
249+
seq_lens_list = forward_context.attn_metadata[
250+
key].decode.seq_lens_list
251+
if speculative_config and speculative_config.method == "deepseek_mtp":
252+
actual_seq_lengths = forward_context.attn_metadata[
253+
key].decode.actual_seq_lengths_q
254+
spec_multiple = speculative_config.num_speculative_tokens + 1
255+
seq_lens_list = seq_lens_list + [0] * (
256+
runtime_shape // spec_multiple - len(seq_lens_list))
257+
actual_seq_lengths = [
258+
spec_multiple * (i + 1)
259+
for i in range(runtime_shape // spec_multiple)
260+
]
261+
else:
262+
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
263+
len(seq_lens_list))
263264
torch.npu.graph_task_update_begin(update_stream, handle)
264265

265266
torch_npu.npu_fused_infer_attention_score.out(
@@ -291,26 +292,27 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
291292
graph_params = get_graph_params()
292293
# FIXME: Behold! We are using a temporary hack here to update the args
293294
# for each layer's attention op in the graph.
294-
for key, param, handle, event in zip(
295-
forward_context.attn_metadata,
296-
graph_params.attn_params[runtime_shape],
297-
graph_params.handles[runtime_shape],
298-
graph_params.events[runtime_shape],
299-
):
300-
(q_nope, k_nope, value, num_heads, num_kv_heads, scale, block_table,
301-
block_size, actual_seq_lengths_kv, attn_output, softmax_lse, cp_rank,
302-
dcp_rank, dcp_size) = param
303-
actual_seq_lengths_kv = forward_context.attn_metadata[
304-
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank,
305-
dcp_rank]
306-
pad_length = runtime_shape - len(actual_seq_lengths_kv)
307-
pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype)
308-
actual_seq_lengths_kv = np.concatenate(
309-
[actual_seq_lengths_kv, pad_tensor])
310-
if dcp_size > 1:
311-
num_heads = num_heads * dcp_size
312-
313-
with torch.npu.stream(update_stream):
295+
with torch.npu.stream(update_stream):
296+
for key, param, handle, event in zip(
297+
forward_context.attn_metadata,
298+
graph_params.attn_params[runtime_shape],
299+
graph_params.handles[runtime_shape],
300+
graph_params.events[runtime_shape],
301+
):
302+
(q_nope, k_nope, value, num_heads, num_kv_heads, scale,
303+
block_table, block_size, actual_seq_lengths_kv, attn_output,
304+
softmax_lse, cp_rank, dcp_rank, dcp_size) = param
305+
actual_seq_lengths_kv = forward_context.attn_metadata[
306+
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank,
307+
dcp_rank]
308+
pad_length = runtime_shape - len(actual_seq_lengths_kv)
309+
pad_tensor = np.zeros(pad_length,
310+
dtype=actual_seq_lengths_kv.dtype)
311+
actual_seq_lengths_kv = np.concatenate(
312+
[actual_seq_lengths_kv, pad_tensor])
313+
if dcp_size > 1:
314+
num_heads = num_heads * dcp_size
315+
314316
torch.npu.graph_task_update_begin(update_stream, handle)
315317

316318
torch_npu.npu_fused_infer_attention_score.out(
@@ -340,30 +342,30 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
340342
graph_params = get_graph_params()
341343
# FIXME: Behold! We are using a temporary hack here to update the args
342344
# for each layer's attention op in the graph.
343-
for key, param, handle, event in zip(
344-
forward_context.attn_metadata,
345-
graph_params.attn_params[runtime_shape],
346-
graph_params.handles[runtime_shape],
347-
graph_params.events[runtime_shape],
348-
):
349-
(q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads, scale,
350-
num_kv_heads, attn_output, softmax_lse) = param
351-
352-
decode_meta = forward_context.attn_metadata[key].decode
353-
seq_len = decode_meta.cp_seq_len
354-
355-
if speculative_config and speculative_config.method == "deepseek_mtp":
356-
spec_multiple = speculative_config.num_speculative_tokens + 1
357-
seq_len = seq_len + [0] * (runtime_shape // spec_multiple -
358-
len(seq_len))
359-
else:
360-
pad_length = runtime_shape - len(seq_len)
361-
pad_tensor = torch.zeros(pad_length,
362-
dtype=seq_len.dtype,
363-
device=seq_len.device)
364-
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
365-
366-
with torch.npu.stream(update_stream):
345+
with torch.npu.stream(update_stream):
346+
for key, param, handle, event in zip(
347+
forward_context.attn_metadata,
348+
graph_params.attn_params[runtime_shape],
349+
graph_params.handles[runtime_shape],
350+
graph_params.events[runtime_shape],
351+
):
352+
(q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads,
353+
scale, num_kv_heads, attn_output, softmax_lse) = param
354+
355+
decode_meta = forward_context.attn_metadata[key].decode
356+
seq_len = decode_meta.cp_seq_len
357+
358+
if speculative_config and speculative_config.method == "deepseek_mtp":
359+
spec_multiple = speculative_config.num_speculative_tokens + 1
360+
seq_len = seq_len + [0] * (runtime_shape // spec_multiple -
361+
len(seq_len))
362+
else:
363+
pad_length = runtime_shape - len(seq_len)
364+
pad_tensor = torch.zeros(pad_length,
365+
dtype=seq_len.dtype,
366+
device=seq_len.device)
367+
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
368+
367369
torch.npu.graph_task_update_begin(update_stream, handle)
368370

369371
torch_npu.atb.npu_multi_head_latent_attention(

0 commit comments

Comments
 (0)