From c410f9e4a44e0988260dacafb6de9d03c17a8a24 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 28 Jul 2025 19:24:16 +0800 Subject: [PATCH 1/5] + v1 Signed-off-by: wang.yuqi --- .../models/language/pooling/test_classification.py | 8 -------- tests/models/language/pooling/test_gte.py | 10 ---------- tests/models/language/pooling/test_jina.py | 13 ------------- vllm/config.py | 5 +++++ vllm/model_executor/models/bert_with_rope.py | 6 ++---- vllm/model_executor/models/modernbert.py | 2 -- 6 files changed, 7 insertions(+), 37 deletions(-) diff --git a/tests/models/language/pooling/test_classification.py b/tests/models/language/pooling/test_classification.py index 77df6d16a367..c71fa9627533 100644 --- a/tests/models/language/pooling/test_classification.py +++ b/tests/models/language/pooling/test_classification.py @@ -6,14 +6,6 @@ from vllm.platforms import current_platform -# TODO: enable when float32 is supported by V1 -# @pytest.fixture(autouse=True) -# def v1(run_with_both_engines): -# # Simple autouse wrapper to run both engines for each test -# # This can be promoted up to conftest.py to run for every -# # test in a package -# pass - @pytest.mark.parametrize( "model", diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index 0ad54785308e..c4fc3711e6c5 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -56,17 +56,10 @@ enable_test=False), ] -V1FlashAttentionImpNotSupported = [ - "Alibaba-NLP/gte-Qwen2-1.5B-instruct", "Alibaba-NLP/gte-modernbert-base" -] - @pytest.mark.parametrize("model_info", MODELS) def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo, monkeypatch) -> None: - if model_info.name in V1FlashAttentionImpNotSupported: - monkeypatch.setenv("VLLM_USE_V1", "0") - vllm_extra_kwargs: dict[str, Any] = {} if model_info.architecture == "GteNewModel": vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} @@ -79,9 +72,6 @@ def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo, def test_embed_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts, monkeypatch) -> None: - if model_info.name in V1FlashAttentionImpNotSupported: - monkeypatch.setenv("VLLM_USE_V1", "0") - vllm_extra_kwargs: dict[str, Any] = {} if model_info.architecture == "GteNewModel": vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling/test_jina.py index 2ae431de1683..59b634428cef 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling/test_jina.py @@ -4,7 +4,6 @@ import pytest -import vllm.envs as envs from vllm import PoolingParams from ...utils import EmbedModelInfo, RerankModelInfo @@ -24,14 +23,6 @@ ] -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.mark.parametrize("model_info", EMBEDDING_MODELS) def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: @@ -63,10 +54,6 @@ def hf_model_callback(model): @pytest.mark.parametrize("model_info", RERANK_MODELS) def test_rerank_models_mteb(hf_runner, vllm_runner, model_info: RerankModelInfo) -> None: - if (model_info.architecture == "XLMRobertaForSequenceClassification" - and envs.VLLM_USE_V1): - pytest.skip("Not supported yet") - mteb_test_rerank_models(hf_runner, vllm_runner, model_info) diff --git a/vllm/config.py b/vllm/config.py index aa3c20756064..baf8bd3245b6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4885,6 +4885,11 @@ def try_verify_and_update_config(self): if self.model_config is None: return + # Avoid running try_verify_and_update_config multiple times + if getattr(self.model_config, "config_updated", False): + return + self.model_config.config_updated = True + architecture = self.model_config.architecture if architecture is None: return diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 0b7350f07d3f..2824a39563f8 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -8,7 +8,6 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionType -from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import (get_act_and_mul_fn, @@ -364,7 +363,6 @@ def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor): return hidden_states -@support_torch_compile class BertWithRopeEncoder(nn.Module): def __init__(self, @@ -398,7 +396,7 @@ def forward( return hidden_states -class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant): +class BertWithRope(nn.Module, SupportsQuant): hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -468,7 +466,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class NomicBertModel(BertWithRope): +class NomicBertModel(BertWithRope, SupportsV0Only): # for https://huggingface.co/nomic-ai/nomic-bert-2048 hf_to_vllm_mapper = WeightsMapper( diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index fc2b0c1f5182..4967032a244e 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -8,7 +8,6 @@ from transformers import ModernBertConfig from vllm.attention import Attention, AttentionType -from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -200,7 +199,6 @@ def forward( return hidden_states -@support_torch_compile class ModernBertModel(nn.Module): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={"layers.": "encoder_layer.layers."}) From be1d139a30b97b3e47290ba25f2f946ab31379c9 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 28 Jul 2025 20:15:42 +0800 Subject: [PATCH 2/5] + config_updated Signed-off-by: wang.yuqi --- vllm/config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index baf8bd3245b6..51ffedb9a5eb 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -771,6 +771,9 @@ def _task_to_convert(task: TaskOption) -> ConvertType: raise ValueError( "`override_neuron_config` is only supported on Neuron.") + # Avoid running try_verify_and_update_config multiple times + self.config_updated = False + self._verify_quantization() self._verify_cuda_graph() self._verify_bnb_config() From d30f5075be925c16247d29164e8c1b4d8b5d63b4 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 29 Jul 2025 08:41:30 +0800 Subject: [PATCH 3/5] + fix NomicBertModel Signed-off-by: wang.yuqi --- vllm/model_executor/models/bert_with_rope.py | 3 +-- vllm/model_executor/models/config.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 2824a39563f8..4a418b10e1b1 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -22,7 +22,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models import SupportsV0Only from vllm.model_executor.models.interfaces import SupportsQuant from vllm.model_executor.models.utils import WeightsMapper from vllm.sequence import IntermediateTensors @@ -466,7 +465,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class NomicBertModel(BertWithRope, SupportsV0Only): +class NomicBertModel(BertWithRope): # for https://huggingface.co/nomic-ai/nomic-bert-2048 hf_to_vllm_mapper = WeightsMapper( diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 6f50b1753098..9030ff307bee 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -93,7 +93,7 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: config.num_hidden_layers = config.n_layer head_dim = config.hidden_size // config.num_attention_heads - rotary_emb_dim = head_dim * config.rotary_emb_fraction + rotary_emb_dim = int(head_dim * config.rotary_emb_fraction) max_trained_positions = getattr(config, "max_trained_positions", 2048) config.rotary_kwargs = { "head_size": head_dim, From bcd706a6422ab4a9ec1dfe4a70b45580bdf803f4 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 29 Jul 2025 10:42:51 +0800 Subject: [PATCH 4/5] - max_num_seqs Signed-off-by: wang.yuqi --- tests/models/language/pooling/test_qwen3_reranker.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py index 9c6a833b4138..68e96f32700c 100644 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ b/tests/models/language/pooling/test_qwen3_reranker.py @@ -83,9 +83,6 @@ def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: } } - if model_info.name == "Qwen/Qwen3-Reranker-4B": - vllm_extra_kwargs["max_num_seqs"] = 1 - mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, vllm_extra_kwargs) @@ -106,9 +103,6 @@ def test_rerank_models_mteb_tp(vllm_runner, "tensor_parallel_size": 2, } - if model_info.name == "Qwen/Qwen3-Reranker-4B": - vllm_extra_kwargs["max_num_seqs"] = 1 - mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, From 2c4d9324364060a97864dc364efd79e64c97444e Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 31 Jul 2025 01:25:45 +0800 Subject: [PATCH 5/5] fix Signed-off-by: wang.yuqi --- tests/models/language/pooling/test_gte.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index c4fc3711e6c5..6d2eff709961 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -58,8 +58,8 @@ @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo, - monkeypatch) -> None: +def test_embed_models_mteb(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: vllm_extra_kwargs: dict[str, Any] = {} if model_info.architecture == "GteNewModel": vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} @@ -70,8 +70,8 @@ def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo, @pytest.mark.parametrize("model_info", MODELS) def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, example_prompts, - monkeypatch) -> None: + model_info: EmbedModelInfo, + example_prompts) -> None: vllm_extra_kwargs: dict[str, Any] = {} if model_info.architecture == "GteNewModel": vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}