Skip to content

Commit fd06717

Browse files
committed
chore: make OpenAIMixin maintainable, turn OpenAIMixin into a pydantic.BaseModel
- implement get_api_key instead of relying on LiteLLMOpenAIMixin.get_api_key - remove use of LiteLLMOpenAIMixin - add default initialize/shutdown methods to OpenAIMixin - remove __init__s to allow proper pydantic construction - remove dead code from vllm adapter and associated / duplicate unit tests - update vllm adapter to use openaimixin for model registration - remove ModelRegistryHelper from fireworks & together adapters - remove Inference from nvidia adapter - complete type hints on embedding_model_metadata - allow extra fields on OpenAIMixin, for model_store, __provider_id__, etc - new recordings for ollama - enhance the list models error handling w/ new tests - update cerebras (remove cerebras-cloud-sdk) and anthropic (custom model listing) inference adapters - parametrized test_inference_client_caching - remove cerebras, databricks, fireworks, together from blanket mypy exclude
1 parent 351c4b9 commit fd06717

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+12900
-1733
lines changed

docs/docs/providers/inference/remote_databricks.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Databricks inference provider for running models on Databricks' unified analytic
1515
| Field | Type | Required | Default | Description |
1616
|-------|------|----------|---------|-------------|
1717
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
18-
| `url` | `<class 'str'>` | No | | The URL for the Databricks model serving endpoint |
18+
| `url` | `str \| None` | No | | The URL for the Databricks model serving endpoint |
1919
| `api_token` | `<class 'pydantic.types.SecretStr'>` | No | | The Databricks API token |
2020

2121
## Sample Configuration

llama_stack/providers/registry/inference.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,7 @@ def available_providers() -> list[ProviderSpec]:
5252
api=Api.inference,
5353
adapter_type="cerebras",
5454
provider_type="remote::cerebras",
55-
pip_packages=[
56-
"cerebras_cloud_sdk",
57-
],
55+
pip_packages=[],
5856
module="llama_stack.providers.remote.inference.cerebras",
5957
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
6058
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
@@ -179,7 +177,7 @@ def available_providers() -> list[ProviderSpec]:
179177
api=Api.inference,
180178
adapter_type="anthropic",
181179
provider_type="remote::anthropic",
182-
pip_packages=["litellm"],
180+
pip_packages=["litellm", "anthropic"],
183181
module="llama_stack.providers.remote.inference.anthropic",
184182
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
185183
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",

llama_stack/providers/remote/inference/anthropic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@
1010
async def get_adapter_impl(config: AnthropicConfig, _deps):
1111
from .anthropic import AnthropicInferenceAdapter
1212

13-
impl = AnthropicInferenceAdapter(config)
13+
impl = AnthropicInferenceAdapter(config=config)
1414
await impl.initialize()
1515
return impl

llama_stack/providers/remote/inference/anthropic/anthropic.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,19 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
7+
from collections.abc import Iterable
8+
9+
from anthropic import AsyncAnthropic
10+
811
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
912

1013
from .config import AnthropicConfig
1114

1215

13-
class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
16+
class AnthropicInferenceAdapter(OpenAIMixin):
17+
config: AnthropicConfig
18+
19+
provider_data_api_key_field: str = "anthropic_api_key"
1420
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings
1521
# TODO: add support for voyageai, which is where these models are hosted
1622
# embedding_model_metadata = {
@@ -23,22 +29,11 @@ class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
2329
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
2430
# }
2531

26-
def __init__(self, config: AnthropicConfig) -> None:
27-
LiteLLMOpenAIMixin.__init__(
28-
self,
29-
litellm_provider_name="anthropic",
30-
api_key_from_config=config.api_key,
31-
provider_data_api_key_field="anthropic_api_key",
32-
)
33-
self.config = config
34-
35-
async def initialize(self) -> None:
36-
await super().initialize()
37-
38-
async def shutdown(self) -> None:
39-
await super().shutdown()
40-
41-
get_api_key = LiteLLMOpenAIMixin.get_api_key
32+
def get_api_key(self) -> str:
33+
return self.config.api_key or ""
4234

4335
def get_base_url(self):
4436
return "https://api.anthropic.com/v1"
37+
38+
async def get_models(self) -> Iterable[str] | None:
39+
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]

llama_stack/providers/remote/inference/azure/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@
1010
async def get_adapter_impl(config: AzureConfig, _deps):
1111
from .azure import AzureInferenceAdapter
1212

13-
impl = AzureInferenceAdapter(config)
13+
impl = AzureInferenceAdapter(config=config)
1414
await impl.initialize()
1515
return impl

llama_stack/providers/remote/inference/azure/azure.py

Lines changed: 6 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,20 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from typing import Any
87
from urllib.parse import urljoin
98

10-
from llama_stack.apis.inference import ChatCompletionRequest
11-
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
12-
LiteLLMOpenAIMixin,
13-
)
149
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
1510

1611
from .config import AzureConfig
1712

1813

19-
class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
20-
def __init__(self, config: AzureConfig) -> None:
21-
LiteLLMOpenAIMixin.__init__(
22-
self,
23-
litellm_provider_name="azure",
24-
api_key_from_config=config.api_key.get_secret_value(),
25-
provider_data_api_key_field="azure_api_key",
26-
openai_compat_api_base=str(config.api_base),
27-
)
28-
self.config = config
14+
class AzureInferenceAdapter(OpenAIMixin):
15+
config: AzureConfig
2916

