Skip to content

Commit c201ddb

Browse files
committed
fix mtp torchair
1 parent 001e6a8 commit c201ddb

File tree

2 files changed

+19
-15
lines changed

2 files changed

+19
-15
lines changed

vllm_ascend/ops/vocab_parallel_embedding.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,23 @@ def get_masked_input_and_mask(
3030
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
3131
# torch.compile will fuse all of the pointwise ops below
3232
# into a single kernel, making it very fast
33-
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
34-
org_vocab_end_index)
35-
added_vocab_mask = (input_ >= added_vocab_start_index) & (
36-
input_ < added_vocab_end_index)
37-
added_offset = added_vocab_start_index - (
38-
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
39-
valid_offset = (org_vocab_start_index *
40-
org_vocab_mask) + (added_offset * added_vocab_mask)
41-
vocab_mask = org_vocab_mask | added_vocab_mask
33+
org_vocab_mask = (input_ >= org_vocab_start_index) & (
34+
input_ < org_vocab_end_index)
35+
# Adapt: avoid create added_vocab_mask when added_vocab_start_index == added_vocab_end_index.
36+
if added_vocab_start_index == added_vocab_end_index:
37+
valid_offset = (org_vocab_start_index * org_vocab_mask)
38+
vocab_mask = org_vocab_mask
39+
else:
40+
added_vocab_mask = (input_ >= added_vocab_start_index) & (
41+
input_ < added_vocab_end_index)
42+
added_offset = added_vocab_start_index - (
43+
org_vocab_end_index -
44+
org_vocab_start_index) - num_org_vocab_padding
45+
valid_offset = (org_vocab_start_index *
46+
org_vocab_mask) + (added_offset * added_vocab_mask)
47+
vocab_mask = org_vocab_mask | added_vocab_mask
48+
# Adapt end.
4249
input_ = vocab_mask * (input_ - valid_offset)
43-
#FIXME(xyx) refactor this
44-
torch._dynamo.mark_static(vocab_mask)
4550
return input_, ~vocab_mask
4651

4752

vllm_ascend/worker/mtp_proposer_v1.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,7 @@ def prepare_inputs(
100100
# [a, b, c] -> [a - n1, b - n2, c - n3]
101101
num_tokens_per_req = query_len_per_req - num_rejected_tokens
102102
if is_torchair_graph:
103-
cu_num_tokens = torch.empty_like(cu_target_query_lens)
104-
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
105-
cu_num_tokens[0] = 0
103+
cu_num_tokens = cu_target_query_lens
106104
relative_index = query_len_per_req - num_rejected_tokens - 1
107105
token_indices = cu_num_tokens[:-1] + relative_index
108106
elif force_one_token:
@@ -239,7 +237,8 @@ def propose(
239237
input_ids=self.input_ids[:num_input_tokens],
240238
positions=self.positions[:num_input_tokens],
241239
previous_hidden_states=self.
242-
hidden_states[:num_input_tokens])
240+
hidden_states[:num_input_tokens],
241+
kv_caches=self.runner.kv_caches[-1:])
243242
sample_hidden_states = hidden_states[last_token_indices]
244243
logits = self.model.compute_logits(sample_hidden_states, None)
245244
draft_token_ids = logits.argmax(dim=-1)

0 commit comments

Comments
 (0)