Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/docs/providers/inference/remote_databricks.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Databricks inference provider for running models on Databricks' unified analytic
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | | The URL for the Databricks model serving endpoint |
| `url` | `str \| None` | No | | The URL for the Databricks model serving endpoint |
| `api_token` | `<class 'pydantic.types.SecretStr'>` | No | | The Databricks API token |

## Sample Configuration
Expand Down
25 changes: 8 additions & 17 deletions llama_stack/providers/registry/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="cerebras",
provider_type="remote::cerebras",
pip_packages=[
"cerebras_cloud_sdk",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.cerebras",
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
Expand Down Expand Up @@ -169,7 +167,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="openai",
provider_type="remote::openai",
pip_packages=["litellm"],
pip_packages=[],
module="llama_stack.providers.remote.inference.openai",
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
Expand All @@ -179,7 +177,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="anthropic",
provider_type="remote::anthropic",
pip_packages=["litellm"],
pip_packages=["anthropic"],
module="llama_stack.providers.remote.inference.anthropic",
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
Expand All @@ -189,9 +187,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="gemini",
provider_type="remote::gemini",
pip_packages=[
"litellm",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.gemini",
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
Expand All @@ -202,7 +198,6 @@ def available_providers() -> list[ProviderSpec]:
adapter_type="vertexai",
provider_type="remote::vertexai",
pip_packages=[
"litellm",
"google-cloud-aiplatform",
],
module="llama_stack.providers.remote.inference.vertexai",
Expand Down Expand Up @@ -233,9 +228,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="groq",
provider_type="remote::groq",
pip_packages=[
"litellm",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
Expand All @@ -245,7 +238,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="llama-openai-compat",
provider_type="remote::llama-openai-compat",
pip_packages=["litellm"],
pip_packages=[],
module="llama_stack.providers.remote.inference.llama_openai_compat",
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
Expand All @@ -255,9 +248,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="sambanova",
provider_type="remote::sambanova",
pip_packages=[
"litellm",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.sambanova",
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
Expand Down Expand Up @@ -287,7 +278,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
provider_type="remote::azure",
adapter_type="azure",
pip_packages=["litellm"],
pip_packages=[],
module="llama_stack.providers.remote.inference.azure",
config_class="llama_stack.providers.remote.inference.azure.AzureConfig",
provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
async def get_adapter_impl(config: AnthropicConfig, _deps):
from .anthropic import AnthropicInferenceAdapter

impl = AnthropicInferenceAdapter(config)
impl = AnthropicInferenceAdapter(config=config)
await impl.initialize()
return impl
31 changes: 13 additions & 18 deletions llama_stack/providers/remote/inference/anthropic/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from collections.abc import Iterable

from anthropic import AsyncAnthropic

from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin

from .config import AnthropicConfig


class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
class AnthropicInferenceAdapter(OpenAIMixin):
config: AnthropicConfig

provider_data_api_key_field: str = "anthropic_api_key"
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings
# TODO: add support for voyageai, which is where these models are hosted
# embedding_model_metadata = {
Expand All @@ -23,22 +29,11 @@ class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
# }

def __init__(self, config: AnthropicConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="anthropic",
api_key_from_config=config.api_key,
provider_data_api_key_field="anthropic_api_key",
)
self.config = config

async def initialize(self) -> None:
await super().initialize()

async def shutdown(self) -> None:
await super().shutdown()

get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_api_key(self) -> str:
return self.config.api_key or ""

def get_base_url(self):
return "https://api.anthropic.com/v1"

async def list_provider_model_ids(self) -> Iterable[str]:
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]
2 changes: 1 addition & 1 deletion llama_stack/providers/remote/inference/azure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
async def get_adapter_impl(config: AzureConfig, _deps):
from .azure import AzureInferenceAdapter

impl = AzureInferenceAdapter(config)
impl = AzureInferenceAdapter(config=config)
await impl.initialize()
return impl
46 changes: 6 additions & 40 deletions llama_stack/providers/remote/inference/azure/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,20 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Any
from urllib.parse import urljoin

from llama_stack.apis.inference import ChatCompletionRequest
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove litellm from RemoteProviderSpec in the registry?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

LiteLLMOpenAIMixin,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin

from .config import AzureConfig


class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: AzureConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="azure",
api_key_from_config=config.api_key.get_secret_value(),
provider_data_api_key_field="azure_api_key",
openai_compat_api_base=str(config.api_base),
)
self.config = config
class AzureInferenceAdapter(OpenAIMixin):
config: AzureConfig

# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
provider_data_api_key_field: str = "azure_api_key"

def get_api_key(self) -> str:
return self.config.api_key.get_secret_value()

def get_base_url(self) -> str:
"""
Expand All @@ -37,26 +26,3 @@ def get_base_url(self) -> str:
Returns the Azure API base URL from the configuration.
"""
return urljoin(str(self.config.api_base), "/openai/v1")

async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)

# Add Azure specific parameters
provider_data = self.get_request_provider_data()
if provider_data:
if getattr(provider_data, "azure_api_key", None):
params["api_key"] = provider_data.azure_api_key
if getattr(provider_data, "azure_api_base", None):
params["api_base"] = provider_data.azure_api_base
if getattr(provider_data, "azure_api_version", None):
params["api_version"] = provider_data.azure_api_version
if getattr(provider_data, "azure_api_type", None):
params["api_type"] = provider_data.azure_api_type
else:
params["api_key"] = self.config.api_key.get_secret_value()
params["api_base"] = str(self.config.api_base)
params["api_version"] = self.config.api_version
params["api_type"] = self.config.api_type

return params
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ async def get_adapter_impl(config: CerebrasImplConfig, _deps):

assert isinstance(config, CerebrasImplConfig), f"Unexpected config type: {type(config)}"

impl = CerebrasInferenceAdapter(config)
impl = CerebrasInferenceAdapter(config=config)

await impl.initialize()

Expand Down
56 changes: 3 additions & 53 deletions llama_stack/providers/remote/inference/cerebras/cerebras.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,71 +6,21 @@

from urllib.parse import urljoin

from cerebras.cloud.sdk import AsyncCerebras

from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
Inference,
OpenAIEmbeddingsResponse,
TopKSamplingStrategy,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
)
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
)

