Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/docs/providers/inference/remote_vllm.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Remote vLLM inference provider for connecting to vLLM servers.
| `api_token` | `str \| None` | No | fake | The API token |
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
| `allow_listing_models` | `<class 'bool'>` | No | True | Whether to allow listing models from the vLLM server |

## Sample Configuration

Expand All @@ -28,4 +29,5 @@ url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true}
```
6 changes: 6 additions & 0 deletions llama_stack/core/routing_tables/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ async def refresh(self) -> None:
await self.update_registered_models(provider_id, models)

async def list_models(self) -> ListModelsResponse:
# Check if providers allow listing models before returning models
for provider_id, provider in self.impls_by_provider_id.items():
allow_listing_models = await provider.allow_listing_models()
logger.debug(f"Provider {provider_id}: allow_listing_models={allow_listing_models}")
if not allow_listing_models:
logger.debug(f"Provider {provider_id} has allow_listing_models disabled")
return ListModelsResponse(data=await self.get_all_with_type("model"))

async def openai_list_models(self) -> OpenAIListModelsResponse:
Expand Down
1 change: 1 addition & 0 deletions llama_stack/distributions/ci-tests/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ providers:
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
Expand Down
1 change: 1 addition & 0 deletions llama_stack/distributions/postgres-demo/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ providers:
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io:
Expand Down
1 change: 1 addition & 0 deletions llama_stack/distributions/starter-gpu/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ providers:
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
Expand Down
1 change: 1 addition & 0 deletions llama_stack/distributions/starter/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ providers:
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
allow_listing_models: ${env.VLLM_ALLOW_LISTING_MODELS:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ async def openai_completion(self, *args, **kwargs):
async def should_refresh_models(self) -> bool:
return False

async def allow_listing_models(self) -> bool:
return True

async def list_models(self) -> list[Model] | None:
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ async def shutdown(self) -> None:
async def should_refresh_models(self) -> bool:
return False

async def allow_listing_models(self) -> bool:
return True

async def list_models(self) -> list[Model] | None:
return [
Model(
Expand Down
5 changes: 5 additions & 0 deletions llama_stack/providers/remote/inference/vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
default=False,
description="Whether to refresh models periodically",
)
allow_listing_models: bool = Field(
default=True,
description="Whether to allow listing models from the vLLM server",
)

@field_validator("tls_verify")
@classmethod
Expand All @@ -59,4 +63,5 @@ def sample_run_config(
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
"api_token": "${env.VLLM_API_TOKEN:=fake}",
"tls_verify": "${env.VLLM_TLS_VERIFY:=true}",
"allow_listing_models": "${env.VLLM_ALLOW_LISTING_MODELS:=true}",
}
49 changes: 35 additions & 14 deletions llama_stack/providers/remote/inference/vllm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,18 @@ async def should_refresh_models(self) -> bool:
# Strictly respecting the refresh_models directive
return self.config.refresh_models

async def allow_listing_models(self) -> bool:
# Respecting the allow_listing_models directive
result = self.config.allow_listing_models
log.debug(f"VLLM allow_listing_models: {result}")
return result

async def list_models(self) -> list[Model] | None:
log.debug(f"VLLM list_models called, allow_listing_models={self.config.allow_listing_models}")
if not self.config.allow_listing_models:
log.debug("VLLM list_models returning None due to allow_listing_models=False")
return None

models = []
async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
Expand Down Expand Up @@ -332,24 +343,34 @@ async def _get_model(self, model_id: str) -> Model:
def get_extra_client_params(self):
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}

async def register_model(self, model: Model) -> Model:
try:
model = await self.register_helper.register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from the vLLM server.

This method respects the allow_listing_models configuration flag.
If allow_listing_models is False, it returns True to allow model registration
without making HTTP requests (trusting that the model exists).

:param model: The model identifier to check.
:return: True if the model is available or if allow_listing_models is False, False otherwise.
"""
# Check if provider allows listing models before making HTTP request
if not self.config.allow_listing_models:
log.debug(
"VLLM check_model_availability returning True due to allow_listing_models=False (trusting model exists)"
)
return True

try:
res = self.client.models.list()
except APIConnectionError as e:
raise ValueError(
f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL."
) from e
log.warning(f"Failed to connect to vLLM at {self.config.url}: {e}")
return False

available_models = [m.id async for m in res]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model {model.provider_resource_id} is not being served by vLLM. "
f"Available models: {', '.join(available_models)}"
)
return model
is_available = model in available_models
log.debug(f"VLLM model {model} availability: {is_available}")
return is_available

async def _get_params(self, request: ChatCompletionRequest) -> dict:
options = get_sampling_options(request.sampling_params)
Expand Down
3 changes: 3 additions & 0 deletions llama_stack/providers/utils/inference/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ async def list_models(self) -> list[Model] | None:
async def should_refresh_models(self) -> bool:
return False

async def allow_listing_models(self) -> bool:
return True

def get_provider_model_id(self, identifier: str) -> str | None:
return self.alias_to_provider_id_map.get(identifier, None)

Expand Down
3 changes: 3 additions & 0 deletions llama_stack/providers/utils/inference/openai_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,6 @@ async def check_model_availability(self, model: str) -> bool:

async def should_refresh_models(self) -> bool:
return False

async def allow_listing_models(self) -> bool:
return True
3 changes: 3 additions & 0 deletions tests/unit/distribution/routers/test_routing_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ async def unregister_model(self, model_id: str):
async def should_refresh_models(self):
return False

async def allow_listing_models(self):
return True

async def list_models(self):
return [
Model(
Expand Down
86 changes: 75 additions & 11 deletions tests/unit/providers/inference/test_remote_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,27 +636,75 @@ async def test_should_refresh_models():
Test the should_refresh_models method with different refresh_models configurations.

This test verifies that:
1. When refresh_models is True, should_refresh_models returns True regardless of api_token
2. When refresh_models is False, should_refresh_models returns False regardless of api_token
1. When refresh_models is True, should_refresh_models returns True
2. When refresh_models is False, should_refresh_models returns False
"""

# Test case 1: refresh_models is True, api_token is None
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True)
# Test case 1: refresh_models is True
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", refresh_models=True)
adapter1 = VLLMInferenceAdapter(config1)
result1 = await adapter1.should_refresh_models()
assert result1 is True, "should_refresh_models should return True when refresh_models is True"

# Test case 2: refresh_models is True, api_token is empty string
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True)
# Test case 2: refresh_models is False
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", refresh_models=False)
adapter2 = VLLMInferenceAdapter(config2)
result2 = await adapter2.should_refresh_models()
assert result2 is True, "should_refresh_models should return True when refresh_models is True"
assert result2 is False, "should_refresh_models should return False when refresh_models is False"

# Test case 3: refresh_models is True, api_token is "fake" (default)
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True)

async def test_allow_listing_models_flag():
"""
Test the allow_listing_models flag functionality.

This test verifies that:
1. When allow_listing_models is True (default), list_models returns models from the server
2. When allow_listing_models is False, list_models returns None without calling the server
"""

# Test case 1: allow_listing_models is True (default)
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", allow_listing_models=True)
adapter1 = VLLMInferenceAdapter(config1)
adapter1.__provider_id__ = "test-vllm"

# Mock the client.models.list() method
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()

async def mock_models_list():
yield OpenAIModel(id="test-model-1", created=1, object="model", owned_by="test")
yield OpenAIModel(id="test-model-2", created=2, object="model", owned_by="test")

mock_client.models.list.return_value = mock_models_list()
mock_client_property.return_value = mock_client

models = await adapter1.list_models()
assert models is not None, "list_models should return models when allow_listing_models is True"
assert len(models) == 2, "Should return 2 models"
assert models[0].identifier == "test-model-1"
assert models[1].identifier == "test-model-2"
mock_client.models.list.assert_called_once()

# Test case 2: allow_listing_models is False
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", allow_listing_models=False)
adapter2 = VLLMInferenceAdapter(config2)
adapter2.__provider_id__ = "test-vllm"

# Mock the client.models.list() method
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()
mock_client_property.return_value = mock_client

models = await adapter2.list_models()
assert models is None, "list_models should return None when allow_listing_models is False"
mock_client.models.list.assert_not_called()

# Test case 3: allow_listing_models defaults to True
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost")
adapter3 = VLLMInferenceAdapter(config3)
result3 = await adapter3.should_refresh_models()
assert result3 is True, "should_refresh_models should return True when refresh_models is True"
adapter3.__provider_id__ = "test-vllm"
result3 = await adapter3.allow_listing_models()
assert result3 is True, "allow_listing_models should return True by default"

# Test case 4: refresh_models is True, api_token is real token
config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True)
Expand All @@ -670,6 +718,22 @@ async def test_should_refresh_models():
result5 = await adapter5.should_refresh_models()
assert result5 is False, "should_refresh_models should return False when refresh_models is False"

# Mock the client.models.list() method
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()

async def mock_models_list():
yield OpenAIModel(id="default-model", created=1, object="model", owned_by="test")

mock_client.models.list.return_value = mock_models_list()
mock_client_property.return_value = mock_client

models = await adapter3.list_models()
assert models is not None, "list_models should return models when allow_listing_models defaults to True"
assert len(models) == 1, "Should return 1 model"
assert models[0].identifier == "default-model"
mock_client.models.list.assert_called_once()


async def test_provider_data_var_context_propagation(vllm_inference_adapter):
"""
Expand Down
1 change: 1 addition & 0 deletions tests/unit/server/test_access_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ async def test_setup(cached_disk_dist_registry):
mock_inference.__provider_spec__ = MagicMock()
mock_inference.__provider_spec__.api = Api.inference
mock_inference.register_model = AsyncMock(side_effect=_return_model)
mock_inference.allow_listing_models = AsyncMock(return_value=True)
routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference},
dist_registry=cached_disk_dist_registry,
Expand Down
Loading