Skip to content

Commit dc97160

Browse files
committed
chore: make OpenAIMixin maintainable, turn OpenAIMixin into a pydantic.BaseModel
- implement get_api_key instead of relying on LiteLLMOpenAIMixin.get_api_key - remove use of LiteLLMOpenAIMixin - add default initialize/shutdown methods to OpenAIMixin - remove __init__s to allow proper pydantic construction - remove dead code from vllm adapter and associated / duplicate unit tests - update vllm adapter to use openaimixin for model registration - remove ModelRegistryHelper from fireworks & together adapters - remove Inference from nvidia adapter - complete type hints on embedding_model_metadata - allow extra fields on OpenAIMixin, for model_store, __provider_id__, etc - new recordings for ollama - enhance the list models error handling w/ new tests - update cerebras (remove cerebras-cloud-sdk) and anthropic (custom model listing) inference adapters
1 parent 351c4b9 commit dc97160

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+12888
-1335
lines changed

llama_stack/providers/registry/inference.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,7 @@ def available_providers() -> list[ProviderSpec]:
5252
api=Api.inference,
5353
adapter_type="cerebras",
5454
provider_type="remote::cerebras",
55-
pip_packages=[
56-
"cerebras_cloud_sdk",
57-
],
55+
pip_packages=[],
5856
module="llama_stack.providers.remote.inference.cerebras",
5957
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
6058
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
@@ -179,7 +177,7 @@ def available_providers() -> list[ProviderSpec]:
179177
api=Api.inference,
180178
adapter_type="anthropic",
181179
provider_type="remote::anthropic",
182-
pip_packages=["litellm"],
180+
pip_packages=["litellm", "anthropic"],
183181
module="llama_stack.providers.remote.inference.anthropic",
184182
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
185183
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",

llama_stack/providers/remote/inference/anthropic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@
1010
async def get_adapter_impl(config: AnthropicConfig, _deps):
1111
from .anthropic import AnthropicInferenceAdapter
1212

13-
impl = AnthropicInferenceAdapter(config)
13+
impl = AnthropicInferenceAdapter(config=config)
1414
await impl.initialize()
1515
return impl

llama_stack/providers/remote/inference/anthropic/anthropic.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,19 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
7+
from collections.abc import Iterable
8+
9+
from anthropic import AsyncAnthropic
10+
811
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
912

1013
from .config import AnthropicConfig
1114

1215

13-
class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
16+
class AnthropicInferenceAdapter(OpenAIMixin):
17+
config: AnthropicConfig
18+
19+
provider_data_api_key_field: str = "anthropic_api_key"
1420
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings
1521
# TODO: add support for voyageai, which is where these models are hosted
1622
# embedding_model_metadata = {
@@ -23,22 +29,11 @@ class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
2329
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
2430
# }
2531

26-
def __init__(self, config: AnthropicConfig) -> None:
27-
LiteLLMOpenAIMixin.__init__(
28-
self,
29-
litellm_provider_name="anthropic",
30-
api_key_from_config=config.api_key,
31-
provider_data_api_key_field="anthropic_api_key",
32-
)
33-
self.config = config
34-
35-
async def initialize(self) -> None:
36-
await super().initialize()
37-
38-
async def shutdown(self) -> None:
39-
await super().shutdown()
40-
41-
get_api_key = LiteLLMOpenAIMixin.get_api_key
32+
def get_api_key(self) -> str:
33+
return self.config.api_key or ""
4234

4335
def get_base_url(self):
4436
return "https://api.anthropic.com/v1"
37+
38+
async def get_models(self) -> Iterable[str] | None:
39+
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]

llama_stack/providers/remote/inference/azure/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@
1010
async def get_adapter_impl(config: AzureConfig, _deps):
1111
from .azure import AzureInferenceAdapter
1212

13-
impl = AzureInferenceAdapter(config)
13+
impl = AzureInferenceAdapter(config=config)
1414
await impl.initialize()
1515
return impl

llama_stack/providers/remote/inference/azure/azure.py

Lines changed: 28 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,20 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from typing import Any
87
from urllib.parse import urljoin
98

10-
from llama_stack.apis.inference import ChatCompletionRequest
11-
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
12-
LiteLLMOpenAIMixin,
13-
)
149
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
1510

1611
from .config import AzureConfig
1712

1813

19-
class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
20-
def __init__(self, config: AzureConfig) -> None:
21-
LiteLLMOpenAIMixin.__init__(
22-
self,
23-
litellm_provider_name="azure",
24-
api_key_from_config=config.api_key.get_secret_value(),
25-
provider_data_api_key_field="azure_api_key",
26-
openai_compat_api_base=str(config.api_base),
27-
)
28-
self.config = config
14+
class AzureInferenceAdapter(OpenAIMixin):
15+
config: AzureConfig
2916

