@@ -94,7 +94,7 @@ class AscendMLADecodeMetadata:
94
94
seq_lens : torch .Tensor
95
95
max_seq_lens : int
96
96
seq_lens_list : list [int ]
97
- actual_seq_q_lens : Optional [list [int ]] = None
97
+ actual_seq_lengths_q : Optional [list [int ]] = None
98
98
attn_mask : Optional [torch .Tensor ] = None
99
99
sin : torch .Tensor = None
100
100
cos : torch .Tensor = None
@@ -131,7 +131,6 @@ class AscendMLAMetadata:
131
131
num_input_tokens : int = 0 # Number of tokens including padding.
132
132
133
133
enable_dbo_across_dp : bool = False
134
- is_mtp_model : bool = False
135
134
136
135
query_lens : Optional [list [int ]] = None
137
136
# The dimension of the attention heads
@@ -313,7 +312,6 @@ def build_torchair_graph_dummy(
313
312
self ,
314
313
num_reqs : int ,
315
314
num_actual_tokens : int ,
316
- is_mtp_model : bool = False ,
317
315
) -> AscendMLAMetadata :
318
316
device = self .runner .device
319
317
_ , max_blocks = self .runner .graph_block_tables .shape
@@ -337,7 +335,7 @@ def build_torchair_graph_dummy(
337
335
dtype = torch .int32 ,
338
336
device = device )
339
337
if self .runner .speculative_config is not None and \
340
- self .runner .speculative_config .method == 'deepseek_mtp' and not is_mtp_model :
338
+ self .runner .speculative_config .method == 'deepseek_mtp' :
341
339
attn_state = AscendAttentionState .SpecDecoding
342
340
num_decode_tokens = 2
343
341
else :
@@ -362,7 +360,7 @@ def build_torchair_graph_dummy(
362
360
seq_lens_list = seq_lens_list ,
363
361
max_seq_lens = 1 ,
364
362
attn_mask = self .runner .spec_attn_mask ,
365
- actual_seq_q_lens = self .runner .actual_seq_q_lens [:num_reqs ],
363
+ actual_seq_lengths_q = self .runner .actual_seq_lengths_q [:num_reqs ],
366
364
sin = sin ,
367
365
cos = cos )
368
366
return self .metadata_cls ( # type: ignore
@@ -380,7 +378,6 @@ def build_torchair_graph_dummy(
380
378
query_start_loc = query_start_loc ,
381
379
seq_lens = seq_lens ,
382
380
block_tables = block_table ,
383
- is_mtp_model = is_mtp_model ,
384
381
)
385
382
386
383
def build (
@@ -393,7 +390,6 @@ def build(
393
390
num_token_pad_size : int = - 1 ,
394
391
num_reqs_pad_size : int = 0 ,
395
392
enable_dbo_across_dp : bool = False ,
396
- is_mtp_model : bool = False ,
397
393
) -> AscendMLAMetadata :
398
394
assert self ._num_decodes + self ._num_prefills == num_reqs
399
395
@@ -498,7 +494,7 @@ def build(
498
494
decode_metadata = None
499
495
use_torchair_graph = num_token_pad_size != - 1
500
496
if self ._num_decodes > 0 :
501
- actual_seq_q_lens = query_start_loc [1 :].tolist ()
497
+ actual_seq_lengths_q = query_start_loc [1 :].tolist ()
502
498
max_seq_lens = seq_lens [:self ._num_decodes ].max ().item ()
503
499
seq_lens = seq_lens [:self ._num_decode_tokens ]
504
500
input_positions = input_positions [:self ._num_decode_tokens ]
@@ -534,16 +530,21 @@ def build(
534
530
dtype = input_positions .dtype ,
535
531
device = input_positions .device )
536
532
input_positions = torch .cat ([input_positions , padding_0 ])
537
- actual_seq_q_lens = query_start_loc [1 :].tolist (
538
- ) + self .runner .actual_seq_q_lens [num_reqs :num_reqs +
533
+ actual_seq_lengths_q = query_start_loc [1 :].tolist (
534
+ ) + self .runner .actual_seq_lengths_q [num_reqs :num_reqs +
539
535
num_reqs_pad_size ]
540
- # mtp torchair + PD scenario, last element of actual_seq_q_lens must equal to num_padded_token_size
536
+ # mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to num_padded_token_size
541
537
num_padded_token_size = slot_mapping .size (0 )
542
- if actual_seq_q_lens [- 1 ] != num_padded_token_size \
538
+ if actual_seq_lengths_q [- 1 ] != num_padded_token_size \
543
539
and self .runner .attn_state == AscendAttentionState .SpecDecoding :
544
- actual_seq_q_lens [- 1 ] = num_padded_token_size
540
+ actual_seq_lengths_q [- 1 ] = num_padded_token_size
545
541
else :
546
542
seq_lens_list = seq_lens .tolist ()
543
+ # mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to num_padded_token_size
544
+ num_padded_token_size = slot_mapping .size (0 )
545
+ if actual_seq_lengths_q [- 1 ] != num_padded_token_size \
546
+ and self .runner .attn_state == AscendAttentionState .SpecDecoding :
547
+ actual_seq_lengths_q [- 1 ] = num_padded_token_size
547
548
548
549
cos = self .cos_cache [input_positions ].unsqueeze ( # type: ignore
549
550
1 ).unsqueeze (2 )
@@ -557,7 +558,7 @@ def build(
557
558
seq_lens_list = seq_lens_list ,
558
559
max_seq_lens = max_seq_lens ,
559
560
attn_mask = self .runner .spec_attn_mask ,
560
- actual_seq_q_lens = actual_seq_q_lens ,
561
+ actual_seq_lengths_q = actual_seq_lengths_q ,
561
562
sin = sin ,
562
563
cos = cos )
563
564
@@ -577,7 +578,6 @@ def build(
577
578
block_tables = block_table ,
578
579
seq_lens = seq_lens ,
579
580
enable_dbo_across_dp = enable_dbo_across_dp ,
580
- is_mtp_model = is_mtp_model ,
581
581
)
582
582
583
583
@@ -1017,16 +1017,13 @@ def _forward_decode(
1017
1017
1018
1018
if attn_metadata .attn_state == AscendAttentionState .SpecDecoding :
1019
1019
assert num_tokens % self .spec_token_num == 0
1020
- if self .enable_kv_nz :
1021
- input_layout = "TND_NTD"
1022
- else :
1023
- input_layout = "TND"
1020
+ input_layout = "TND"
1024
1021
# [bs * q_seq_len, num_heads_per_rank, dim]
1025
1022
q_nope = q_nope .view (num_tokens , self .num_heads , - 1 )
1026
1023
q_pe = q_pe .view (num_tokens , self .num_heads , - 1 )
1027
1024
sparse_mode = 3
1028
1025
spec_attn_mask = attn_metadata .decode .attn_mask # type:ignore
1029
- actual_seq_lengths = decode_meta .actual_seq_q_lens
1026
+ actual_seq_lengths = decode_meta .actual_seq_lengths_q
1030
1027
else :
1031
1028
if self .enable_kv_nz :
1032
1029
q_nope = q_nope .view (num_tokens , 1 , self .num_heads , - 1 )
@@ -1110,8 +1107,6 @@ def forward(
1110
1107
if attn_metadata is None :
1111
1108
# Profiling run.
1112
1109
return output
1113
- # mtp model is not support for graph mode yet
1114
- self .torchair_graph_enabled = self .torchair_graph_enabled and not attn_metadata .is_mtp_model
1115
1110
self .running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state in [
1116
1111
AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding
1117
1112
]
0 commit comments