Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/static/llama-stack-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,14 @@ paths:
post:
responses:
'200':
description: An OpenAICompletion.
description: An OpenAICompletion or an async iterator of OpenAICompletion chunks when streaming.
content:
application/json:
schema:
$ref: '#/components/schemas/OpenAICompletion'
text/event-stream:
schema:
$ref: '#/components/schemas/OpenAICompletion'
'400':
description: Bad Request
$ref: '#/components/responses/BadRequest400'
Expand Down
5 changes: 4 additions & 1 deletion docs/static/stainless-llama-stack-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,14 @@ paths:
post:
responses:
'200':
description: An OpenAICompletion.
description: An OpenAICompletion or an async iterator of OpenAICompletion chunks when streaming.
content:
application/json:
schema:
$ref: '#/components/schemas/OpenAICompletion'
text/event-stream:
schema:
$ref: '#/components/schemas/OpenAICompletion'
'400':
description: Bad Request
$ref: '#/components/responses/BadRequest400'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ async def shutdown(self) -> None:
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
raise NotImplementedError("OpenAI completion not supported by meta reference provider")

async def should_refresh_models(self) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def openai_embeddings(
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
"""Bedrock's OpenAI-compatible API does not support the /v1/completions endpoint."""
raise NotImplementedError(
"Bedrock's OpenAI-compatible API does not support /v1/completions endpoint. "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from collections.abc import Iterable
from collections.abc import AsyncIterator, Iterable

from databricks.sdk import WorkspaceClient

Expand Down Expand Up @@ -50,5 +50,5 @@ async def list_provider_model_ids(self) -> Iterable[str]:
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
raise NotImplementedError()
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from collections.abc import AsyncIterator

from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
Expand Down Expand Up @@ -36,7 +38,7 @@ def get_base_url(self) -> str:
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
raise NotImplementedError()

async def openai_embeddings(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from openai import AsyncOpenAI

from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
from llama_stack_api import (
Inference,
Model,
Expand Down Expand Up @@ -107,12 +108,16 @@ def _get_passthrough_api_key(self) -> str:
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
"""Forward completion request to downstream using OpenAI client."""
client = self._get_openai_client()
request_params = params.model_dump(exclude_none=True)
response = await client.completions.create(**request_params)
return response # type: ignore

if params.stream:
return wrap_async_stream(response)

return response # type: ignore[return-value]

async def openai_chat_completion(
self,
Expand All @@ -122,7 +127,11 @@ async def openai_chat_completion(
client = self._get_openai_client()
request_params = params.model_dump(exclude_none=True)
response = await client.chat.completions.create(**request_params)
return response # type: ignore

if params.stream:
return wrap_async_stream(response)

return response # type: ignore[return-value]

async def openai_embeddings(
self,
Expand Down
10 changes: 8 additions & 2 deletions src/llama_stack/providers/remote/inference/watsonx/watsonx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
from llama_stack_api import (
Model,
ModelType,
Expand Down Expand Up @@ -177,7 +178,7 @@ async def _normalize_stream(
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
"""
Override parent method to add watsonx-specific parameters.
"""
Expand Down Expand Up @@ -210,7 +211,12 @@ async def openai_completion(
timeout=self.config.timeout,
project_id=self.config.project_id,
)
return await litellm.atext_completion(**request_params)
result = await litellm.atext_completion(**request_params)

if params.stream:
return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types

return result # type: ignore[return-value] # external lib lacks type stubs

async def openai_embeddings(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from llama_stack.providers.utils.inference.openai_compat import (
prepare_openai_completion_params,
)
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
from llama_stack_api import (
InferenceProvider,
OpenAIChatCompletion,
Expand Down Expand Up @@ -178,7 +179,7 @@ async def openai_embeddings(
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
if not self.model_store:
raise ValueError("Model store is not initialized")

Expand Down Expand Up @@ -210,7 +211,12 @@ async def openai_completion(
api_base=self.api_base,
)
# LiteLLM returns compatible type but mypy can't verify external library
return await litellm.atext_completion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs
result = await litellm.atext_completion(**request_params)

if params.stream:
return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types

return result # type: ignore[return-value] # external lib lacks type stubs

async def openai_chat_completion(
self,
Expand Down Expand Up @@ -261,7 +267,12 @@ async def openai_chat_completion(
api_base=self.api_base,
)
# LiteLLM returns compatible type but mypy can't verify external library
return await litellm.acompletion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs
result = await litellm.acompletion(**request_params)

if params.stream:
return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types

return result # type: ignore[return-value] # external lib lacks type stubs

async def check_model_availability(self, model: str) -> bool:
"""
Expand Down
14 changes: 6 additions & 8 deletions src/llama_stack/providers/utils/inference/openai_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,30 +248,28 @@ async def _get_provider_model_id(self, model: str) -> str:
return model_obj.provider_resource_id

async def _maybe_overwrite_id(self, resp: Any, stream: bool | None) -> Any:
if not self.overwrite_completion_id:
return resp

new_id = f"cltsd-{uuid.uuid4()}"
if stream:
new_id = f"cltsd-{uuid.uuid4()}" if self.overwrite_completion_id else None

async def _gen():
async for chunk in resp:
chunk.id = new_id
if new_id:
chunk.id = new_id
yield chunk

return _gen()
else:
resp.id = new_id
if self.overwrite_completion_id:
resp.id = f"cltsd-{uuid.uuid4()}"
return resp

async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
"""
Direct OpenAI completion API call.
"""
# TODO: fix openai_completion to return type compatible with OpenAI's API response
provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id)

Expand Down
23 changes: 23 additions & 0 deletions src/llama_stack/providers/utils/inference/stream_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from collections.abc import AsyncIterator

from llama_stack.log import get_logger

log = get_logger(name=__name__, category="providers::utils")


async def wrap_async_stream[T](stream: AsyncIterator[T]) -> AsyncIterator[T]:
"""
Wrap an async stream to ensure it returns a proper AsyncIterator.
"""
try:
async for item in stream:
yield item
except Exception as e:
log.error(f"Error in wrapped async stream: {e}")
raise
6 changes: 3 additions & 3 deletions src/llama_stack_api/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,11 +1020,11 @@ async def rerank(
async def openai_completion(
self,
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
) -> OpenAICompletion:
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
"""Create completion.

Generate an OpenAI-compatible completion for the given prompt using the specified model.
:returns: An OpenAICompletion.
:returns: An OpenAICompletion. When streaming, returns Server-Sent Events (SSE) with OpenAICompletion chunks.
"""
...

Expand All @@ -1036,7 +1036,7 @@ async def openai_chat_completion(
"""Create chat completions.

Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:returns: An OpenAIChatCompletion.
:returns: An OpenAIChatCompletion. When streaming, returns Server-Sent Events (SSE) with OpenAIChatCompletionChunk objects.
"""
...

Expand Down
Loading
Loading