Skip to content

Commit bb5f16d

Browse files
authored
[BugFix] Fix Qwen3-next break (#3428)
### What this PR does / why we need it? Fix Qwen3NextGatedDeltaNet, caused by vllm-project/vllm#26437 ### How was this patch tested? ``` def main(): prompts = [ "窗前明月光,", "The president of the United States is Mr.", "The capital of France is", "The future of AI is", "感时花溅泪,", "家书抵万金啥意思?", "plz tell me a story: ", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="/root/.cache/modelscope/hub/models/Qwen/Qwen3-Next-80B-A3B-Instruct", tensor_parallel_size=4, enforce_eager=True, trust_remote_code=True, max_model_len=256, gpu_memory_utilization=0.7, block_size=64 ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: Icey <[email protected]>
1 parent 7572939 commit bb5f16d

File tree

1 file changed

+35
-11
lines changed

1 file changed

+35
-11
lines changed

vllm_ascend/models/qwen3_next.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
from vllm.transformers_utils.configs import Qwen3NextConfig
5252
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
5353

54+
from vllm_ascend.utils import vllm_version_is
55+
5456
from vllm.model_executor.models.qwen3_next import ( # isort: skip
5557
Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextForCausalLM,
5658
Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock,
@@ -201,7 +203,11 @@ def _forward(
201203
spec_query_start_loc = attn_metadata.spec_query_start_loc
202204
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
203205
spec_sequence_masks = attn_metadata.spec_sequence_masks
204-
spec_token_masks = attn_metadata.spec_token_masks
206+
if vllm_version_is("0.11.0"):
207+
spec_token_masks = attn_metadata.spec_token_masks
208+
else:
209+
spec_token_indx = attn_metadata.spec_token_indx
210+
non_spec_token_indx = attn_metadata.non_spec_token_indx
205211
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
206212
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
207213
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
@@ -216,8 +222,9 @@ def _forward(
216222

217223
# 1. Set up dimensions for reshapes later
218224
projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens])
219-
if spec_token_masks is not None:
220-
spec_token_masks = spec_token_masks[:num_actual_tokens]
225+
if vllm_version_is("0.11.0"):
226+
if spec_token_masks is not None:
227+
spec_token_masks = spec_token_masks[:num_actual_tokens]
221228
projected_states_qkvz, projected_states_ba = torch.split(
222229
projected_states,
223230
[
@@ -242,8 +249,13 @@ def _forward(
242249
mixed_qkv_spec = mixed_qkv
243250
mixed_qkv_non_spec = None
244251
else:
245-
mixed_qkv_spec = mixed_qkv[spec_token_masks]
246-
mixed_qkv_non_spec = mixed_qkv[~spec_token_masks]
252+
if vllm_version_is("0.11.0"):
253+
mixed_qkv_spec = mixed_qkv[spec_token_masks]
254+
mixed_qkv_non_spec = mixed_qkv[~spec_token_masks]
255+
else:
256+
mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
257+
mixed_qkv_non_spec = mixed_qkv.index_select(
258+
0, non_spec_token_indx)
247259
else:
248260
mixed_qkv_spec = None
249261
mixed_qkv_non_spec = mixed_qkv
@@ -293,10 +305,16 @@ def _forward(
293305
g_non_spec = None
294306
beta_non_spec = None
295307
else:
296-
g_spec = g[:, spec_token_masks]
297-
beta_spec = beta[:, spec_token_masks]
298-
g_non_spec = g[:, ~spec_token_masks]
299-
beta_non_spec = beta[:, ~spec_token_masks]
308+
if vllm_version_is("0.11.0"):
309+
g_spec = g[:, spec_token_masks]
310+
beta_spec = beta[:, spec_token_masks]
311+
g_non_spec = g[:, ~spec_token_masks]
312+
beta_non_spec = beta[:, ~spec_token_masks]
313+
else:
314+
g_spec = g.index_select(1, spec_token_indx)
315+
beta_spec = beta.index_select(1, spec_token_indx)
316+
g_non_spec = g.index_select(1, non_spec_token_indx)
317+
beta_non_spec = beta.index_select(1, non_spec_token_indx)
300318
else:
301319
g_spec = None
302320
beta_spec = None
@@ -404,8 +422,14 @@ def _forward(
404422
dtype=core_attn_out_non_spec.dtype,
405423
device=core_attn_out_non_spec.device,
406424
)
407-
core_attn_out[:, spec_token_masks] = core_attn_out_spec
408-
core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec
425+
if vllm_version_is("0.11.0"):
426+
core_attn_out[:, spec_token_masks] = core_attn_out_spec
427+
core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec
428+
else:
429+
core_attn_out.index_copy_(1, spec_token_indx,
430+
core_attn_out_spec)
431+
core_attn_out.index_copy_(1, non_spec_token_indx,
432+
core_attn_out_non_spec)
409433
elif spec_sequence_masks is not None:
410434
core_attn_out = core_attn_out_spec
411435
else:

0 commit comments

Comments
 (0)