Skip to content

Commit d23ed26

Browse files
authored
chore: turn OpenAIMixin into a pydantic.BaseModel (#3671)
# What does this PR do? - implement get_api_key instead of relying on LiteLLMOpenAIMixin.get_api_key - remove use of LiteLLMOpenAIMixin - add default initialize/shutdown methods to OpenAIMixin - remove __init__s to allow proper pydantic construction - remove dead code from vllm adapter and associated / duplicate unit tests - update vllm adapter to use openaimixin for model registration - remove ModelRegistryHelper from fireworks & together adapters - remove Inference from nvidia adapter - complete type hints on embedding_model_metadata - allow extra fields on OpenAIMixin, for model_store, __provider_id__, etc - new recordings for ollama - enhance the list models error handling - update cerebras (remove cerebras-cloud-sdk) and anthropic (custom model listing) inference adapters - parametrized test_inference_client_caching - remove cerebras, databricks, fireworks, together from blanket mypy exclude - removed unnecessary litellm deps ## Test Plan ci
1 parent 724dac4 commit d23ed26

File tree

131 files changed

+83632
-1758
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

131 files changed

+83632
-1758
lines changed

docs/docs/providers/inference/remote_databricks.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Databricks inference provider for running models on Databricks' unified analytic
1515
| Field | Type | Required | Default | Description |
1616
|-------|------|----------|---------|-------------|
1717
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
18-
| `url` | `<class 'str'>` | No | | The URL for the Databricks model serving endpoint |
18+
| `url` | `str \| None` | No | | The URL for the Databricks model serving endpoint |
1919
| `api_token` | `<class 'pydantic.types.SecretStr'>` | No | | The Databricks API token |
2020

2121
## Sample Configuration

llama_stack/providers/registry/inference.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,7 @@ def available_providers() -> list[ProviderSpec]:
5252
api=Api.inference,
5353
adapter_type="cerebras",
5454
provider_type="remote::cerebras",
55-
pip_packages=[
56-
"cerebras_cloud_sdk",
57-
],
55+
pip_packages=[],
5856
module="llama_stack.providers.remote.inference.cerebras",
5957
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
6058
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
@@ -169,7 +167,7 @@ def available_providers() -> list[ProviderSpec]:
169167
api=Api.inference,
170168
adapter_type="openai",
171169
provider_type="remote::openai",
172-
pip_packages=["litellm"],
170+
pip_packages=[],
173171
module="llama_stack.providers.remote.inference.openai",
174172
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
175173
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
@@ -179,7 +177,7 @@ def available_providers() -> list[ProviderSpec]:
179177
api=Api.inference,
180178
adapter_type="anthropic",
181179
provider_type="remote::anthropic",
182-
pip_packages=["litellm"],
180+
pip_packages=["anthropic"],
183181
module="llama_stack.providers.remote.inference.anthropic",
184182
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
185183
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
@@ -189,9 +187,7 @@ def available_providers() -> list[ProviderSpec]:
189187
api=Api.inference,
190188
adapter_type="gemini",
191189
provider_type="remote::gemini",
192-
pip_packages=[
193-
"litellm",
194-
],
190+
pip_packages=[],
195191
module="llama_stack.providers.remote.inference.gemini",
196192
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
197193
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
@@ -202,7 +198,6 @@ def available_providers() -> list[ProviderSpec]:
202198
adapter_type="vertexai",
203199
provider_type="remote::vertexai",
204200
pip_packages=[
205-
"litellm",
206201
"google-cloud-aiplatform",
207202
],
208203
module="llama_stack.providers.remote.inference.vertexai",
@@ -233,9 +228,7 @@ def available_providers() -> list[ProviderSpec]:
233228
api=Api.inference,
234229
adapter_type="groq",
235230
provider_type="remote::groq",
236-
pip_packages=[
237-
"litellm",
238-
],
231+
pip_packages=[],
239232
module="llama_stack.providers.remote.inference.groq",
240233
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
241234
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
@@ -245,7 +238,7 @@ def available_providers() -> list[ProviderSpec]:
245238
api=Api.inference,
246239
adapter_type="llama-openai-compat",
247240
provider_type="remote::llama-openai-compat",
248-
pip_packages=["litellm"],
241+
pip_packages=[],
249242
module="llama_stack.providers.remote.inference.llama_openai_compat",
250243
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
251244
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
@@ -255,9 +248,7 @@ def available_providers() -> list[ProviderSpec]:
255248
api=Api.inference,
256249
adapter_type="sambanova",
257250
provider_type="remote::sambanova",
258-
pip_packages=[
259-
"litellm",
260-
],
251+
pip_packages=[],
261252
module="llama_stack.providers.remote.inference.sambanova",
262253
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
263254
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
@@ -287,7 +278,7 @@ def available_providers() -> list[ProviderSpec]:
287278
api=Api.inference,
288279
provider_type="remote::azure",
289280
adapter_type="azure",
290-
pip_packages=["litellm"],
281+
pip_packages=[],
291282
module="llama_stack.providers.remote.inference.azure",
292283
config_class="llama_stack.providers.remote.inference.azure.AzureConfig",
293284
provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator",

llama_stack/providers/remote/inference/anthropic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@
1010
async def get_adapter_impl(config: AnthropicConfig, _deps):
1111
from .anthropic import AnthropicInferenceAdapter
1212

13-
impl = AnthropicInferenceAdapter(config)
13+
impl = AnthropicInferenceAdapter(config=config)
1414
await impl.initialize()
1515
return impl

llama_stack/providers/remote/inference/anthropic/anthropic.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,19 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
7+
from collections.abc import Iterable
8+
9+
from anthropic import AsyncAnthropic
10+
811
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
912

1013
from .config import AnthropicConfig
1114

1215

13-
class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
16+
class AnthropicInferenceAdapter(OpenAIMixin):
17+
config: AnthropicConfig
18+
19+
provider_data_api_key_field: str = "anthropic_api_key"
1420
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings
1521
# TODO: add support for voyageai, which is where these models are hosted
1622
# embedding_model_metadata = {
@@ -23,22 +29,11 @@ class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
2329
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
2430
# }
2531

26-
def __init__(self, config: AnthropicConfig) -> None:
27-
LiteLLMOpenAIMixin.__init__(
28-
self,
29-
litellm_provider_name="anthropic",
30-
api_key_from_config=config.api_key,
31-
provider_data_api_key_field="anthropic_api_key",
32-
)
33-
self.config = config
34-
35-
async def initialize(self) -> None:
36-
await super().initialize()
37-
38-
async def shutdown(self) -> None:
39-
await super().shutdown()
40-
41-
get_api_key = LiteLLMOpenAIMixin.get_api_key
32+
def get_api_key(self) -> str:
33+
return self.config.api_key or ""
4234

4335
def get_base_url(self):
4436
return "https://api.anthropic.com/v1"
37+
38+
async def list_provider_model_ids(self) -> Iterable[str]:
39+
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]