from .config import CerebrasImplConfig


class CerebrasInferenceAdapter(
OpenAIMixin,
Inference,
):
def __init__(self, config: CerebrasImplConfig) -> None:
self.config = config

# TODO: make this use provider data, etc. like other providers
self._cerebras_client = AsyncCerebras(
base_url=self.config.base_url,
api_key=self.config.api_key.get_secret_value(),
)
class CerebrasInferenceAdapter(OpenAIMixin):
config: CerebrasImplConfig

def get_api_key(self) -> str:
return self.config.api_key.get_secret_value()

def get_base_url(self) -> str:
return urljoin(self.config.base_url, "v1")

async def initialize(self) -> None:
return

async def shutdown(self) -> None:
pass

async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
raise ValueError("`top_k` not supported by Cerebras")

prompt = ""
if isinstance(request, ChatCompletionRequest):
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
elif isinstance(request, CompletionRequest):
prompt = await completion_request_to_prompt(request)
else:
raise ValueError(f"Unknown request type {type(request)}")

return {
"model": request.model,
"prompt": prompt,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}

async def openai_embeddings(
self,
model: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class CerebrasImplConfig(RemoteInferenceProviderConfig):
description="Base URL for the Cerebras API",
)
api_key: SecretStr = Field(
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")),
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")), # type: ignore[arg-type]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")), # type: ignore[arg-type]
default_factory=lambda: SecretStr(os.getenv("CEREBRAS_API_KEY", ""))

And mypy will be happy

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i removed the blanket excludes. i want to avoid rolling more things into this PR.

description="Cerebras API Key",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ async def get_adapter_impl(config: DatabricksImplConfig, _deps):
from .databricks import DatabricksInferenceAdapter

assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config)
impl = DatabricksInferenceAdapter(config=config)
await impl.initialize()
return impl
4 changes: 2 additions & 2 deletions llama_stack/providers/remote/inference/databricks/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

@json_schema_type
class DatabricksImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
url: str | None = Field(
default=None,
description="The URL for the Databricks model serving endpoint",
)
api_token: SecretStr = Field(
default=SecretStr(None),
default=SecretStr(None), # type: ignore[arg-type]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not set default=SecretStr("") and avoid ignore type comment? (yes I know, that same discussion again)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i hope to handle this later. i removed the blanket exclude so we can see the specific instances.

description="The Databricks API token",
)

Expand Down
Loading
Loading