@@ -196,27 +196,27 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
196196 graph_params = get_graph_params ()
197197 # FIXME: Behold! We are using a temporary hack here to update the args
198198 # for each layer's attention op in the graph.
199- for key , param , handle , event in zip (
200- forward_context . attn_metadata ,
201- graph_params . attn_params [ runtime_shape ] ,
202- graph_params .handles [runtime_shape ],
203- graph_params .events [runtime_shape ],
204- ):
205- (
206- query ,
207- key_cache ,
208- value_cache ,
209- num_kv_heads ,
210- num_heads ,
211- scale ,
212- block_table ,
213- seq_lens ,
214- output ,
215- ) = param
216- seq_lens = forward_context . attn_metadata [ key ]. seq_lens
217- torch_npu_check = version_check ()
218-
219- with torch . npu . stream ( update_stream ):
199+ with torch . npu . stream ( update_stream ):
200+ for key , param , handle , event in zip (
201+ forward_context . attn_metadata ,
202+ graph_params .attn_params [runtime_shape ],
203+ graph_params .handles [runtime_shape ],
204+ graph_params . events [ runtime_shape ],
205+ ):
206+ (
207+ query ,
208+ key_cache ,
209+ value_cache ,
210+ num_kv_heads ,
211+ num_heads ,
212+ scale ,
213+ block_table ,
214+ seq_lens ,
215+ output ,
216+ ) = param
217+ seq_lens = forward_context . attn_metadata [ key ]. seq_lens
218+ torch_npu_check = version_check ()
219+
220220 torch .npu .graph_task_update_begin (update_stream , handle )
221221 if torch_npu_check :
222222 torch_npu ._npu_paged_attention (
@@ -250,30 +250,32 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
250250 graph_params = get_graph_params ()
251251 # FIXME: Behold! We are using a temporary hack here to update the args
252252 # for each layer's attention op in the graph.
253- for key , param , handle , event in zip (
254- forward_context .attn_metadata ,
255- graph_params .attn_params [runtime_shape ],
256- graph_params .handles [runtime_shape ],
257- graph_params .events [runtime_shape ],
258- ):
259- (q_nope , k_nope , q_pe , k_pe , num_heads , num_kv_heads , input_layout ,
260- spec_attn_mask , sparse_mode , scale , block_table , block_size ,
261- seq_lens_list , actual_seq_lengths , attn_output , softmax_lse ) = param
262- seq_lens_list = forward_context .attn_metadata [key ].decode .seq_lens_list
263- if speculative_config and speculative_config .method == "deepseek_mtp" :
264- actual_seq_lengths = forward_context .attn_metadata [
265- key ].decode .actual_seq_lengths_q
266- spec_multiple = speculative_config .num_speculative_tokens + 1
267- seq_lens_list = seq_lens_list + [0 ] * (
268- runtime_shape // spec_multiple - len (seq_lens_list ))
269- actual_seq_lengths = [
270- spec_multiple * (i + 1 )
271- for i in range (runtime_shape // spec_multiple )
272- ]
273- else :
274- seq_lens_list = seq_lens_list + [0 ] * (runtime_shape -
275- len (seq_lens_list ))
276- with torch .npu .stream (update_stream ):
253+ with torch .npu .stream (update_stream ):
254+ for key , param , handle , event in zip (
255+ forward_context .attn_metadata ,
256+ graph_params .attn_params [runtime_shape ],
257+ graph_params .handles [runtime_shape ],
258+ graph_params .events [runtime_shape ],
259+ ):
260+ (q_nope , k_nope , q_pe , k_pe , num_heads , num_kv_heads , input_layout ,
261+ spec_attn_mask , sparse_mode , scale , block_table , block_size ,
262+ seq_lens_list , actual_seq_lengths , attn_output ,
263+ softmax_lse ) = param
264+ seq_lens_list = forward_context .attn_metadata [
265+ key ].decode .seq_lens_list
266+ if speculative_config and speculative_config .method == "deepseek_mtp" :
267+ actual_seq_lengths = forward_context .attn_metadata [
268+ key ].decode .actual_seq_lengths_q
269+ spec_multiple = speculative_config .num_speculative_tokens + 1
270+ seq_lens_list = seq_lens_list + [0 ] * (
271+ runtime_shape // spec_multiple - len (seq_lens_list ))
272+ actual_seq_lengths = [
273+ spec_multiple * (i + 1 )
274+ for i in range (runtime_shape // spec_multiple )
275+ ]
276+ else :
277+ seq_lens_list = seq_lens_list + [0 ] * (runtime_shape -
278+ len (seq_lens_list ))
277279 torch .npu .graph_task_update_begin (update_stream , handle )
278280
279281 torch_npu .npu_fused_infer_attention_score .out (
@@ -305,26 +307,27 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
305307 graph_params = get_graph_params ()
306308 # FIXME: Behold! We are using a temporary hack here to update the args
307309 # for each layer's attention op in the graph.
308- for key , param , handle , event in zip (
309- forward_context .attn_metadata ,
310- graph_params .attn_params [runtime_shape ],
311- graph_params .handles [runtime_shape ],
312- graph_params .events [runtime_shape ],
313- ):
314- (q_nope , k_nope , value , num_heads , num_kv_heads , scale , block_table ,
315- block_size , actual_seq_lengths_kv , attn_output , softmax_lse , cp_rank ,
316- dcp_rank , dcp_size ) = param
317- actual_seq_lengths_kv = forward_context .attn_metadata [
318- key ].decode_meta .num_computed_tokens_of_pcp_dcp [:, cp_rank ,
319- dcp_rank ]
320- pad_length = runtime_shape - len (actual_seq_lengths_kv )
321- pad_tensor = np .zeros (pad_length , dtype = actual_seq_lengths_kv .dtype )
322- actual_seq_lengths_kv = np .concatenate (
323- [actual_seq_lengths_kv , pad_tensor ])
324- if dcp_size > 1 :
325- num_heads = num_heads * dcp_size
326-
327- with torch .npu .stream (update_stream ):
310+ with torch .npu .stream (update_stream ):
311+ for key , param , handle , event in zip (
312+ forward_context .attn_metadata ,
313+ graph_params .attn_params [runtime_shape ],
314+ graph_params .handles [runtime_shape ],
315+ graph_params .events [runtime_shape ],
316+ ):
317+ (q_nope , k_nope , value , num_heads , num_kv_heads , scale , block_table ,
318+ block_size , actual_seq_lengths_kv , attn_output , softmax_lse , cp_rank ,
319+ dcp_rank , dcp_size ) = param
320+ actual_seq_lengths_kv = forward_context .attn_metadata [
321+ key ].decode_meta .num_computed_tokens_of_pcp_dcp [:, cp_rank ,
322+ dcp_rank ]
323+ pad_length = runtime_shape - len (actual_seq_lengths_kv )
324+ pad_tensor = np .zeros (pad_length ,
325+ dtype = actual_seq_lengths_kv .dtype )
326+ actual_seq_lengths_kv = np .concatenate (
327+ [actual_seq_lengths_kv , pad_tensor ])
328+ if dcp_size > 1 :
329+ num_heads = num_heads * dcp_size
330+
328331 torch .npu .graph_task_update_begin (update_stream , handle )
329332
330333 torch_npu .npu_fused_infer_attention_score .out (
@@ -354,30 +357,30 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
354357 graph_params = get_graph_params ()
355358 # FIXME: Behold! We are using a temporary hack here to update the args
356359 # for each layer's attention op in the graph.
357- for key , param , handle , event in zip (
358- forward_context . attn_metadata ,
359- graph_params . attn_params [ runtime_shape ] ,
360- graph_params .handles [runtime_shape ],
361- graph_params .events [runtime_shape ],
362- ):
363- ( q_nope , q_pe , k_nope , k_pe , block_table , seq_len , num_heads , scale ,
364- num_kv_heads , attn_output , softmax_lse ) = param
365-
366- decode_meta = forward_context . attn_metadata [ key ]. decode
367- seq_len = decode_meta . cp_seq_len
368-
369- if speculative_config and speculative_config . method == "deepseek_mtp" :
370- spec_multiple = speculative_config .num_speculative_tokens + 1
371- seq_len = seq_len + [ 0 ] * ( runtime_shape // spec_multiple -
372- len ( seq_len ))
373- else :
374- pad_length = runtime_shape - len ( seq_len )
375- pad_tensor = torch . zeros ( pad_length ,
376- dtype = seq_len . dtype ,
377- device = seq_len .device )
378- seq_len = torch . cat ([ seq_len , pad_tensor ], dim = 0 )
379-
380- with torch . npu . stream ( update_stream ):
360+ with torch . npu . stream ( update_stream ):
361+ for key , param , handle , event in zip (
362+ forward_context . attn_metadata ,
363+ graph_params .attn_params [runtime_shape ],
364+ graph_params .handles [runtime_shape ],
365+ graph_params . events [ runtime_shape ],
366+ ):
367+ ( q_nope , q_pe , k_nope , k_pe , block_table , seq_len , num_heads ,
368+ scale , num_kv_heads , attn_output , softmax_lse ) = param
369+
370+ decode_meta = forward_context . attn_metadata [ key ]. decode
371+ seq_len = decode_meta . cp_seq_len
372+
373+ if speculative_config and speculative_config .method == "deepseek_mtp" :
374+ spec_multiple = speculative_config . num_speculative_tokens + 1
375+ seq_len = seq_len + [ 0 ] * ( runtime_shape // spec_multiple -
376+ len ( seq_len ))
377+ else :
378+ pad_length = runtime_shape - len ( seq_len )
379+ pad_tensor = torch . zeros ( pad_length ,
380+ dtype = seq_len .dtype ,
381+ device = seq_len . device )
382+ seq_len = torch . cat ([ seq_len , pad_tensor ], dim = 0 )
383+
381384 torch .npu .graph_task_update_begin (update_stream , handle )
382385
383386 torch_npu .atb .npu_multi_head_latent_attention (
0 commit comments