Skip to content
Merged
7 changes: 5 additions & 2 deletions client-sdks/stainless/openapi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ paths:
post:
responses:
'200':
description: An OpenAIChatCompletion.
description: An OpenAIChatCompletion. When streaming, returns Server-Sent Events (SSE) with OpenAIChatCompletionChunk objects.
content:
application/json:
schema:
Expand Down Expand Up @@ -301,11 +301,14 @@ paths:
post:
responses:
'200':
description: An OpenAICompletion.
description: An OpenAICompletion. When streaming, returns Server-Sent Events (SSE) with OpenAICompletion chunks.
content:
application/json:
schema:
$ref: '#/components/schemas/OpenAICompletion'
text/event-stream:
schema:
$ref: '#/components/schemas/OpenAICompletion'
'400':
description: Bad Request
$ref: '#/components/responses/BadRequest400'
Expand Down
7 changes: 5 additions & 2 deletions docs/static/llama-stack-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ paths:
post:
responses:
'200':
description: An OpenAIChatCompletion.
description: An OpenAIChatCompletion. When streaming, returns Server-Sent Events (SSE) with OpenAIChatCompletionChunk objects.
content:
application/json:
schema:
Expand Down Expand Up @@ -299,11 +299,14 @@ paths:
post:
responses:
'200':
description: An OpenAICompletion.
description: An OpenAICompletion. When streaming, returns Server-Sent Events (SSE) with OpenAICompletion chunks.
content:
application/json:
schema:
$ref: '#/components/schemas/OpenAICompletion'
text/event-stream:
schema:
$ref: '#/components/schemas/OpenAICompletion'
'400':
description: Bad Request
$ref: '#/components/responses/BadRequest400'
Expand Down
7 changes: 5 additions & 2 deletions docs/static/stainless-llama-stack-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ paths:
post:
responses:
'200':
description: An OpenAIChatCompletion.
description: An OpenAIChatCompletion. When streaming, returns Server-Sent Events (SSE) with OpenAIChatCompletionChunk objects.
content:
application/json:
schema:
Expand Down Expand Up @@ -301,11 +301,14 @@ paths:
post:
responses:
'200':
description: An OpenAICompletion.
description: An OpenAICompletion. When streaming, returns Server-Sent Events (SSE) with OpenAICompletion chunks.
content:
application/json:
schema:
$ref: '#/components/schemas/OpenAICompletion'
text/event-stream:
schema:
$ref: '#/components/schemas/OpenAICompletion'
'400':
description: Bad Request
$ref: '#/components/responses/BadRequest400'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ async def shutdown(self) -> None:
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
raise NotImplementedError("OpenAI completion not supported by meta reference provider")

