Skip to content

Commit aec1751

Browse files
committed
[0.9.1][MTP V1]MTP model adapt torchair
Signed-off-by: xuyexiong <[email protected]> fix mtp torchair fix
1 parent 3715398 commit aec1751

File tree

6 files changed

+143
-91
lines changed

6 files changed

+143
-91
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class AscendMLADecodeMetadata:
9494
seq_lens: torch.Tensor
9595
max_seq_lens: int
9696
seq_lens_list: list[int]
97-
actual_seq_q_lens: Optional[list[int]] = None
97+
actual_seq_lengths_q: Optional[list[int]] = None
9898
attn_mask: Optional[torch.Tensor] = None
9999
sin: torch.Tensor = None
100100
cos: torch.Tensor = None
@@ -131,7 +131,6 @@ class AscendMLAMetadata:
131131
num_input_tokens: int = 0 # Number of tokens including padding.
132132

133133
enable_dbo_across_dp: bool = False
134-
is_mtp_model: bool = False
135134

136135
query_lens: Optional[list[int]] = None
137136
# The dimension of the attention heads
@@ -313,7 +312,6 @@ def build_torchair_graph_dummy(
313312
self,
314313
num_reqs: int,
315314
num_actual_tokens: int,
316-
is_mtp_model: bool = False,
317315
) -> AscendMLAMetadata:
318316
device = self.runner.device
319317
_, max_blocks = self.runner.graph_block_tables.shape
@@ -337,7 +335,7 @@ def build_torchair_graph_dummy(
337335
dtype=torch.int32,
338336
device=device)
339337
if self.runner.speculative_config is not None and\
340-
self.runner.speculative_config.method == 'deepseek_mtp' and not is_mtp_model:
338+
self.runner.speculative_config.method == 'deepseek_mtp':
341339
attn_state = AscendAttentionState.SpecDecoding
342340
num_decode_tokens = 2
343341
else:
@@ -362,7 +360,7 @@ def build_torchair_graph_dummy(
362360
seq_lens_list=seq_lens_list,
363361
max_seq_lens=1,
364362
attn_mask=self.runner.spec_attn_mask,
365-
actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs],
363+
actual_seq_lengths_q=self.runner.actual_seq_lengths_q[:num_reqs],
366364
sin=sin,
367365
cos=cos)
368366
return self.metadata_cls( # type: ignore
@@ -380,7 +378,6 @@ def build_torchair_graph_dummy(
380378
query_start_loc=query_start_loc,
381379
seq_lens=seq_lens,
382380
block_tables=block_table,
383-
is_mtp_model=is_mtp_model,
384381
)
385382

386383
def build(
@@ -393,7 +390,6 @@ def build(
393390
num_token_pad_size: int = -1,
394391
num_reqs_pad_size: int = 0,
395392
enable_dbo_across_dp: bool = False,
396-
is_mtp_model: bool = False,
397393
) -> AscendMLAMetadata:
398394
assert self._num_decodes + self._num_prefills == num_reqs
399395

@@ -498,7 +494,7 @@ def build(
498494
decode_metadata = None
499495
use_torchair_graph = num_token_pad_size != -1
500496
if self._num_decodes > 0:
501-
actual_seq_q_lens = query_start_loc[1:].tolist()
497+
actual_seq_lengths_q = query_start_loc[1:].tolist()
502498
max_seq_lens = seq_lens[:self._num_decodes].max().item()
503499
seq_lens = seq_lens[:self._num_decode_tokens]
504500
input_positions = input_positions[:self._num_decode_tokens]
@@ -534,16 +530,21 @@ def build(
534530
dtype=input_positions.dtype,
535531
device=input_positions.device)
536532
input_positions = torch.cat([input_positions, padding_0])
537-
actual_seq_q_lens = query_start_loc[1:].tolist(
538-
) + self.runner.actual_seq_q_lens[num_reqs:num_reqs +
533+
actual_seq_lengths_q = query_start_loc[1:].tolist(
534+
) + self.runner.actual_seq_lengths_q[num_reqs:num_reqs +
539535
num_reqs_pad_size]
540-
# mtp torchair + PD scenario, last element of actual_seq_q_lens must equal to num_padded_token_size
536+
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to num_padded_token_size
541537
num_padded_token_size = slot_mapping.size(0)
542-
if actual_seq_q_lens[-1] != num_padded_token_size \
538+
if actual_seq_lengths_q[-1] != num_padded_token_size \
543539
and self.runner.attn_state == AscendAttentionState.SpecDecoding:
544-
actual_seq_q_lens[-1] = num_padded_token_size
540+
actual_seq_lengths_q[-1] = num_padded_token_size
545541
else:
546542
seq_lens_list = seq_lens.tolist()
543+
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to num_padded_token_size
544+
num_padded_token_size = slot_mapping.size(0)
545+
if actual_seq_lengths_q[-1] != num_padded_token_size \
546+
and self.runner.attn_state == AscendAttentionState.SpecDecoding:
547+
actual_seq_lengths_q[-1] = num_padded_token_size
547548

548549
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
549550
1).unsqueeze(2)
@@ -557,7 +558,7 @@ def build(
557558
seq_lens_list=seq_lens_list,
558559
max_seq_lens=max_seq_lens,
559560
attn_mask=self.runner.spec_attn_mask,
560-
actual_seq_q_lens=actual_seq_q_lens,
561+
actual_seq_lengths_q=actual_seq_lengths_q,
561562
sin=sin,
562563
cos=cos)
563564

@@ -577,7 +578,6 @@ def build(
577578
block_tables=block_table,
578579
seq_lens=seq_lens,
579580
enable_dbo_across_dp=enable_dbo_across_dp,
580-
is_mtp_model=is_mtp_model,
581581
)
582582

583583

@@ -1017,16 +1017,13 @@ def _forward_decode(
10171017

10181018
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
10191019
assert num_tokens % self.spec_token_num == 0
1020-
if self.enable_kv_nz:
1021-
input_layout = "TND_NTD"
1022-
else:
1023-
input_layout = "TND"
1020+
input_layout = "TND"
10241021
# [bs * q_seq_len, num_heads_per_rank, dim]
10251022
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
10261023
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
10271024
sparse_mode = 3
10281025
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
1029-
actual_seq_lengths = decode_meta.actual_seq_q_lens
1026+
actual_seq_lengths = decode_meta.actual_seq_lengths_q
10301027
else:
10311028
if self.enable_kv_nz:
10321029
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
@@ -1110,8 +1107,6 @@ def forward(
11101107
if attn_metadata is None:
11111108
# Profiling run.
11121109
return output
1113-
# mtp model is not support for graph mode yet
1114-
self.torchair_graph_enabled = self.torchair_graph_enabled and not attn_metadata.is_mtp_model
11151110
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
11161111
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
11171112
]

vllm_ascend/models/deepseek_mtp.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from transformers import PretrainedConfig
2525
from vllm.attention.backends.abstract import AttentionMetadata
2626
from vllm.config import CacheConfig, ModelConfig, VllmConfig
27+
from vllm.forward_context import get_forward_context
2728
from vllm.model_executor.layers.layernorm import RMSNorm
2829
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2930
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -98,9 +99,11 @@ def forward(
9899
inputs_embeds = self.embed_tokens(input_ids)
99100
assert inputs_embeds is not None
100101
# masking inputs at position 0, as not needed by MTP
101-
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
102-
torch.zeros_like(inputs_embeds),
103-
inputs_embeds)
102+
forward_context = get_forward_context()
103+
if forward_context.with_prefill:
104+
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
105+
torch.zeros_like(inputs_embeds),
106+
inputs_embeds)
104107
inputs_embeds = self.enorm(inputs_embeds)
105108
previous_hidden_states = self.hnorm(previous_hidden_states)
106109

vllm_ascend/models/deepseek_v2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,8 +485,7 @@ def forward(
485485
hidden_states_or_q_c = self.q_a_layernorm(ckq)
486486
else:
487487
hidden_states_or_q_c = hidden_states
488-
is_mtp_model = attn_metadata is not None and attn_metadata.is_mtp_model
489-
if self.torchair_graph_enabled and not is_mtp_model:
488+
if self.torchair_graph_enabled:
490489
if envs.VLLM_USE_V1:
491490
output_shape = hidden_states.shape
492491
output = torch.empty(output_shape,

vllm_ascend/ops/vocab_parallel_embedding.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,22 @@ def get_masked_input_and_mask(
3030
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
3131
# torch.compile will fuse all of the pointwise ops below
3232
# into a single kernel, making it very fast
33-
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
34-
org_vocab_end_index)
35-
added_vocab_mask = (input_ >= added_vocab_start_index) & (
36-
input_ < added_vocab_end_index)
37-
added_offset = added_vocab_start_index - (
38-
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
39-
valid_offset = (org_vocab_start_index *
40-
org_vocab_mask) + (added_offset * added_vocab_mask)
41-
vocab_mask = org_vocab_mask | added_vocab_mask
33+
org_vocab_mask = (input_ >= org_vocab_start_index) & (
34+
input_ < org_vocab_end_index)
35+
# Adapt: avoid create added_vocab_mask when added_vocab_start_index == added_vocab_end_index.
36+
if added_vocab_start_index == added_vocab_end_index:
37+
valid_offset = (org_vocab_start_index * org_vocab_mask)
38+
vocab_mask = org_vocab_mask
39+
else:
40+
added_vocab_mask = (input_ >= added_vocab_start_index) & (
41+
input_ < added_vocab_end_index)
42+
added_offset = added_vocab_start_index - (
43+
org_vocab_end_index -
44+
org_vocab_start_index) - num_org_vocab_padding
45+
valid_offset = (org_vocab_start_index *
46+
org_vocab_mask) + (added_offset * added_vocab_mask)
47+
vocab_mask = org_vocab_mask | added_vocab_mask
48+
# Adapt end.
4249
input_ = vocab_mask * (input_ - valid_offset)
4350
return input_, ~vocab_mask
4451

vllm_ascend/worker/model_runner_v1.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
212212
# Set up speculative decoding.
213213
self.use_spec_decode = False
214214
self.spec_attn_mask = None
215-
self.actual_seq_q_lens = []
215+
self.actual_seq_lengths_q = []
216216
self.spec_token_num = 0
217217
self.decode_token_per_req = 1
218218
if self.speculative_config:
@@ -232,7 +232,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
232232
elif self.speculative_config.method == 'deepseek_mtp':
233233
self.drafter = MtpProposer(self.vllm_config, self)
234234
self.decode_token_per_req = 1 + self.spec_token_num
235-
self.actual_seq_q_lens = [
235+
self.actual_seq_lengths_q = [
236236
len for len in
237237
range(self.decode_token_per_req, self.max_num_tokens +
238238
1, self.decode_token_per_req)
@@ -1009,6 +1009,7 @@ def _process_reqs(
10091009
common_attn_metadata = CommonAttentionMetadata(
10101010
query_start_loc=query_start_loc,
10111011
seq_lens=self.seq_lens_cpu[:num_reqs])
1012+
self.common_attn_metadata = common_attn_metadata
10121013
self.seq_lens_list = self.seq_lens_np.tolist()[:num_input_tokens]
10131014
with_prefill = attn_state not in [
10141015
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
@@ -1040,7 +1041,9 @@ def _process_reqs(
10401041
extra_builder_kwargs['num_token_pad_size'] = num_token_pad_size
10411042
extra_builder_kwargs['num_reqs_pad_size'] = num_reqs_pad_size
10421043
self.num_reqs_pad_size = num_reqs_pad_size
1044+
self.num_token_pad_size = num_token_pad_size
10431045
self.extra_builder_kwargs = extra_builder_kwargs
1046+
self.num_tokens_across_dp = num_tokens_across_dp
10441047

10451048
if self.vllm_config.model_config.use_mla:
10461049
attn_metadata = self.attn_metadata_builder.build( # type: ignore
@@ -1740,7 +1743,12 @@ def _dummy_run(
17401743
**model_kwargs)
17411744
if self.speculative_config and self.speculative_config.method == "deepseek_mtp":
17421745
assert isinstance(self.drafter, MtpProposer)
1743-
self.drafter.dummy_run(num_reqs, with_prefill=with_prefill)
1746+
self.drafter.dummy_run(
1747+
num_tokens=num_tokens,
1748+
with_prefill=with_prefill,
1749+
skip_attn=skip_attn,
1750+
num_reqs=num_reqs,
1751+
num_tokens_across_dp=num_tokens_across_dp)
17441752
return hidden_states
17451753

17461754
@contextmanager
@@ -2123,7 +2131,7 @@ def _generate_mtp_token_ids(
21232131
next_token_ids = torch.tensor(next_token_ids,
21242132
dtype=torch.int32,
21252133
device=self.device)
2126-
2134+
token_indices = None
21272135
if spec_decode_metadata is None:
21282136
# input_ids can be None for multimodal models.
21292137
target_token_ids = self.input_ids[:num_scheduled_tokens]
@@ -2146,12 +2154,20 @@ def _generate_mtp_token_ids(
21462154
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
21472155
attn_metadata.query_start_loc,
21482156
num_rejected_tokens,
2149-
force_one_token=True,
2150-
)
2151-
target_token_ids = self.input_ids[token_indices]
2152-
target_positions = positions[token_indices]
2153-
target_hidden_states = hidden_states[token_indices]
2154-
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
2157+
force_one_token=False,
2158+
is_torchair_graph=self.torchair_graph_enabled)
2159+
if self.torchair_graph_enabled:
2160+
# the seq len of each bath is padded to 2, thus input is same as the main model
2161+
target_token_ids = self.input_ids[:num_scheduled_tokens]
2162+
target_positions = positions[:num_scheduled_tokens]
2163+
target_hidden_states = hidden_states[:num_scheduled_tokens]
2164+
target_slot_mapping = attn_metadata.slot_mapping[:
2165+
num_scheduled_tokens]
2166+
else:
2167+
target_token_ids = self.input_ids[token_indices]
2168+
target_positions = positions[token_indices]
2169+
target_hidden_states = hidden_states[token_indices]
2170+
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
21552171

21562172
draft_token_ids = self.drafter.propose(
21572173
target_token_ids=target_token_ids,
@@ -2162,7 +2178,7 @@ def _generate_mtp_token_ids(
21622178
cu_num_tokens=cu_num_tokens,
21632179
block_table=attn_metadata.block_tables,
21642180
sampling_metadata=sampling_metadata,
2165-
)
2181+
token_indices=token_indices)
21662182
spec_token_ids = draft_token_ids.tolist()
21672183
return spec_token_ids
21682184

0 commit comments

Comments
 (0)