Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 111 additions & 93 deletions llama_stack/providers/remote/inference/runpod/runpod.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,125 +3,131 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator

from openai import OpenAI
from typing import Any

from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import OpenAIEmbeddingsResponse

# from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
from llama_stack.apis.inference import (
Inference,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin

from .config import RunpodImplConfig

# https://docs.runpod.io/serverless/vllm/overview#compatible-models
# https://github.com/runpod-workers/worker-vllm/blob/main/README.md#compatible-model-architectures
RUNPOD_SUPPORTED_MODELS = {
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
"Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8",
"Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B",
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
"Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct",
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8",
"Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct",
"Llama3.2-1B": "meta-llama/Llama-3.2-1B",
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
}

SAFETY_MODELS_ENTRIES = []

# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
MODEL_ENTRIES = [
build_hf_repo_model_entry(provider_model_id, model_descriptor)
for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
] + SAFETY_MODELS_ENTRIES
MODEL_ENTRIES = []


class RunpodInferenceAdapter(
OpenAIMixin,
ModelRegistryHelper,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
):
"""
Adapter for RunPod's OpenAI-compatible API endpoints.
Supports VLLM for serverless endpoint self-hosted or public endpoints.
Can work with any runpod endpoints that support OpenAI-compatible API
"""

def __init__(self, config: RunpodImplConfig) -> None:
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
OpenAIMixin.__init__(self)
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self.config = config

def get_api_key(self) -> str:
"""Get API key for OpenAI client."""
return self.config.api_token

def get_base_url(self) -> str:
"""Get base URL for OpenAI client."""
return self.config.url

async def initialize(self) -> None:
return
pass

async def shutdown(self) -> None:
pass

async def chat_completion(
async def openai_chat_completion(
self,
model: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
request = ChatCompletionRequest(
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
):
"""Override to add RunPod-specific stream_options requirement."""
if stream and not stream_options:
stream_options = {"include_usage": True}

return await super().openai_chat_completion(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
stream=stream,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
tool_config=tool_config,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)

client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
if stream:
return self._stream_chat_completion(request, client)
else:
return await self._nonstream_chat_completion(request, client)

async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(r, request)

async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
params = self._get_params(request)

async def _to_async_generator():
s = client.completions.create(**params)
for chunk in s:
yield chunk

stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk

def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": self.map_to_provider_model(request.model),
"prompt": chat_completion_request_to_prompt(request),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
async def register_model(self, model: Model) -> Model:
"""
Register a model and verify it's available on the RunPod endpoint.
In the .yaml file the model: can be defined as example
models:
- metadata: {}
model_id: qwen3-32b-awq
model_type: llm
provider_id: runpod
provider_model_id: Qwen/Qwen3-32B-AWQ
"""
provider_model_id = model.provider_resource_id or model.identifier
is_available = await self.check_model_availability(provider_model_id)

if not is_available:
raise ValueError(
f"Model {provider_model_id} is not available on RunPod endpoint. "
f"Check your RunPod endpoint configuration."
)

return model

async def openai_embeddings(
self,
Expand All @@ -131,4 +137,16 @@ async def openai_embeddings(
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
# Resolve model_id to provider_resource_id
model_obj = await self.model_store.get_model(model)
provider_model_id = model_obj.provider_resource_id or model

response = await self.client.embeddings.create(
model=provider_model_id,
input=input,
encoding_format=encoding_format,
dimensions=dimensions,
user=user,
)

return response