30-
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
31-
get_api_key = LiteLLMOpenAIMixin.get_api_key
17+
provider_data_api_key_field: str = "azure_api_key"
18+
19+
def get_api_key(self) -> str:
20+
return self.config.api_key.get_secret_value()
3221

3322
def get_base_url(self) -> str:
3423
"""
@@ -37,26 +26,3 @@ def get_base_url(self) -> str:
3726
Returns the Azure API base URL from the configuration.
3827
"""
3928
return urljoin(str(self.config.api_base), "/openai/v1")
40-
41-
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
42-
# Get base parameters from parent
43-
params = await super()._get_params(request)
44-
45-
# Add Azure specific parameters
46-
provider_data = self.get_request_provider_data()
47-
if provider_data:
48-
if getattr(provider_data, "azure_api_key", None):
49-
params["api_key"] = provider_data.azure_api_key
50-
if getattr(provider_data, "azure_api_base", None):
51-
params["api_base"] = provider_data.azure_api_base
52-
if getattr(provider_data, "azure_api_version", None):
53-
params["api_version"] = provider_data.azure_api_version
54-
if getattr(provider_data, "azure_api_type", None):
55-
params["api_type"] = provider_data.azure_api_type
56-
else:
57-
params["api_key"] = self.config.api_key.get_secret_value()
58-
params["api_base"] = str(self.config.api_base)
59-
params["api_version"] = self.config.api_version
60-
params["api_type"] = self.config.api_type
61-
62-
return params

llama_stack/providers/remote/inference/cerebras/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ async def get_adapter_impl(config: CerebrasImplConfig, _deps):
1212

1313
assert isinstance(config, CerebrasImplConfig), f"Unexpected config type: {type(config)}"
1414

15-
impl = CerebrasInferenceAdapter(config)
15+
impl = CerebrasInferenceAdapter(config=config)
1616

1717
await impl.initialize()
1818

llama_stack/providers/remote/inference/cerebras/cerebras.py

Lines changed: 3 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -6,71 +6,21 @@
66

77
from urllib.parse import urljoin
88

9-
from cerebras.cloud.sdk import AsyncCerebras
10-
11-
from llama_stack.apis.inference import (
12-
ChatCompletionRequest,
13-
CompletionRequest,
14-
Inference,
15-
OpenAIEmbeddingsResponse,
16-
TopKSamplingStrategy,
17-
)
18-
from llama_stack.providers.utils.inference.openai_compat import (
19-
get_sampling_options,
20-
)
9+
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
2110
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
22-
from llama_stack.providers.utils.inference.prompt_adapter import (
23-
chat_completion_request_to_prompt,
24-
completion_request_to_prompt,
25-
)
2611

2712
from .config import CerebrasImplConfig
2813

2914

30-
class CerebrasInferenceAdapter(
31-
OpenAIMixin,
32-
Inference,
33-
):
34-
def __init__(self, config: CerebrasImplConfig) -> None:
35-
self.config = config
36-
37-
# TODO: make this use provider data, etc. like other providers
38-
self._cerebras_client = AsyncCerebras(
39-
base_url=self.config.base_url,
40-
api_key=self.config.api_key.get_secret_value(),
41-
)
15+
class CerebrasInferenceAdapter(OpenAIMixin):
16+
config: CerebrasImplConfig
4217

4318
def get_api_key(self) -> str:
4419
return self.config.api_key.get_secret_value()
4520

4621
def get_base_url(self) -> str:
4722
return urljoin(self.config.base_url, "v1")
4823

49-
async def initialize(self) -> None:
50-
return
51-
52-
async def shutdown(self) -> None:
53-
pass
54-
55-
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
56-
if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
57-
raise ValueError("`top_k` not supported by Cerebras")
58-
59-
prompt = ""
60-
if isinstance(request, ChatCompletionRequest):
61-
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
62-
elif isinstance(request, CompletionRequest):
63-
prompt = await completion_request_to_prompt(request)
64-
else:
65-
raise ValueError(f"Unknown request type {type(request)}")
66-
67-
return {
68-
"model": request.model,
69-
"prompt": prompt,
70-
"stream": request.stream,
71-
**get_sampling_options(request.sampling_params),
72-
}
73-
7424
async def openai_embeddings(
7525
self,
7626
model: str,

llama_stack/providers/remote/inference/cerebras/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class CerebrasImplConfig(RemoteInferenceProviderConfig):
2222
description="Base URL for the Cerebras API",
2323
)
2424
api_key: SecretStr = Field(
25-
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")),
25+
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")), # type: ignore[arg-type]
2626
description="Cerebras API Key",
2727
)
2828

llama_stack/providers/remote/inference/databricks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@ async def get_adapter_impl(config: DatabricksImplConfig, _deps):
1111
from .databricks import DatabricksInferenceAdapter
1212

1313
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
14-
impl = DatabricksInferenceAdapter(config)
14+
impl = DatabricksInferenceAdapter(config=config)
1515
await impl.initialize()
1616
return impl

0 commit comments

Comments
 (0)