@@ -194,26 +194,25 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
194194 graph_params = get_graph_params ()
195195 # FIXME: Behold! We are using a temporary hack here to update the args
196196 # for each layer's attention op in the graph.
197- for key , param , handle , event in zip (
198- forward_context .attn_metadata ,
199- graph_params .attn_params [runtime_shape ],
200- graph_params .handles [runtime_shape ],
201- graph_params .events [runtime_shape ],
202- ):
203- (
204- query ,
205- key_cache ,
206- value_cache ,
207- num_kv_heads ,
208- num_heads ,
209- scale ,
210- block_table ,
211- seq_lens ,
212- output ,
213- ) = param
214- seq_lens = forward_context .attn_metadata [key ].seq_lens
215-
216- with torch .npu .stream (update_stream ):
197+ with torch .npu .stream (update_stream ):
198+ for key , param , handle , event in zip (
199+ forward_context .attn_metadata ,
200+ graph_params .attn_params [runtime_shape ],
201+ graph_params .handles [runtime_shape ],
202+ graph_params .events [runtime_shape ],
203+ ):
204+ (
205+ query ,
206+ key_cache ,
207+ value_cache ,
208+ num_kv_heads ,
209+ num_heads ,
210+ scale ,
211+ block_table ,
212+ seq_lens ,
213+ output ,
214+ ) = param
215+ seq_lens = forward_context .attn_metadata [key ].seq_lens
217216 torch .npu .graph_task_update_begin (update_stream , handle )
218217 torch_npu ._npu_paged_attention (
219218 query = query ,
@@ -236,30 +235,32 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
236235 graph_params = get_graph_params ()
237236 # FIXME: Behold! We are using a temporary hack here to update the args
238237 # for each layer's attention op in the graph.
239- for key , param , handle , event in zip (
240- forward_context .attn_metadata ,
241- graph_params .attn_params [runtime_shape ],
242- graph_params .handles [runtime_shape ],
243- graph_params .events [runtime_shape ],
244- ):
245- (q_nope , k_nope , q_pe , k_pe , num_heads , num_kv_heads , input_layout ,
246- spec_attn_mask , sparse_mode , scale , block_table , block_size ,
247- seq_lens_list , actual_seq_lengths , attn_output , softmax_lse ) = param
248- seq_lens_list = forward_context .attn_metadata [key ].decode .seq_lens_list
249- if speculative_config and speculative_config .method == "deepseek_mtp" :
250- actual_seq_lengths = forward_context .attn_metadata [
251- key ].decode .actual_seq_lengths_q
252- spec_multiple = speculative_config .num_speculative_tokens + 1
253- seq_lens_list = seq_lens_list + [0 ] * (
254- runtime_shape // spec_multiple - len (seq_lens_list ))
255- actual_seq_lengths = [
256- spec_multiple * (i + 1 )
257- for i in range (runtime_shape // spec_multiple )
258- ]
259- else :
260- seq_lens_list = seq_lens_list + [0 ] * (runtime_shape -
261- len (seq_lens_list ))
262- with torch .npu .stream (update_stream ):
238+ with torch .npu .stream (update_stream ):
239+ for key , param , handle , event in zip (
240+ forward_context .attn_metadata ,
241+ graph_params .attn_params [runtime_shape ],
242+ graph_params .handles [runtime_shape ],
243+ graph_params .events [runtime_shape ],
244+ ):
245+ (q_nope , k_nope , q_pe , k_pe , num_heads , num_kv_heads , input_layout ,
246+ spec_attn_mask , sparse_mode , scale , block_table , block_size ,
247+ seq_lens_list , actual_seq_lengths , attn_output ,
248+ softmax_lse ) = param
249+ seq_lens_list = forward_context .attn_metadata [
250+ key ].decode .seq_lens_list
251+ if speculative_config and speculative_config .method == "deepseek_mtp" :
252+ actual_seq_lengths = forward_context .attn_metadata [
253+ key ].decode .actual_seq_lengths_q
254+ spec_multiple = speculative_config .num_speculative_tokens + 1
255+ seq_lens_list = seq_lens_list + [0 ] * (
256+ runtime_shape // spec_multiple - len (seq_lens_list ))
257+ actual_seq_lengths = [
258+ spec_multiple * (i + 1 )
259+ for i in range (runtime_shape // spec_multiple )
260+ ]
261+ else :
262+ seq_lens_list = seq_lens_list + [0 ] * (runtime_shape -
263+ len (seq_lens_list ))
263264 torch .npu .graph_task_update_begin (update_stream , handle )
264265
265266 torch_npu .npu_fused_infer_attention_score .out (
@@ -291,26 +292,27 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
291292 graph_params = get_graph_params ()
292293 # FIXME: Behold! We are using a temporary hack here to update the args
293294 # for each layer's attention op in the graph.
294- for key , param , handle , event in zip (
295- forward_context .attn_metadata ,
296- graph_params .attn_params [runtime_shape ],
297- graph_params .handles [runtime_shape ],
298- graph_params .events [runtime_shape ],
299- ):
300- (q_nope , k_nope , value , num_heads , num_kv_heads , scale , block_table ,
301- block_size , actual_seq_lengths_kv , attn_output , softmax_lse , cp_rank ,
302- dcp_rank , dcp_size ) = param
303- actual_seq_lengths_kv = forward_context .attn_metadata [
304- key ].decode_meta .num_computed_tokens_of_pcp_dcp [:, cp_rank ,
305- dcp_rank ]
306- pad_length = runtime_shape - len (actual_seq_lengths_kv )
307- pad_tensor = np .zeros (pad_length , dtype = actual_seq_lengths_kv .dtype )
308- actual_seq_lengths_kv = np .concatenate (
309- [actual_seq_lengths_kv , pad_tensor ])
310- if dcp_size > 1 :
311- num_heads = num_heads * dcp_size
312-
313- with torch .npu .stream (update_stream ):
295+ with torch .npu .stream (update_stream ):
296+ for key , param , handle , event in zip (
297+ forward_context .attn_metadata ,
298+ graph_params .attn_params [runtime_shape ],
299+ graph_params .handles [runtime_shape ],
300+ graph_params .events [runtime_shape ],
301+ ):
302+ (q_nope , k_nope , value , num_heads , num_kv_heads , scale ,
303+ block_table , block_size , actual_seq_lengths_kv , attn_output ,
304+ softmax_lse , cp_rank , dcp_rank , dcp_size ) = param
305+ actual_seq_lengths_kv = forward_context .attn_metadata [
306+ key ].decode_meta .num_computed_tokens_of_pcp_dcp [:, cp_rank ,
307+ dcp_rank ]
308+ pad_length = runtime_shape - len (actual_seq_lengths_kv )
309+ pad_tensor = np .zeros (pad_length ,
310+ dtype = actual_seq_lengths_kv .dtype )
311+ actual_seq_lengths_kv = np .concatenate (
312+ [actual_seq_lengths_kv , pad_tensor ])
313+ if dcp_size > 1 :
314+ num_heads = num_heads * dcp_size
315+
314316 torch .npu .graph_task_update_begin (update_stream , handle )
315317
316318 torch_npu .npu_fused_infer_attention_score .out (
@@ -340,30 +342,30 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
340342 graph_params = get_graph_params ()
341343 # FIXME: Behold! We are using a temporary hack here to update the args
342344 # for each layer's attention op in the graph.
343- for key , param , handle , event in zip (
344- forward_context . attn_metadata ,
345- graph_params . attn_params [ runtime_shape ] ,
346- graph_params .handles [runtime_shape ],
347- graph_params .events [runtime_shape ],
348- ):
349- ( q_nope , q_pe , k_nope , k_pe , block_table , seq_len , num_heads , scale ,
350- num_kv_heads , attn_output , softmax_lse ) = param
351-
352- decode_meta = forward_context . attn_metadata [ key ]. decode
353- seq_len = decode_meta . cp_seq_len
354-
355- if speculative_config and speculative_config . method == "deepseek_mtp" :
356- spec_multiple = speculative_config .num_speculative_tokens + 1
357- seq_len = seq_len + [ 0 ] * ( runtime_shape // spec_multiple -
358- len ( seq_len ))
359- else :
360- pad_length = runtime_shape - len ( seq_len )
361- pad_tensor = torch . zeros ( pad_length ,
362- dtype = seq_len . dtype ,
363- device = seq_len .device )
364- seq_len = torch . cat ([ seq_len , pad_tensor ], dim = 0 )
365-
366- with torch . npu . stream ( update_stream ):
345+ with torch . npu . stream ( update_stream ):
346+ for key , param , handle , event in zip (
347+ forward_context . attn_metadata ,
348+ graph_params .attn_params [runtime_shape ],
349+ graph_params .handles [runtime_shape ],
350+ graph_params . events [ runtime_shape ],
351+ ):
352+ ( q_nope , q_pe , k_nope , k_pe , block_table , seq_len , num_heads ,
353+ scale , num_kv_heads , attn_output , softmax_lse ) = param
354+
355+ decode_meta = forward_context . attn_metadata [ key ]. decode
356+ seq_len = decode_meta . cp_seq_len
357+
358+ if speculative_config and speculative_config .method == "deepseek_mtp" :
359+ spec_multiple = speculative_config . num_speculative_tokens + 1
360+ seq_len = seq_len + [ 0 ] * ( runtime_shape // spec_multiple -
361+ len ( seq_len ))
362+ else :
363+ pad_length = runtime_shape - len ( seq_len )
364+ pad_tensor = torch . zeros ( pad_length ,
365+ dtype = seq_len .dtype ,
366+ device = seq_len . device )
367+ seq_len = torch . cat ([ seq_len , pad_tensor ], dim = 0 )
368+
367369 torch .npu .graph_task_update_begin (update_stream , handle )
368370
369371 torch_npu .atb .npu_multi_head_latent_attention (
0 commit comments