llama_stack/providers/remote/inference/azure/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@
1010
async def get_adapter_impl(config: AzureConfig, _deps):
1111
from .azure import AzureInferenceAdapter
1212

13-
impl = AzureInferenceAdapter(config)
13+
impl = AzureInferenceAdapter(config=config)
1414
await impl.initialize()
1515
return impl

llama_stack/providers/remote/inference/azure/azure.py

Lines changed: 6 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,20 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from typing import Any
87
from urllib.parse import urljoin
98

10-
from llama_stack.apis.inference import ChatCompletionRequest
11-
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
12-
LiteLLMOpenAIMixin,
13-
)
149
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
1510

1611
from .config import AzureConfig
1712

1813

19-
class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
20-
def __init__(self, config: AzureConfig) -> None:
21-
LiteLLMOpenAIMixin.__init__(
22-
self,
23-
litellm_provider_name="azure",
24-
api_key_from_config=config.api_key.get_secret_value(),
25-
provider_data_api_key_field="azure_api_key",
26-
openai_compat_api_base=str(config.api_base),
27-
)
28-
self.config = config
14+
class AzureInferenceAdapter(OpenAIMixin):
15+
config: AzureConfig
2916

30-
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
31-
get_api_key = LiteLLMOpenAIMixin.get_api_key
17+
provider_data_api_key_field: str = "azure_api_key"
18+
19+
def get_api_key(self) -> str:
20+
return self.config.api_key.get_secret_value()
3221