30-
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
31-
get_api_key = LiteLLMOpenAIMixin.get_api_key
17+
provider_data_api_key_field: str = "azure_api_key"
18+
19+
def get_api_key(self) -> str:
20+
return self.config.api_key.get_secret_value()
3221

3322
def get_base_url(self) -> str:
3423
"""
@@ -38,25 +27,25 @@ def get_base_url(self) -> str:
3827
"""
3928
return urljoin(str(self.config.api_base), "/openai/v1")
4029

41-
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
42-
# Get base parameters from parent
43-
params = await super()._get_params(request)
44-
45-
# Add Azure specific parameters
46-
provider_data = self.get_request_provider_data()
47-
if provider_data:
48-
if getattr(provider_data, "azure_api_key", None):
49-
params["api_key"] = provider_data.azure_api_key
50-
if getattr(provider_data, "azure_api_base", None):
51-
params["api_base"] = provider_data.azure_api_base
52-
if getattr(provider_data, "azure_api_version", None):
53-
params["api_version"] = provider_data.azure_api_version
54-
if getattr(provider_data, "azure_api_type", None):
55-
params["api_type"] = provider_data.azure_api_type
56-
else:
57-
params["api_key"] = self.config.api_key.get_secret_value()
58-
params["api_base"] = str(self.config.api_base)
59-
params["api_version"] = self.config.api_version
60-
params["api_type"] = self.config.api_type
61-
62-
return params
30+
# async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
31+
# # Get base parameters from parent
32+
# params = await super()._get_params(request)
33+
34+
# # Add Azure specific parameters
35+
# provider_data = self.get_request_provider_data()
36+
# if provider_data:
37+
# if getattr(provider_data, "azure_api_key", None):
38+
# params["api_key"] = provider_data.azure_api_key
39+
# if getattr(provider_data, "azure_api_base", None):
40+
# params["api_base"] = provider_data.azure_api_base
41+
# if getattr(provider_data, "azure_api_version", None):
42+
# params["api_version"] = provider_data.azure_api_version
43+
# if getattr(provider_data, "azure_api_type", None):
44+
# params["api_type"] = provider_data.azure_api_type
45+
# else:
46+
# params["api_key"] = self.config.api_key.get_secret_value()
47+
# params["api_base"] = str(self.config.api_base)
48+
# params["api_version"] = self.config.api_version
49+
# params["api_type"] = self.config.api_type
50+
51+
# return params

llama_stack/providers/remote/inference/cerebras/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ async def get_adapter_impl(config: CerebrasImplConfig, _deps):
1212

1313
assert isinstance(config, CerebrasImplConfig), f"Unexpected config type: {type(config)}"
1414

15-
impl = CerebrasInferenceAdapter(config)
15+
impl = CerebrasInferenceAdapter(config=config)
1616

1717
await impl.initialize()
1818

llama_stack/providers/remote/inference/cerebras/cerebras.py

Lines changed: 3 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -6,71 +6,21 @@
66

77
from urllib.parse import urljoin
88

9-
from cerebras.cloud.sdk import AsyncCerebras
10-
11-
from llama_stack.apis.inference import (
12-
ChatCompletionRequest,
13-
CompletionRequest,
14-
Inference,
15-
OpenAIEmbeddingsResponse,
16-
TopKSamplingStrategy,
17-
)
18-
from llama_stack.providers.utils.inference.openai_compat import (
19-
get_sampling_options,
20-
)
9+
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
2110
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
22-
from llama_stack.providers.utils.inference.prompt_adapter import (
23-
chat_completion_request_to_prompt,
24-
completion_request_to_prompt,
25-
)
2611

2712
from .config import CerebrasImplConfig
2813

2914

30-
class CerebrasInferenceAdapter(
31-
OpenAIMixin,
32-
Inference,
33-
):
34-
def __init__(self, config: CerebrasImplConfig) -> None:
35-
self.config = config
36-
37-
# TODO: make this use provider data, etc. like other providers
38-
self._cerebras_client = AsyncCerebras(
39-
base_url=self.config.base_url,
40-
api_key=self.config.api_key.get_secret_value(),
41-
)
15+
class CerebrasInferenceAdapter(OpenAIMixin):
16+
config: CerebrasImplConfig
4217

4318
def get_api_key(self) -> str:
4419
return self.config.api_key.get_secret_value()
4520

4621
def get_base_url(self) -> str:
4722
return urljoin(self.config.base_url, "v1")
4823

