Skip to content

Commit 724dac4

Browse files
authored
chore: give OpenAIMixin subcalsses a change to list models without leaking _model_cache details (#3682)
# What does this PR do? close the _model_cache abstraction leak ## Test Plan ci w/ new tests
1 parent f00bcd9 commit 724dac4

File tree

3 files changed

+164
-39
lines changed

3 files changed

+164
-39
lines changed

llama_stack/providers/remote/inference/databricks/databricks.py

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7+
from collections.abc import Iterable
78
from typing import Any
89

910
from databricks.sdk import WorkspaceClient
1011

1112
from llama_stack.apis.inference import (
1213
Inference,
13-
Model,
1414
OpenAICompletion,
1515
)
16-
from llama_stack.apis.models import ModelType
1716
from llama_stack.log import get_logger
1817
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
1918

@@ -72,31 +71,13 @@ async def openai_completion(
7271
) -> OpenAICompletion:
7372
raise NotImplementedError()
7473

75-
async def list_models(self) -> list[Model] | None:
76-
self._model_cache = {} # from OpenAIMixin
77-
ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async
78-
endpoints = ws_client.serving_endpoints.list()
79-
for endpoint in endpoints:
80-
model = Model(
81-
provider_id=self.__provider_id__,
82-
provider_resource_id=endpoint.name,
83-
identifier=endpoint.name,
84-
)
85-
if endpoint.task == "llm/v1/chat":
86-
model.model_type = ModelType.llm # this is redundant, but informative
87-
elif endpoint.task == "llm/v1/embeddings":
88-
if endpoint.name not in self.embedding_model_metadata:
89-
logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.")
90-
continue
91-
model.model_type = ModelType.embedding
92-
model.metadata = self.embedding_model_metadata[endpoint.name]
93-
else:
94-
logger.warning(f"Unknown model type, skipping: {endpoint}")
95-
continue
96-
97-
self._model_cache[endpoint.name] = model
98-
99-
return list(self._model_cache.values())
74+
async def list_provider_model_ids(self) -> Iterable[str]:
75+
return [
76+
endpoint.name
77+
for endpoint in WorkspaceClient(
78+
host=self.config.url, token=self.get_api_key()
79+
).serving_endpoints.list() # TODO: this is not async
80+
]
10081

10182
async def should_refresh_models(self) -> bool:
10283
return False

llama_stack/providers/utils/inference/openai_mixin.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import base64
88
import uuid
99
from abc import ABC, abstractmethod
10-
from collections.abc import AsyncIterator
10+
from collections.abc import AsyncIterator, Iterable
1111
from typing import Any
1212

1313
from openai import NOT_GIVEN, AsyncOpenAI
@@ -111,6 +111,18 @@ def get_extra_client_params(self) -> dict[str, Any]:
111111
"""
112112
return {}
113113

114+
async def list_provider_model_ids(self) -> Iterable[str]:
115+
"""
116+
List available models from the provider.
117+
118+
Child classes can override this method to provide a custom implementation
119+
for listing models. The default implementation uses the AsyncOpenAI client
120+
to list models from the OpenAI-compatible endpoint.
121+
122+
:return: An iterable of model IDs or None if not implemented
123+
"""
124+
return [m.id async for m in self.client.models.list()]
125+
114126
@property
115127
def client(self) -> AsyncOpenAI:
116128
"""
@@ -387,28 +399,36 @@ async def list_models(self) -> list[Model] | None:
387399
"""
388400
self._model_cache = {}
389401

390-
async for m in self.client.models.list():
391-
if self.allowed_models and m.id not in self.allowed_models:
392-
logger.info(f"Skipping model {m.id} as it is not in the allowed models list")
402+
# give subclasses a chance to provide custom model listing
403+
iterable = await self.list_provider_model_ids()
404+
if not hasattr(iterable, "__iter__"):
405+
raise TypeError(
406+
f"Failed to list models: {self.__class__.__name__}.list_provider_model_ids() must return an iterable of "
407+
f"strings or None, but returned {type(iterable).__name__}"
408+
)
409+
provider_models_ids = list(iterable)
410+
logger.info(f"{self.__class__.__name__}.list_provider_model_ids() returned {len(provider_models_ids)} models")
411+
412+
for provider_model_id in provider_models_ids:
413+
if self.allowed_models and provider_model_id not in self.allowed_models:
414+
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
393415
continue
394-
if metadata := self.embedding_model_metadata.get(m.id):
395-
# This is an embedding model - augment with metadata
416+
if metadata := self.embedding_model_metadata.get(provider_model_id):
396417
model = Model(
397418
provider_id=self.__provider_id__, # type: ignore[attr-defined]
398-
provider_resource_id=m.id,
399-
identifier=m.id,
419+
provider_resource_id=provider_model_id,
420+
identifier=provider_model_id,
400421
model_type=ModelType.embedding,
401422
metadata=metadata,
402423
)
403424
else:
404-
# This is an LLM
405425
model = Model(
406426
provider_id=self.__provider_id__, # type: ignore[attr-defined]
407-
provider_resource_id=m.id,
408-
identifier=m.id,
427+
provider_resource_id=provider_model_id,
428+
identifier=provider_model_id,
409429
model_type=ModelType.llm,
410430
)
411-
self._model_cache[m.id] = model
431+
self._model_cache[provider_model_id] = model
412432

413433
return list(self._model_cache.values())
414434

tests/unit/providers/utils/inference/test_openai_mixin.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# the root directory of this source tree.
66

77
import json
8+
from collections.abc import Iterable
89
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
910

1011
import pytest
@@ -498,6 +499,129 @@ def get_base_url(self):
498499
return "default-base-url"
499500

500501

502+
class CustomListProviderModelIdsImplementation(OpenAIMixinImpl):
503+
"""Test implementation with custom list_provider_model_ids override"""
504+
505+
def __init__(self, custom_model_ids):
506+
self._custom_model_ids = custom_model_ids
507+
508+
async def list_provider_model_ids(self) -> Iterable[str]:
509+
"""Return custom model IDs list"""
510+
return self._custom_model_ids
511+
512+
513+
class TestOpenAIMixinCustomListProviderModelIds:
514+
"""Test cases for custom list_provider_model_ids() implementation functionality"""
515+
516+
@pytest.fixture
517+
def custom_model_ids_list(self):
518+
"""Create a list of custom model ID strings"""
519+
return ["custom-model-1", "custom-model-2", "custom-embedding"]
520+
521+
@pytest.fixture
522+
def adapter(self, custom_model_ids_list):
523+
"""Create mixin instance with custom list_provider_model_ids implementation"""
524+
mixin = CustomListProviderModelIdsImplementation(custom_model_ids=custom_model_ids_list)
525+
mixin.embedding_model_metadata = {"custom-embedding": {"embedding_dimension": 768, "context_length": 512}}
526+
return mixin
527+
528+
async def test_is_used(self, adapter, custom_model_ids_list):
529+
"""Test that custom list_provider_model_ids() implementation is used instead of client.models.list()"""
530+
result = await adapter.list_models()
531+
532+
assert result is not None
533+
assert len(result) == 3
534+
535+
assert set(custom_model_ids_list) == {m.identifier for m in result}
536+
537+
async def test_populates_cache(self, adapter, custom_model_ids_list):
538+
"""Test that custom list_provider_model_ids() results are cached"""
539+
assert len(adapter._model_cache) == 0
540+
541+
await adapter.list_models()
542+
543+
assert set(custom_model_ids_list) == set(adapter._model_cache.keys())
544+
545+
async def test_respects_allowed_models(self):
546+
"""Test that custom list_provider_model_ids() respects allowed_models filtering"""
547+
mixin = CustomListProviderModelIdsImplementation(custom_model_ids=["model-1", "model-2", "model-3"])
548+
mixin.allowed_models = ["model-1"]
549+
550+
result = await mixin.list_models()
551+
552+
assert result is not None
553+
assert len(result) == 1
554+
assert result[0].identifier == "model-1"
555+
556+
async def test_with_empty_list(self):
557+
"""Test that custom list_provider_model_ids() handles empty list correctly"""
558+
mixin = CustomListProviderModelIdsImplementation(custom_model_ids=[])
559+
560+
result = await mixin.list_models()
561+
562+
assert result is not None
563+
assert len(result) == 0
564+
assert len(mixin._model_cache) == 0
565+
566+
async def test_wrong_type_raises_error(self):
567+
"""Test that list_provider_model_ids() returning unhashable items results in an error"""
568+
mixin = CustomListProviderModelIdsImplementation(custom_model_ids=[["nested", "list"], {"key": "value"}])
569+
570+
with pytest.raises(TypeError, match="unhashable type"):
571+
await mixin.list_models()
572+
573+
async def test_non_iterable_raises_error(self):
574+
"""Test that list_provider_model_ids() returning non-iterable type raises error"""
575+
mixin = CustomListProviderModelIdsImplementation(custom_model_ids=42)
576+
577+
with pytest.raises(
578+
TypeError,
579+
match=r"Failed to list models: CustomListProviderModelIdsImplementation\.list_provider_model_ids\(\) must return an iterable.*but returned int",
580+
):
581+
await mixin.list_models()
582+
583+
async def test_with_none_items_raises_error(self):
584+
"""Test that list_provider_model_ids() returning list with None items causes error"""
585+
mixin = CustomListProviderModelIdsImplementation(custom_model_ids=[None, "valid-model", None])
586+
587+
with pytest.raises(Exception, match="Input should be a valid string"):
588+
await mixin.list_models()
589+
590+
async def test_accepts_various_iterables(self):
591+
"""Test that list_provider_model_ids() accepts tuples, sets, generators, etc."""
592+
593+
class TupleAdapter(OpenAIMixinImpl):
594+
async def list_provider_model_ids(self) -> Iterable[str] | None:
595+
return ("model-1", "model-2", "model-3")
596+
597+
mixin = TupleAdapter()
598+
result = await mixin.list_models()
599+
assert result is not None
600+
assert len(result) == 3
601+
602+
class GeneratorAdapter(OpenAIMixinImpl):
603+
async def list_provider_model_ids(self) -> Iterable[str] | None:
604+
def gen():
605+
yield "gen-model-1"
606+
yield "gen-model-2"
607+
608+
return gen()
609+
610+
mixin = GeneratorAdapter()
611+
result = await mixin.list_models()
612+
assert result is not None
613+
assert len(result) == 2
614+
615+
class SetAdapter(OpenAIMixinImpl):
616+
async def list_provider_model_ids(self) -> Iterable[str] | None:
617+
return {"set-model-1", "set-model-2"}
618+
619+
mixin = SetAdapter()
620+
result = await mixin.list_models()
621+
assert result is not None
622+
assert len(result) == 2
623+
624+
501625
class TestOpenAIMixinProviderDataApiKey:
502626
"""Test cases for provider_data_api_key_field functionality"""
503627

0 commit comments

Comments
 (0)