3322
def get_base_url(self) -> str:
3423
"""
@@ -37,26 +26,3 @@ def get_base_url(self) -> str:
3726
Returns the Azure API base URL from the configuration.
3827
"""
3928
return urljoin(str(self.config.api_base), "/openai/v1")
40-
41-
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
42-
# Get base parameters from parent
43-
params = await super()._get_params(request)
44-
45-
# Add Azure specific parameters
46-
provider_data = self.get_request_provider_data()
47-
if provider_data:
48-
if getattr(provider_data, "azure_api_key", None):
49-
params["api_key"] = provider_data.azure_api_key
50-
if getattr(provider_data, "azure_api_base", None):
51-
params["api_base"] = provider_data.azure_api_base
52-
if getattr(provider_data, "azure_api_version", None):
53-
params["api_version"] = provider_data.azure_api_version
54-
if getattr(provider_data, "azure_api_type", None):
55-
params["api_type"] = provider_data.azure_api_type
56-
else:
57-
params["api_key"] = self.config.api_key.get_secret_value()
58-
params["api_base"] = str(self.config.api_base)
59-
params["api_version"] = self.config.api_version
60-
params["api_type"] = self.config.api_type
61-
62-
return params

llama_stack/providers/remote/inference/cerebras/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ async def get_adapter_impl(config: CerebrasImplConfig, _deps):
1212

1313
assert isinstance(config, CerebrasImplConfig), f"Unexpected config type: {type(config)}"
1414

15-
impl = CerebrasInferenceAdapter(config)
15+
impl = CerebrasInferenceAdapter(config=config)
1616

1717
await impl.initialize()
1818

llama_stack/providers/remote/inference/cerebras/cerebras.py

Lines changed: 3 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -6,71 +6,21 @@
66

77
from urllib.parse import urljoin
88

9-
from cerebras.cloud.sdk import AsyncCerebras
10-
11-
from llama_stack.apis.inference import (
12-
ChatCompletionRequest,
13-
CompletionRequest,
14-
Inference,
15-
OpenAIEmbeddingsResponse,
16-
TopKSamplingStrategy,
17-
)
18-
from llama_stack.providers.utils.inference.openai_compat import (
19-
get_sampling_options,
20-
)
9+
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
2110
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
22-
from llama_stack.providers.utils.inference.prompt_adapter import (
23-
chat_completion_request_to_prompt,
24-
completion_request_to_prompt,
25-
)
2611

2712
from .config import CerebrasImplConfig
2813

2914

30-
class CerebrasInferenceAdapter(
31-
OpenAIMixin,
32-
Inference,
33-
):
34-
def __init__(self, config: CerebrasImplConfig) -> None:
35-
self.config = config
36-
37-
# TODO: make this use provider data, etc. like other providers
38-
self._cerebras_client = AsyncCerebras(
39-
base_url=self.config.base_url,
40-
api_key=self.config.api_key.get_secret_value(),
41-
)
15+
class CerebrasInferenceAdapter(OpenAIMixin):
16+
config: CerebrasImplConfig
4217

4318
def get_api_key(self) -> str:
4419
return self.config.api_key.get_secret_value()
4520

4621
def get_base_url(self) -> str:
4722
return urljoin(self.config.base_url, "v1")
4823

49-
async def initialize(self) -> None:
50-
return
51-
52-
async def shutdown(self) -> None:
53-
pass
54-
55-
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
56-
if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
57-
raise ValueError("`top_k` not supported by Cerebras")
58-
59-
prompt = ""
60-
if isinstance(request, ChatCompletionRequest):
61-
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
62-
elif isinstance(request, CompletionRequest):
63-
prompt = await completion_request_to_prompt(request)
64-
else:
65-
raise ValueError(f"Unknown request type {type(request)}")
66-
67-
return {
68-
"model": request.model,
69-
"prompt": prompt,
70-
"stream": request.stream,
71-
**get_sampling_options(request.sampling_params),
72-
}
73-
7424
async def openai_embeddings(
7525
self,
7626
model: str,

llama_stack/providers/remote/inference/cerebras/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class CerebrasImplConfig(RemoteInferenceProviderConfig):
2222
description="Base URL for the Cerebras API",
2323
)
2424
api_key: SecretStr = Field(
25-
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")),
25+
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")), # type: ignore[arg-type]
2626
description="Cerebras API Key",
2727
)
2828

llama_stack/providers/remote/inference/databricks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@ async def get_adapter_impl(config: DatabricksImplConfig, _deps):
1111
from .databricks import DatabricksInferenceAdapter
1212

1313
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
14-
impl = DatabricksInferenceAdapter(config)
14+
impl = DatabricksInferenceAdapter(config=config)
1515
await impl.initialize()
1616
return impl

0 commit comments

Comments
 (0)