Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions src/llama_stack/core/routers/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ async def rerank(
async def openai_completion(
self,
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
) -> OpenAICompletion:
) -> OpenAICompletion | AsyncIterator[Any]:
logger.debug(
f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}",
)
Expand All @@ -185,9 +185,12 @@ async def openai_completion(
params.model = provider_resource_id

if params.stream:
return await provider.openai_completion(params)
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
response_stream = await provider.openai_completion(params)
return self.wrap_completion_stream_with_metrics(
response=response_stream,
fully_qualified_model_id=request_model_id,
provider_id=provider.__provider_id__,
)

response = await provider.openai_completion(params)
response.model = request_model_id
Expand Down Expand Up @@ -412,16 +415,17 @@ async def stream_tokens_and_compute_metrics_openai_chat(
completion_text += "".join(choice_data["content_parts"])

# Add metrics to the chunk
if self.telemetry_enabled and hasattr(chunk, "usage") and chunk.usage:
metrics = self._construct_metrics(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
fully_qualified_model_id=fully_qualified_model_id,
provider_id=provider_id,
)
for metric in metrics:
enqueue_event(metric)
if self.telemetry_enabled:
if hasattr(chunk, "usage") and chunk.usage:
metrics = self._construct_metrics(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
fully_qualified_model_id=fully_qualified_model_id,
provider_id=provider_id,
)
for metric in metrics:
enqueue_event(metric)

yield chunk
finally:
Expand Down Expand Up @@ -471,3 +475,31 @@ async def stream_tokens_and_compute_metrics_openai_chat(
)
logger.debug(f"InferenceRouter.completion_response: {final_response}")
asyncio.create_task(self.store.store_chat_completion(final_response, messages))

async def wrap_completion_stream_with_metrics(
self,
response: AsyncIterator,
fully_qualified_model_id: str,
provider_id: str,
) -> AsyncIterator:
"""Stream OpenAI completion chunks and compute metrics on final chunk."""

async for chunk in response:
if hasattr(chunk, "model"):
chunk.model = fully_qualified_model_id

if getattr(chunk, "choices", None) and any(c.finish_reason for c in chunk.choices):
if self.telemetry_enabled:
if getattr(chunk, "usage", None):
usage = chunk.usage
metrics = self._construct_metrics(
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens,
fully_qualified_model_id=fully_qualified_model_id,
provider_id=provider_id,
)
for metric in metrics:
enqueue_event(metric)

yield chunk
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from openai import AuthenticationError

from llama_stack.core.telemetry.tracing import get_current_span
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack_api import (
Expand Down Expand Up @@ -82,14 +81,7 @@ async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Override to enable streaming usage metrics and handle authentication errors."""
# Enable streaming usage metrics when telemetry is active
if params.stream and get_current_span() is not None:
if params.stream_options is None:
params.stream_options = {"include_usage": True}
elif "include_usage" not in params.stream_options:
params.stream_options = {**params.stream_options, "include_usage": True}

"""Override to handle authentication errors and null responses."""
try:
logger.debug(f"Calling Bedrock OpenAI API with model={params.model}, stream={params.stream}")
result = await super().openai_chat_completion(params=params)
Expand Down
19 changes: 0 additions & 19 deletions src/llama_stack/providers/remote/inference/runpod/runpod.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from collections.abc import AsyncIterator

from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack_api import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
)

from .config import RunpodImplConfig

Expand All @@ -29,15 +22,3 @@ class RunpodInferenceAdapter(OpenAIMixin):
def get_base_url(self) -> str:
"""Get base URL for OpenAI client."""
return str(self.config.base_url)

async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Override to add RunPod-specific stream_options requirement."""
params = params.model_copy()

if params.stream and not params.stream_options:
params.stream_options = {"include_usage": True}

return await super().openai_chat_completion(params)
12 changes: 1 addition & 11 deletions src/llama_stack/providers/remote/inference/watsonx/watsonx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import litellm
import requests

from llama_stack.core.telemetry.tracing import get_current_span
from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
Expand Down Expand Up @@ -56,15 +55,6 @@ async def openai_chat_completion(
Override parent method to add timeout and inject usage object when missing.
This works around a LiteLLM defect where usage block is sometimes dropped.
"""

# Add usage tracking for streaming when telemetry is active
stream_options = params.stream_options
if params.stream and get_current_span() is not None:
if stream_options is None:
stream_options = {"include_usage": True}
elif "include_usage" not in stream_options:
stream_options = {**stream_options, "include_usage": True}

model_obj = await self.model_store.get_model(params.model)

request_params = await prepare_openai_completion_params(
Expand All @@ -84,7 +74,7 @@ async def openai_chat_completion(
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=stream_options,
stream_options=params.stream_options,
temperature=params.temperature,
tool_choice=params.tool_choice,
tools=params.tools,
Expand Down
20 changes: 20 additions & 0 deletions src/llama_stack/providers/utils/inference/openai_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,16 @@ async def openai_completion(
"""
Direct OpenAI completion API call.
"""
from llama_stack.core.telemetry.tracing import get_current_span

# inject if streaming AND telemetry active
if params.stream and get_current_span() is not None:
params = params.model_copy()
if params.stream_options is None:
params.stream_options = {"include_usage": True}
elif "include_usage" not in params.stream_options:
params.stream_options = {**params.stream_options, "include_usage": True}

# TODO: fix openai_completion to return type compatible with OpenAI's API response
provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id)
Expand Down Expand Up @@ -308,6 +318,16 @@ async def openai_chat_completion(
"""
Direct OpenAI chat completion API call.
"""
from llama_stack.core.telemetry.tracing import get_current_span

# inject if streaming AND telemetry active
if params.stream and get_current_span() is not None:
params = params.model_copy()
if params.stream_options is None:
params.stream_options = {"include_usage": True}
elif "include_usage" not in params.stream_options:
params.stream_options = {**params.stream_options, "include_usage": True}

provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id)

Expand Down
143 changes: 143 additions & 0 deletions tests/unit/providers/utils/inference/test_openai_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,3 +934,146 @@ async def test_inference_with_no_restrictions(self, mixin, mock_client_context):
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)


class TestOpenAIMixinStreamingMetrics:
"""Test cases for streaming metrics injection in OpenAIMixin"""

async def test_openai_chat_completion_streaming_metrics_injection(self, mixin, mock_client_context):
"""Test that stream_options={"include_usage": True} is injected when streaming and telemetry is enabled"""

params = OpenAIChatCompletionRequestWithExtraBody(
model="test-model",
messages=[{"role": "user", "content": "hello"}],
stream=True,
stream_options=None,
)

mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())

with mock_client_context(mixin, mock_client):
with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span:
mock_get_span.return_value = MagicMock()

with patch(
"llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params"
) as mock_prepare:
mock_prepare.return_value = {"model": "test-model"}

await mixin.openai_chat_completion(params)

call_kwargs = mock_prepare.call_args.kwargs
assert call_kwargs["stream_options"] == {"include_usage": True}

assert params.stream_options is None

async def test_openai_chat_completion_streaming_no_telemetry(self, mixin, mock_client_context):
"""Test that stream_options is NOT injected when telemetry is disabled"""

params = OpenAIChatCompletionRequestWithExtraBody(
model="test-model",
messages=[{"role": "user", "content": "hello"}],
stream=True,
stream_options=None,
)

mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())

with mock_client_context(mixin, mock_client):
with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span:
mock_get_span.return_value = None

with patch(
"llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params"
) as mock_prepare:
mock_prepare.return_value = {"model": "test-model"}

await mixin.openai_chat_completion(params)

call_kwargs = mock_prepare.call_args.kwargs
assert call_kwargs["stream_options"] is None

async def test_openai_completion_streaming_metrics_injection(self, mixin, mock_client_context):
"""Test that stream_options={"include_usage": True} is injected for legacy completion"""

params = OpenAICompletionRequestWithExtraBody(
model="test-model",
prompt="hello",
stream=True,
stream_options=None,
)

mock_client = MagicMock()
mock_client.completions.create = AsyncMock(return_value=MagicMock())

with mock_client_context(mixin, mock_client):
with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span:
mock_get_span.return_value = MagicMock()

with patch(
"llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params"
) as mock_prepare:
mock_prepare.return_value = {"model": "test-model"}

await mixin.openai_completion(params)

call_kwargs = mock_prepare.call_args.kwargs
assert call_kwargs["stream_options"] == {"include_usage": True}
assert params.stream_options is None

async def test_preserves_existing_stream_options(self, mixin, mock_client_context):
"""Test that existing stream_options are preserved and merged"""

params = OpenAIChatCompletionRequestWithExtraBody(
model="test-model",
messages=[{"role": "user", "content": "hello"}],
stream=True,
stream_options={"include_usage": False},
)

mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())

with mock_client_context(mixin, mock_client):
with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span:
mock_get_span.return_value = MagicMock()

with patch(
"llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params"
) as mock_prepare:
mock_prepare.return_value = {"model": "test-model"}

await mixin.openai_chat_completion(params)

call_kwargs = mock_prepare.call_args.kwargs
# It should stay False because it was present
assert call_kwargs["stream_options"] == {"include_usage": False}

async def test_merges_existing_stream_options(self, mixin, mock_client_context):
"""Test that existing stream_options are merged"""

params = OpenAIChatCompletionRequestWithExtraBody(
model="test-model",
messages=[{"role": "user", "content": "hello"}],
stream=True,
stream_options={"other_option": True},
)

mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())

with mock_client_context(mixin, mock_client):
with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span:
mock_get_span.return_value = MagicMock()

with patch(
"llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params"
) as mock_prepare:
mock_prepare.return_value = {"model": "test-model"}

await mixin.openai_chat_completion(params)

call_kwargs = mock_prepare.call_args.kwargs
assert call_kwargs["stream_options"] == {"other_option": True, "include_usage": True}