@@ -203,13 +203,6 @@ def _forward(
203203 spec_query_start_loc = attn_metadata .spec_query_start_loc
204204 non_spec_query_start_loc = attn_metadata .non_spec_query_start_loc
205205 spec_sequence_masks = attn_metadata .spec_sequence_masks
206- if vllm_version_is ("0.11.0" ):
207- spec_token_masks = attn_metadata .spec_token_masks
208- if spec_token_masks is not None :
209- spec_token_masks = spec_token_masks [:num_actual_tokens ]
210- else :
211- spec_token_indx = attn_metadata .spec_token_indx
212- non_spec_token_indx = attn_metadata .non_spec_token_indx
213206 spec_state_indices_tensor = attn_metadata .spec_state_indices_tensor # noqa: E501
214207 non_spec_state_indices_tensor = attn_metadata .non_spec_state_indices_tensor # noqa: E501
215208 self_kv_cache = self .kv_cache [forward_context .virtual_engine ]
@@ -222,6 +215,14 @@ def _forward(
222215 attn_metadata .num_spec_decode_tokens )
223216 num_accepted_tokens = attn_metadata .num_accepted_tokens
224217
218+ if vllm_version_is ("0.11.0" ):
219+ spec_token_masks = attn_metadata .spec_token_masks
220+ if spec_token_masks is not None :
221+ spec_token_masks = spec_token_masks [:num_actual_tokens ]
222+ else :
223+ spec_token_indx = attn_metadata .spec_token_indx
224+ non_spec_token_indx = attn_metadata .non_spec_token_indx
225+
225226 # 1. Set up dimensions for reshapes later
226227 projected_states , _ = self .in_proj (hidden_states [:num_actual_tokens ])
227228 projected_states_qkvz , projected_states_ba = torch .split (
0 commit comments