async def should_refresh_models(self) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def openai_embeddings(
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
"""Bedrock's OpenAI-compatible API does not support the /v1/completions endpoint."""
raise NotImplementedError(
"Bedrock's OpenAI-compatible API does not support /v1/completions endpoint. "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from collections.abc import Iterable
from collections.abc import AsyncIterator, Iterable

from databricks.sdk import WorkspaceClient

Expand Down Expand Up @@ -50,5 +50,5 @@ async def list_provider_model_ids(self) -> Iterable[str]:
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
raise NotImplementedError()
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from collections.abc import AsyncIterator

from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
Expand Down Expand Up @@ -36,7 +38,7 @@ def get_base_url(self) -> str:
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
raise NotImplementedError()

async def openai_embeddings(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from openai import AsyncOpenAI

from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
from llama_stack_api import (
Inference,
Model,
Expand Down Expand Up @@ -107,12 +108,16 @@ def _get_passthrough_api_key(self) -> str:
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
"""Forward completion request to downstream using OpenAI client."""
client = self._get_openai_client()
request_params = params.model_dump(exclude_none=True)
response = await client.completions.create(**request_params)
return response # type: ignore

if params.stream:
return wrap_async_stream(response)

return response # type: ignore[return-value]

async def openai_chat_completion(
self,
Expand All @@ -122,7 +127,11 @@ async def openai_chat_completion(
client = self._get_openai_client()
request_params = params.model_dump(exclude_none=True)
response = await client.chat.completions.create(**request_params)
return response # type: ignore

if params.stream:
return wrap_async_stream(response)

return response # type: ignore[return-value]

async def openai_embeddings(
self,
Expand Down
10 changes: 8 additions & 2 deletions src/llama_stack/providers/remote/inference/watsonx/watsonx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
from llama_stack_api import (
Model,
ModelType,
Expand Down Expand Up @@ -177,7 +178,7 @@ async def _normalize_stream(
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
"""
Override parent method to add watsonx-specific parameters.
"""
Expand Down Expand Up @@ -210,7 +211,12 @@ async def openai_completion(
timeout=self.config.timeout,
project_id=self.config.project_id,
)
return await litellm.atext_completion(**request_params)
result = await litellm.atext_completion(**request_params)

if params.stream:
return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types

return result # type: ignore[return-value] # external lib lacks type stubs

async def openai_embeddings(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from llama_stack.providers.utils.inference.openai_compat import (
prepare_openai_completion_params,
)
from llama_stack.providers.utils.inference.stream_utils import wrap_async_stream
from llama_stack_api import (
InferenceProvider,
OpenAIChatCompletion,
Expand Down Expand Up @@ -178,7 +179,7 @@ async def openai_embeddings(
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
if not self.model_store:
raise ValueError("Model store is not initialized")

Expand Down Expand Up @@ -210,7 +211,12 @@ async def openai_completion(
api_base=self.api_base,
)
# LiteLLM returns compatible type but mypy can't verify external library
return await litellm.atext_completion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs
result = await litellm.atext_completion(**request_params)

if params.stream:
return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types

return result # type: ignore[return-value] # external lib lacks type stubs

async def openai_chat_completion(
self,
Expand Down Expand Up @@ -261,7 +267,12 @@ async def openai_chat_completion(
api_base=self.api_base,
)
# LiteLLM returns compatible type but mypy can't verify external library
return await litellm.acompletion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs
result = await litellm.acompletion(**request_params)

if params.stream:
return wrap_async_stream(result) # type: ignore[arg-type] # LiteLLM streaming types

return result # type: ignore[return-value] # external lib lacks type stubs

async def check_model_availability(self, model: str) -> bool:
"""
Expand Down
14 changes: 6 additions & 8 deletions src/llama_stack/providers/utils/inference/openai_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,30 +248,28 @@ async def _get_provider_model_id(self, model: str) -> str:
return model_obj.provider_resource_id

async def _maybe_overwrite_id(self, resp: Any, stream: bool | None) -> Any:
if not self.overwrite_completion_id:
return resp

new_id = f"cltsd-{uuid.uuid4()}"
if stream:
new_id = f"cltsd-{uuid.uuid4()}" if self.overwrite_completion_id else None

async def _gen():
async for chunk in resp:
chunk.id = new_id
if new_id:
chunk.id = new_id
yield chunk

return _gen()
else:
resp.id = new_id
if self.overwrite_completion_id:
resp.id = f"cltsd-{uuid.uuid4()}"
return resp

async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
"""
Direct OpenAI completion API call.
"""
# TODO: fix openai_completion to return type compatible with OpenAI's API response
provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id)

Expand Down
23 changes: 23 additions & 0 deletions src/llama_stack/providers/utils/inference/stream_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from collections.abc import AsyncIterator

from llama_stack.log import get_logger

log = get_logger(name=__name__, category="providers::utils")


async def wrap_async_stream[T](stream: AsyncIterator[T]) -> AsyncIterator[T]:
"""
Wrap an async stream to ensure it returns a proper AsyncIterator.
"""
try:
async for item in stream:
yield item
except Exception as e:
log.error(f"Error in wrapped async stream: {e}")
raise
6 changes: 3 additions & 3 deletions src/llama_stack_api/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,11 +1020,11 @@ async def rerank(
async def openai_completion(
self,
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
) -> OpenAICompletion:
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
"""Create completion.

Generate an OpenAI-compatible completion for the given prompt using the specified model.
:returns: An OpenAICompletion.
:returns: An OpenAICompletion. When streaming, returns Server-Sent Events (SSE) with OpenAICompletion chunks.
"""
...

Expand All @@ -1036,7 +1036,7 @@ async def openai_chat_completion(
"""Create chat completions.

Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:returns: An OpenAIChatCompletion.
:returns: An OpenAIChatCompletion. When streaming, returns Server-Sent Events (SSE) with OpenAIChatCompletionChunk objects.
"""
...

Expand Down
Loading
Loading