From e9214f9004b9d08b94fa2ffe88c3da74f0a88fc5 Mon Sep 17 00:00:00 2001 From: Akram Ben Aissi Date: Sat, 4 Oct 2025 00:17:53 +0200 Subject: [PATCH 1/3] feat: Add allow_listing_models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add allow_listing_models configuration flag to VLLM provider to control model listing behavior • Implement allow_listing_models() method across all providers with default implementations in base classes • Prevent HTTP requests to /v1/models endpoint when allow_listing_models=false for improved security and performance • Fix unit tests to include allow_listing_models method in test classes and mock objects --- docs/docs/providers/inference/remote_vllm.mdx | 2 + llama_stack/core/routing_tables/models.py | 6 ++ llama_stack/distributions/ci-tests/run.yaml | 1 + .../distributions/postgres-demo/run.yaml | 1 + .../distributions/starter-gpu/run.yaml | 1 + llama_stack/distributions/starter/run.yaml | 1 + .../inference/meta_reference/inference.py | 3 + .../sentence_transformers.py | 3 + .../providers/remote/inference/vllm/config.py | 5 ++ .../providers/remote/inference/vllm/vllm.py | 49 ++++++++--- .../utils/inference/model_registry.py | 3 + .../providers/utils/inference/openai_mixin.py | 3 + .../routers/test_routing_tables.py | 3 + .../providers/inference/test_remote_vllm.py | 86 ++++++++++++++++--- tests/unit/server/test_access_control.py | 1 + 15 files changed, 143 insertions(+), 25 deletions(-) diff --git a/docs/docs/providers/inference/remote_vllm.mdx b/docs/docs/providers/inference/remote_vllm.mdx index 598f97b198..efa863016e 100644 --- a/docs/docs/providers/inference/remote_vllm.mdx +++ b/docs/docs/providers/inference/remote_vllm.mdx @@ -20,6 +20,7 @@ Remote vLLM inference provider for connecting to vLLM servers. | `api_token` | `str \| None` | No | fake | The API token | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | | `refresh_models` | `` | No | False | Whether to refresh models periodically | +| `allow_listing_models` | `` | No | True | Whether to allow listing models from the vLLM server | ## Sample Configuration @@ -28,4 +29,5 @@ url: ${env.VLLM_URL:=} max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} +allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true} ``` diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index 641c73c167..0a98506742 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -43,6 +43,12 @@ async def refresh(self) -> None: await self.update_registered_models(provider_id, models) async def list_models(self) -> ListModelsResponse: + # Check if providers allow listing models before returning models + for provider_id, provider in self.impls_by_provider_id.items(): + allow_listing_models = await provider.allow_listing_models() + logger.debug(f"Provider {provider_id}: allow_listing_models={allow_listing_models}") + if not allow_listing_models: + logger.debug(f"Provider {provider_id} has allow_listing_models disabled") return ListModelsResponse(data=await self.get_all_with_type("model")) async def openai_list_models(self) -> OpenAIListModelsResponse: diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index b14477a9ad..e70c11100a 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -31,6 +31,7 @@ providers: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} + allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true} - provider_id: ${env.TGI_URL:+tgi} provider_type: remote::tgi config: diff --git a/llama_stack/distributions/postgres-demo/run.yaml b/llama_stack/distributions/postgres-demo/run.yaml index 0cf0e82e6a..67691e5cff 100644 --- a/llama_stack/distributions/postgres-demo/run.yaml +++ b/llama_stack/distributions/postgres-demo/run.yaml @@ -16,6 +16,7 @@ providers: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} + allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true} - provider_id: sentence-transformers provider_type: inline::sentence-transformers vector_io: diff --git a/llama_stack/distributions/starter-gpu/run.yaml b/llama_stack/distributions/starter-gpu/run.yaml index de5fe56811..fb29f7407d 100644 --- a/llama_stack/distributions/starter-gpu/run.yaml +++ b/llama_stack/distributions/starter-gpu/run.yaml @@ -31,6 +31,7 @@ providers: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} + allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true} - provider_id: ${env.TGI_URL:+tgi} provider_type: remote::tgi config: diff --git a/llama_stack/distributions/starter/run.yaml b/llama_stack/distributions/starter/run.yaml index c440e4e4b6..d338944bb4 100644 --- a/llama_stack/distributions/starter/run.yaml +++ b/llama_stack/distributions/starter/run.yaml @@ -31,6 +31,7 @@ providers: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} + allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true} - provider_id: ${env.TGI_URL:+tgi} provider_type: remote::tgi config: diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index fd65fa10d3..3c003bbcaf 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -71,6 +71,9 @@ async def openai_completion(self, *args, **kwargs): async def should_refresh_models(self) -> bool: return False + async def allow_listing_models(self) -> bool: + return True + async def list_models(self) -> list[Model] | None: return None diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index b984d97bf1..542c4bcebe 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -52,6 +52,9 @@ async def shutdown(self) -> None: async def should_refresh_models(self) -> bool: return False + async def allow_listing_models(self) -> bool: + return True + async def list_models(self) -> list[Model] | None: return [ Model( diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py index 86ef3fe268..3277188006 100644 --- a/llama_stack/providers/remote/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -34,6 +34,10 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig): default=False, description="Whether to refresh models periodically", ) + allow_listing_models: bool = Field( + default=True, + description="Whether to allow listing models from the vLLM server", + ) @field_validator("tls_verify") @classmethod @@ -59,4 +63,5 @@ def sample_run_config( "max_tokens": "${env.VLLM_MAX_TOKENS:=4096}", "api_token": "${env.VLLM_API_TOKEN:=fake}", "tls_verify": "${env.VLLM_TLS_VERIFY:=true}", + "allow_listing_models": "${env.VLLM_ALLOW_LISTING_MODELS:=true}", } diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 54ac8e1dc0..87a78c0f11 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -282,7 +282,18 @@ async def should_refresh_models(self) -> bool: # Strictly respecting the refresh_models directive return self.config.refresh_models + async def allow_listing_models(self) -> bool: + # Respecting the allow_listing_models directive + result = self.config.allow_listing_models + log.debug(f"VLLM allow_listing_models: {result}") + return result + async def list_models(self) -> list[Model] | None: + log.debug(f"VLLM list_models called, allow_listing_models={self.config.allow_listing_models}") + if not self.config.allow_listing_models: + log.debug("VLLM list_models returning None due to allow_listing_models=False") + return None + models = [] async for m in self.client.models.list(): model_type = ModelType.llm # unclear how to determine embedding vs. llm models @@ -332,24 +343,34 @@ async def _get_model(self, model_id: str) -> Model: def get_extra_client_params(self): return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)} - async def register_model(self, model: Model) -> Model: - try: - model = await self.register_helper.register_model(model) - except ValueError: - pass # Ignore statically unknown model, will check live listing + async def check_model_availability(self, model: str) -> bool: + """ + Check if a specific model is available from the vLLM server. + + This method respects the allow_listing_models configuration flag. + If allow_listing_models is False, it returns True to allow model registration + without making HTTP requests (trusting that the model exists). + + :param model: The model identifier to check. + :return: True if the model is available or if allow_listing_models is False, False otherwise. + """ + # Check if provider allows listing models before making HTTP request + if not self.config.allow_listing_models: + log.debug( + "VLLM check_model_availability returning True due to allow_listing_models=False (trusting model exists)" + ) + return True + try: res = self.client.models.list() except APIConnectionError as e: - raise ValueError( - f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL." - ) from e + log.warning(f"Failed to connect to vLLM at {self.config.url}: {e}") + return False + available_models = [m.id async for m in res] - if model.provider_resource_id not in available_models: - raise ValueError( - f"Model {model.provider_resource_id} is not being served by vLLM. " - f"Available models: {', '.join(available_models)}" - ) - return model + is_available = model in available_models + log.debug(f"VLLM model {model} availability: {is_available}") + return is_available async def _get_params(self, request: ChatCompletionRequest) -> dict: options = get_sampling_options(request.sampling_params) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 4913c2e1fb..4a4fa7adfc 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -100,6 +100,9 @@ async def list_models(self) -> list[Model] | None: async def should_refresh_models(self) -> bool: return False + async def allow_listing_models(self) -> bool: + return True + def get_provider_model_id(self, identifier: str) -> str | None: return self.alias_to_provider_id_map.get(identifier, None) diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 4354b067e6..4164edfa37 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -425,3 +425,6 @@ async def check_model_availability(self, model: str) -> bool: async def should_refresh_models(self) -> bool: return False + + async def allow_listing_models(self) -> bool: + return True diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 54a9dd72e2..524d650da6 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -52,6 +52,9 @@ async def unregister_model(self, model_id: str): async def should_refresh_models(self): return False + async def allow_listing_models(self): + return True + async def list_models(self): return [ Model( diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index cd31e4943d..6675a59014 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -636,27 +636,75 @@ async def test_should_refresh_models(): Test the should_refresh_models method with different refresh_models configurations. This test verifies that: - 1. When refresh_models is True, should_refresh_models returns True regardless of api_token - 2. When refresh_models is False, should_refresh_models returns False regardless of api_token + 1. When refresh_models is True, should_refresh_models returns True + 2. When refresh_models is False, should_refresh_models returns False """ - # Test case 1: refresh_models is True, api_token is None - config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True) + # Test case 1: refresh_models is True + config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", refresh_models=True) adapter1 = VLLMInferenceAdapter(config1) result1 = await adapter1.should_refresh_models() assert result1 is True, "should_refresh_models should return True when refresh_models is True" - # Test case 2: refresh_models is True, api_token is empty string - config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True) + # Test case 2: refresh_models is False + config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", refresh_models=False) adapter2 = VLLMInferenceAdapter(config2) result2 = await adapter2.should_refresh_models() - assert result2 is True, "should_refresh_models should return True when refresh_models is True" + assert result2 is False, "should_refresh_models should return False when refresh_models is False" - # Test case 3: refresh_models is True, api_token is "fake" (default) - config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True) + +async def test_allow_listing_models_flag(): + """ + Test the allow_listing_models flag functionality. + + This test verifies that: + 1. When allow_listing_models is True (default), list_models returns models from the server + 2. When allow_listing_models is False, list_models returns None without calling the server + """ + + # Test case 1: allow_listing_models is True (default) + config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", allow_listing_models=True) + adapter1 = VLLMInferenceAdapter(config1) + adapter1.__provider_id__ = "test-vllm" + + # Mock the client.models.list() method + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property: + mock_client = MagicMock() + + async def mock_models_list(): + yield OpenAIModel(id="test-model-1", created=1, object="model", owned_by="test") + yield OpenAIModel(id="test-model-2", created=2, object="model", owned_by="test") + + mock_client.models.list.return_value = mock_models_list() + mock_client_property.return_value = mock_client + + models = await adapter1.list_models() + assert models is not None, "list_models should return models when allow_listing_models is True" + assert len(models) == 2, "Should return 2 models" + assert models[0].identifier == "test-model-1" + assert models[1].identifier == "test-model-2" + mock_client.models.list.assert_called_once() + + # Test case 2: allow_listing_models is False + config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", allow_listing_models=False) + adapter2 = VLLMInferenceAdapter(config2) + adapter2.__provider_id__ = "test-vllm" + + # Mock the client.models.list() method + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property: + mock_client = MagicMock() + mock_client_property.return_value = mock_client + + models = await adapter2.list_models() + assert models is None, "list_models should return None when allow_listing_models is False" + mock_client.models.list.assert_not_called() + + # Test case 3: allow_listing_models defaults to True + config3 = VLLMInferenceAdapterConfig(url="http://test.localhost") adapter3 = VLLMInferenceAdapter(config3) - result3 = await adapter3.should_refresh_models() - assert result3 is True, "should_refresh_models should return True when refresh_models is True" + adapter3.__provider_id__ = "test-vllm" + result3 = await adapter3.allow_listing_models() + assert result3 is True, "allow_listing_models should return True by default" # Test case 4: refresh_models is True, api_token is real token config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True) @@ -670,6 +718,22 @@ async def test_should_refresh_models(): result5 = await adapter5.should_refresh_models() assert result5 is False, "should_refresh_models should return False when refresh_models is False" + # Mock the client.models.list() method + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property: + mock_client = MagicMock() + + async def mock_models_list(): + yield OpenAIModel(id="default-model", created=1, object="model", owned_by="test") + + mock_client.models.list.return_value = mock_models_list() + mock_client_property.return_value = mock_client + + models = await adapter3.list_models() + assert models is not None, "list_models should return models when allow_listing_models defaults to True" + assert len(models) == 1, "Should return 1 model" + assert models[0].identifier == "default-model" + mock_client.models.list.assert_called_once() + async def test_provider_data_var_context_propagation(vllm_inference_adapter): """ diff --git a/tests/unit/server/test_access_control.py b/tests/unit/server/test_access_control.py index 55449804a0..8752dfc282 100644 --- a/tests/unit/server/test_access_control.py +++ b/tests/unit/server/test_access_control.py @@ -32,6 +32,7 @@ async def test_setup(cached_disk_dist_registry): mock_inference.__provider_spec__ = MagicMock() mock_inference.__provider_spec__.api = Api.inference mock_inference.register_model = AsyncMock(side_effect=_return_model) + mock_inference.allow_listing_models = AsyncMock(return_value=True) routing_table = ModelsRoutingTable( impls_by_provider_id={"test_provider": mock_inference}, dist_registry=cached_disk_dist_registry, From e28bc936351772135719e16ba8f5e6fd5506a85a Mon Sep 17 00:00:00 2001 From: Akram Ben Aissi Date: Mon, 6 Oct 2025 12:56:05 +0200 Subject: [PATCH 2/3] Improve VLLM model discovery error handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add comprehensive error handling in check_model_availability method • Provide helpful error messages with actionable solutions for 404 errors • Warn when API token is set but model discovery is disabled --- docs/docs/providers/inference/remote_vllm.mdx | 4 +- llama_stack/core/routing_tables/models.py | 10 ++-- llama_stack/distributions/ci-tests/run.yaml | 2 +- .../distributions/postgres-demo/run.yaml | 2 +- .../distributions/starter-gpu/run.yaml | 2 +- llama_stack/distributions/starter/run.yaml | 2 +- .../inference/meta_reference/inference.py | 2 +- .../sentence_transformers.py | 2 +- .../providers/remote/inference/vllm/config.py | 6 +-- .../providers/remote/inference/vllm/vllm.py | 51 +++++++++++++------ .../utils/inference/model_registry.py | 2 +- .../providers/utils/inference/openai_mixin.py | 2 +- .../routers/test_routing_tables.py | 2 +- .../providers/inference/test_remote_vllm.py | 28 +++++----- tests/unit/server/test_access_control.py | 2 +- 15 files changed, 69 insertions(+), 50 deletions(-) diff --git a/docs/docs/providers/inference/remote_vllm.mdx b/docs/docs/providers/inference/remote_vllm.mdx index efa863016e..884ca8922e 100644 --- a/docs/docs/providers/inference/remote_vllm.mdx +++ b/docs/docs/providers/inference/remote_vllm.mdx @@ -20,7 +20,7 @@ Remote vLLM inference provider for connecting to vLLM servers. | `api_token` | `str \| None` | No | fake | The API token | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | | `refresh_models` | `` | No | False | Whether to refresh models periodically | -| `allow_listing_models` | `` | No | True | Whether to allow listing models from the vLLM server | +| `enable_model_discovery` | `` | No | True | Whether to enable model discovery from the vLLM server | ## Sample Configuration @@ -29,5 +29,5 @@ url: ${env.VLLM_URL:=} max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} -allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true} +enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true} ``` diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index 0a98506742..e751551582 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -43,12 +43,12 @@ async def refresh(self) -> None: await self.update_registered_models(provider_id, models) async def list_models(self) -> ListModelsResponse: - # Check if providers allow listing models before returning models + # Check if providers enable model discovery before returning models for provider_id, provider in self.impls_by_provider_id.items(): - allow_listing_models = await provider.allow_listing_models() - logger.debug(f"Provider {provider_id}: allow_listing_models={allow_listing_models}") - if not allow_listing_models: - logger.debug(f"Provider {provider_id} has allow_listing_models disabled") + enable_model_discovery = await provider.enable_model_discovery() + logger.debug(f"Provider {provider_id}: enable_model_discovery={enable_model_discovery}") + if not enable_model_discovery: + logger.debug(f"Provider {provider_id} has enable_model_discovery disabled") return ListModelsResponse(data=await self.get_all_with_type("model")) async def openai_list_models(self) -> OpenAIListModelsResponse: diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index e70c11100a..81c947f569 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -31,7 +31,7 @@ providers: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} - allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true} + enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true} - provider_id: ${env.TGI_URL:+tgi} provider_type: remote::tgi config: diff --git a/llama_stack/distributions/postgres-demo/run.yaml b/llama_stack/distributions/postgres-demo/run.yaml index 67691e5cff..98e784e764 100644 --- a/llama_stack/distributions/postgres-demo/run.yaml +++ b/llama_stack/distributions/postgres-demo/run.yaml @@ -16,7 +16,7 @@ providers: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} - allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true} + enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true} - provider_id: sentence-transformers provider_type: inline::sentence-transformers vector_io: diff --git a/llama_stack/distributions/starter-gpu/run.yaml b/llama_stack/distributions/starter-gpu/run.yaml index fb29f7407d..187e3ccde4 100644 --- a/llama_stack/distributions/starter-gpu/run.yaml +++ b/llama_stack/distributions/starter-gpu/run.yaml @@ -31,7 +31,7 @@ providers: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} - allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true} + enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true} - provider_id: ${env.TGI_URL:+tgi} provider_type: remote::tgi config: diff --git a/llama_stack/distributions/starter/run.yaml b/llama_stack/distributions/starter/run.yaml index d338944bb4..d02bd439d9 100644 --- a/llama_stack/distributions/starter/run.yaml +++ b/llama_stack/distributions/starter/run.yaml @@ -31,7 +31,7 @@ providers: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} - allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true} + enable_model_discovery: ${env.VLLM_ENABLE_MODEL_DISCOVERY:=true} - provider_id: ${env.TGI_URL:+tgi} provider_type: remote::tgi config: diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 3c003bbcaf..f272040c00 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -71,7 +71,7 @@ async def openai_completion(self, *args, **kwargs): async def should_refresh_models(self) -> bool: return False - async def allow_listing_models(self) -> bool: + async def enable_model_discovery(self) -> bool: return True async def list_models(self) -> list[Model] | None: diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 542c4bcebe..3dd5d2b899 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -52,7 +52,7 @@ async def shutdown(self) -> None: async def should_refresh_models(self) -> bool: return False - async def allow_listing_models(self) -> bool: + async def enable_model_discovery(self) -> bool: return True async def list_models(self) -> list[Model] | None: diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py index 3277188006..3887107ddb 100644 --- a/llama_stack/providers/remote/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -34,9 +34,9 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig): default=False, description="Whether to refresh models periodically", ) - allow_listing_models: bool = Field( + enable_model_discovery: bool = Field( default=True, - description="Whether to allow listing models from the vLLM server", + description="Whether to enable model discovery from the vLLM server", ) @field_validator("tls_verify") @@ -63,5 +63,5 @@ def sample_run_config( "max_tokens": "${env.VLLM_MAX_TOKENS:=4096}", "api_token": "${env.VLLM_API_TOKEN:=fake}", "tls_verify": "${env.VLLM_TLS_VERIFY:=true}", - "allow_listing_models": "${env.VLLM_ALLOW_LISTING_MODELS:=true}", + "enable_model_discovery": "${env.VLLM_ENABLE_MODEL_DISCOVERY:=true}", } diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 87a78c0f11..305990f061 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -282,16 +282,16 @@ async def should_refresh_models(self) -> bool: # Strictly respecting the refresh_models directive return self.config.refresh_models - async def allow_listing_models(self) -> bool: - # Respecting the allow_listing_models directive - result = self.config.allow_listing_models - log.debug(f"VLLM allow_listing_models: {result}") + async def enable_model_discovery(self) -> bool: + # Respecting the enable_model_discovery directive + result = self.config.enable_model_discovery + log.debug(f"VLLM enable_model_discovery: {result}") return result async def list_models(self) -> list[Model] | None: - log.debug(f"VLLM list_models called, allow_listing_models={self.config.allow_listing_models}") - if not self.config.allow_listing_models: - log.debug("VLLM list_models returning None due to allow_listing_models=False") + log.debug(f"VLLM list_models called, enable_model_discovery={self.config.enable_model_discovery}") + if not self.config.enable_model_discovery: + log.debug("VLLM list_models returning None due to enable_model_discovery=False") return None models = [] @@ -347,18 +347,23 @@ async def check_model_availability(self, model: str) -> bool: """ Check if a specific model is available from the vLLM server. - This method respects the allow_listing_models configuration flag. - If allow_listing_models is False, it returns True to allow model registration + This method respects the enable_model_discovery configuration flag. + If enable_model_discovery is False, it returns True to allow model registration without making HTTP requests (trusting that the model exists). :param model: The model identifier to check. - :return: True if the model is available or if allow_listing_models is False, False otherwise. + :return: True if the model is available or if enable_model_discovery is False, False otherwise. """ - # Check if provider allows listing models before making HTTP request - if not self.config.allow_listing_models: - log.debug( - "VLLM check_model_availability returning True due to allow_listing_models=False (trusting model exists)" - ) + # Check if provider enables model discovery before making HTTP request + if not self.config.enable_model_discovery: + log.debug("Model discovery disabled for vLLM: Trusting model exists") + # Warn if API key is set but model discovery is disabled + if self.config.api_token: + log.warning( + "Model discovery is disabled but VLLM_API_TOKEN is set. " + "If you're not using model discovery, you may not need to set the API token. " + "Consider removing VLLM_API_TOKEN from your configuration or setting enable_model_discovery=true." + ) return True try: @@ -367,7 +372,21 @@ async def check_model_availability(self, model: str) -> bool: log.warning(f"Failed to connect to vLLM at {self.config.url}: {e}") return False - available_models = [m.id async for m in res] + try: + available_models = [m.id async for m in res] + except Exception as e: + # Provide helpful error message for model discovery failures + log.error(f"Model discovery failed with the following output from vLLM server: {e}.\n") + log.error( + f"Model discovery failed: This typically occurs when a provider (like vLLM) is configured " + f"with model discovery enabled but the provider server doesn't support the /models endpoint. " + f"To resolve this, either:\n" + f"1. Check that {self.config.url} correctly points to the vLLM server, or\n" + f"2. Ensure your provider server supports the /v1/models endpoint and if authenticated that VLLM_API_TOKEN is set, or\n" + f"3. Set enable_model_discovery=false for the problematic provider in your configuration\n" + ) + return False + is_available = model in available_models log.debug(f"VLLM model {model} availability: {is_available}") return is_available diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 4a4fa7adfc..17c43fe388 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -100,7 +100,7 @@ async def list_models(self) -> list[Model] | None: async def should_refresh_models(self) -> bool: return False - async def allow_listing_models(self) -> bool: + async def enable_model_discovery(self) -> bool: return True def get_provider_model_id(self, identifier: str) -> str | None: diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 4164edfa37..a00a45963e 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -426,5 +426,5 @@ async def check_model_availability(self, model: str) -> bool: async def should_refresh_models(self) -> bool: return False - async def allow_listing_models(self) -> bool: + async def enable_model_discovery(self) -> bool: return True diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 524d650da6..ec7aca27e4 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -52,7 +52,7 @@ async def unregister_model(self, model_id: str): async def should_refresh_models(self): return False - async def allow_listing_models(self): + async def enable_model_discovery(self): return True async def list_models(self): diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 6675a59014..701282179f 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -653,17 +653,17 @@ async def test_should_refresh_models(): assert result2 is False, "should_refresh_models should return False when refresh_models is False" -async def test_allow_listing_models_flag(): +async def test_enable_model_discovery_flag(): """ - Test the allow_listing_models flag functionality. + Test the enable_model_discovery flag functionality. This test verifies that: - 1. When allow_listing_models is True (default), list_models returns models from the server - 2. When allow_listing_models is False, list_models returns None without calling the server + 1. When enable_model_discovery is True (default), list_models returns models from the server + 2. When enable_model_discovery is False, list_models returns None without calling the server """ - # Test case 1: allow_listing_models is True (default) - config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", allow_listing_models=True) + # Test case 1: enable_model_discovery is True (default) + config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", enable_model_discovery=True) adapter1 = VLLMInferenceAdapter(config1) adapter1.__provider_id__ = "test-vllm" @@ -679,14 +679,14 @@ async def mock_models_list(): mock_client_property.return_value = mock_client models = await adapter1.list_models() - assert models is not None, "list_models should return models when allow_listing_models is True" + assert models is not None, "list_models should return models when enable_model_discovery is True" assert len(models) == 2, "Should return 2 models" assert models[0].identifier == "test-model-1" assert models[1].identifier == "test-model-2" mock_client.models.list.assert_called_once() - # Test case 2: allow_listing_models is False - config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", allow_listing_models=False) + # Test case 2: enable_model_discovery is False + config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", enable_model_discovery=False) adapter2 = VLLMInferenceAdapter(config2) adapter2.__provider_id__ = "test-vllm" @@ -696,15 +696,15 @@ async def mock_models_list(): mock_client_property.return_value = mock_client models = await adapter2.list_models() - assert models is None, "list_models should return None when allow_listing_models is False" + assert models is None, "list_models should return None when enable_model_discovery is False" mock_client.models.list.assert_not_called() - # Test case 3: allow_listing_models defaults to True + # Test case 3: enable_model_discovery defaults to True config3 = VLLMInferenceAdapterConfig(url="http://test.localhost") adapter3 = VLLMInferenceAdapter(config3) adapter3.__provider_id__ = "test-vllm" - result3 = await adapter3.allow_listing_models() - assert result3 is True, "allow_listing_models should return True by default" + result3 = await adapter3.enable_model_discovery() + assert result3 is True, "enable_model_discovery should return True by default" # Test case 4: refresh_models is True, api_token is real token config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True) @@ -729,7 +729,7 @@ async def mock_models_list(): mock_client_property.return_value = mock_client models = await adapter3.list_models() - assert models is not None, "list_models should return models when allow_listing_models defaults to True" + assert models is not None, "list_models should return models when enable_model_discovery defaults to True" assert len(models) == 1, "Should return 1 model" assert models[0].identifier == "default-model" mock_client.models.list.assert_called_once() diff --git a/tests/unit/server/test_access_control.py b/tests/unit/server/test_access_control.py index 8752dfc282..3cb393e0e9 100644 --- a/tests/unit/server/test_access_control.py +++ b/tests/unit/server/test_access_control.py @@ -32,7 +32,7 @@ async def test_setup(cached_disk_dist_registry): mock_inference.__provider_spec__ = MagicMock() mock_inference.__provider_spec__.api = Api.inference mock_inference.register_model = AsyncMock(side_effect=_return_model) - mock_inference.allow_listing_models = AsyncMock(return_value=True) + mock_inference.enable_model_discovery = AsyncMock(return_value=True) routing_table = ModelsRoutingTable( impls_by_provider_id={"test_provider": mock_inference}, dist_registry=cached_disk_dist_registry, From 055179ad4551b13ed894ff1055eabb1c5e0c101e Mon Sep 17 00:00:00 2001 From: Akram Ben Aissi Date: Mon, 6 Oct 2025 16:39:15 +0200 Subject: [PATCH 3/3] Review changes --- llama_stack/core/routing_tables/models.py | 6 --- .../providers/remote/inference/vllm/vllm.py | 41 ++++++++----------- 2 files changed, 18 insertions(+), 29 deletions(-) diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index e751551582..641c73c167 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -43,12 +43,6 @@ async def refresh(self) -> None: await self.update_registered_models(provider_id, models) async def list_models(self) -> ListModelsResponse: - # Check if providers enable model discovery before returning models - for provider_id, provider in self.impls_by_provider_id.items(): - enable_model_discovery = await provider.enable_model_discovery() - logger.debug(f"Provider {provider_id}: enable_model_discovery={enable_model_discovery}") - if not enable_model_discovery: - logger.debug(f"Provider {provider_id} has enable_model_discovery disabled") return ListModelsResponse(data=await self.get_all_with_type("model")) async def openai_list_models(self) -> OpenAIListModelsResponse: diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 305990f061..3b3edb5933 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -282,12 +282,6 @@ async def should_refresh_models(self) -> bool: # Strictly respecting the refresh_models directive return self.config.refresh_models - async def enable_model_discovery(self) -> bool: - # Respecting the enable_model_discovery directive - result = self.config.enable_model_discovery - log.debug(f"VLLM enable_model_discovery: {result}") - return result - async def list_models(self) -> list[Model] | None: log.debug(f"VLLM list_models called, enable_model_discovery={self.config.enable_model_discovery}") if not self.config.enable_model_discovery: @@ -343,17 +337,12 @@ async def _get_model(self, model_id: str) -> Model: def get_extra_client_params(self): return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)} - async def check_model_availability(self, model: str) -> bool: - """ - Check if a specific model is available from the vLLM server. - - This method respects the enable_model_discovery configuration flag. - If enable_model_discovery is False, it returns True to allow model registration - without making HTTP requests (trusting that the model exists). + async def register_model(self, model: Model) -> Model: + try: + model = await self.register_helper.register_model(model) + except ValueError: + pass # Ignore statically unknown model, will check live listing - :param model: The model identifier to check. - :return: True if the model is available or if enable_model_discovery is False, False otherwise. - """ # Check if provider enables model discovery before making HTTP request if not self.config.enable_model_discovery: log.debug("Model discovery disabled for vLLM: Trusting model exists") @@ -364,13 +353,14 @@ async def check_model_availability(self, model: str) -> bool: "If you're not using model discovery, you may not need to set the API token. " "Consider removing VLLM_API_TOKEN from your configuration or setting enable_model_discovery=true." ) - return True + return model try: res = self.client.models.list() except APIConnectionError as e: - log.warning(f"Failed to connect to vLLM at {self.config.url}: {e}") - return False + raise ValueError( + f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL." + ) from e try: available_models = [m.id async for m in res] @@ -385,11 +375,16 @@ async def check_model_availability(self, model: str) -> bool: f"2. Ensure your provider server supports the /v1/models endpoint and if authenticated that VLLM_API_TOKEN is set, or\n" f"3. Set enable_model_discovery=false for the problematic provider in your configuration\n" ) - return False + raise ValueError( + f"Model discovery failed for vLLM at {self.config.url}. Please check the server configuration and logs." + ) from e - is_available = model in available_models - log.debug(f"VLLM model {model} availability: {is_available}") - return is_available + if model.provider_resource_id not in available_models: + raise ValueError( + f"Model {model.provider_resource_id} is not being served by vLLM. " + f"Available models: {', '.join(available_models)}" + ) + return model async def _get_params(self, request: ChatCompletionRequest) -> dict: options = get_sampling_options(request.sampling_params)