49-
async def initialize(self) -> None:
50-
return
51-
52-
async def shutdown(self) -> None:
53-
pass
54-
55-
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
56-
if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
57-
raise ValueError("`top_k` not supported by Cerebras")
58-
59-
prompt = ""
60-
if isinstance(request, ChatCompletionRequest):
61-
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
62-
elif isinstance(request, CompletionRequest):
63-
prompt = await completion_request_to_prompt(request)
64-
else:
65-
raise ValueError(f"Unknown request type {type(request)}")
66-
67-
return {
68-
"model": request.model,
69-
"prompt": prompt,
70-
"stream": request.stream,
71-
**get_sampling_options(request.sampling_params),
72-
}
73-
7424
async def openai_embeddings(
7525
self,
7626
model: str,

llama_stack/providers/remote/inference/databricks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@ async def get_adapter_impl(config: DatabricksImplConfig, _deps):
1111
from .databricks import DatabricksInferenceAdapter
1212

1313
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
14-
impl = DatabricksInferenceAdapter(config)
14+
impl = DatabricksInferenceAdapter(config=config)
1515
await impl.initialize()
1616
return impl

llama_stack/providers/remote/inference/databricks/databricks.py

Lines changed: 11 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@
99
from databricks.sdk import WorkspaceClient
1010

1111
from llama_stack.apis.inference import (
12-
Inference,
1312
Model,
1413
OpenAICompletion,
1514
)
16-
from llama_stack.apis.models import ModelType
1715
from llama_stack.log import get_logger
1816
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
1917

@@ -22,31 +20,21 @@
2220
logger = get_logger(name=__name__, category="inference::databricks")
2321

2422

25-
class DatabricksInferenceAdapter(
26-
OpenAIMixin,
27-
Inference,
28-
):
23+
class DatabricksInferenceAdapter(OpenAIMixin):
24+
config: DatabricksImplConfig
25+
2926
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
30-
embedding_model_metadata = {
27+
embedding_model_metadata: dict[str, dict[str, int]] = {
3128
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
3229
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
3330
}
3431

35-
def __init__(self, config: DatabricksImplConfig) -> None:
36-
self.config = config
37-
3832
def get_api_key(self) -> str:
3933
return self.config.api_token.get_secret_value()
4034

4135
def get_base_url(self) -> str:
4236
return f"{self.config.url}/serving-endpoints"
4337

44-
async def initialize(self) -> None:
45-
return
46-
47-
async def shutdown(self) -> None:
48-
pass
49-
5038
async def openai_completion(
5139
self,
5240
model: str,
@@ -72,31 +60,13 @@ async def openai_completion(
7260
) -> OpenAICompletion:
7361
raise NotImplementedError()
7462

75-
async def list_models(self) -> list[Model] | None:
76-
self._model_cache = {} # from OpenAIMixin
77-
ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async
78-
endpoints = ws_client.serving_endpoints.list()
79-
for endpoint in endpoints:
80-
model = Model(
81-
provider_id=self.__provider_id__,
82-
provider_resource_id=endpoint.name,
83-
identifier=endpoint.name,
84-
)
85-
if endpoint.task == "llm/v1/chat":
86-
model.model_type = ModelType.llm # this is redundant, but informative
87-
elif endpoint.task == "llm/v1/embeddings":
88-
if endpoint.name not in self.embedding_model_metadata:
89-
logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.")
90-
continue
91-
model.model_type = ModelType.embedding
92-
model.metadata = self.embedding_model_metadata[endpoint.name]
93-
else:
94-
logger.warning(f"Unknown model type, skipping: {endpoint}")
95-
continue
96-
97-
self._model_cache[endpoint.name] = model
98-
99-
return list(self._model_cache.values())
63+
async def get_models(self) -> list[Model] | None:
64+
return [
65+
endpoint.name
66+
for endpoint in WorkspaceClient(
67+
host=self.config.url, token=self.get_api_key()
68+
).serving_endpoints.list() # TODO: this is not async
69+
]
10070

10171
async def should_refresh_models(self) -> bool:
10272
return False

llama_stack/providers/remote/inference/fireworks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@ async def get_adapter_impl(config: FireworksImplConfig, _deps):
1717
from .fireworks import FireworksInferenceAdapter
1818

1919
assert isinstance(config, FireworksImplConfig), f"Unexpected config type: {type(config)}"
20-
impl = FireworksInferenceAdapter(config)
20+
impl = FireworksInferenceAdapter(config=config)
2121
await impl.initialize()
2222
return impl

0 commit comments

Comments
 (0)