From f6080040da514a04eb80bdb2a281e17236725512 Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Thu, 2 Oct 2025 18:38:30 -0700 Subject: [PATCH 01/11] Code refactoring and removing dead code --- llama_stack/apis/inference/inference.py | 129 +++++++---- llama_stack/core/routers/inference.py | 202 +++++++++++++----- .../sentence_transformers.py | 8 +- .../remote/inference/cerebras/cerebras.py | 28 +-- .../remote/inference/runpod/runpod.py | 19 +- .../remote/inference/watsonx/watsonx.py | 53 +++-- 6 files changed, 302 insertions(+), 137 deletions(-) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 829a94a6a8..d149a4dc2a 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -6,16 +6,7 @@ from collections.abc import AsyncIterator from enum import Enum -from typing import ( - Annotated, - Any, - Literal, - Protocol, - runtime_checkable, -) - -from pydantic import BaseModel, Field, field_validator -from typing_extensions import TypedDict +from typing import Annotated, Any, Literal, Protocol, runtime_checkable from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent from llama_stack.apis.common.responses import Order @@ -32,6 +23,9 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, register_schema, webmethod +from pydantic import BaseModel, Field, field_validator +from typing_extensions import TypedDict + register_schema(ToolCall) register_schema(ToolDefinition) @@ -357,32 +351,32 @@ class CompletionRequest(BaseModel): logprobs: LogProbConfig | None = None -@json_schema_type -class CompletionResponse(MetricResponseMixin): - """Response from a completion request. +# @json_schema_type +# class CompletionResponse(MetricResponseMixin): +# """Response from a completion request. - :param content: The generated completion text - :param stop_reason: Reason why generation stopped - :param logprobs: Optional log probabilities for generated tokens - """ +# :param content: The generated completion text +# :param stop_reason: Reason why generation stopped +# :param logprobs: Optional log probabilities for generated tokens +# """ - content: str - stop_reason: StopReason - logprobs: list[TokenLogProbs] | None = None +# content: str +# stop_reason: StopReason +# logprobs: list[TokenLogProbs] | None = None -@json_schema_type -class CompletionResponseStreamChunk(MetricResponseMixin): - """A chunk of a streamed completion response. +# @json_schema_type +# class CompletionResponseStreamChunk(MetricResponseMixin): +# """A chunk of a streamed completion response. - :param delta: New content generated since last chunk. This can be one or more tokens. - :param stop_reason: Optional reason why generation stopped, if complete - :param logprobs: Optional log probabilities for generated tokens - """ +# :param delta: New content generated since last chunk. This can be one or more tokens. +# :param stop_reason: Optional reason why generation stopped, if complete +# :param logprobs: Optional log probabilities for generated tokens +# """ - delta: str - stop_reason: StopReason | None = None - logprobs: list[TokenLogProbs] | None = None +# delta: str +# stop_reason: StopReason | None = None +# logprobs: list[TokenLogProbs] | None = None class SystemMessageBehavior(Enum): @@ -415,7 +409,9 @@ class ToolConfig(BaseModel): tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto) tool_prompt_format: ToolPromptFormat | None = Field(default=None) - system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append) + system_message_behavior: SystemMessageBehavior | None = Field( + default=SystemMessageBehavior.append + ) def model_post_init(self, __context: Any) -> None: if isinstance(self.tool_choice, str): @@ -544,15 +540,21 @@ class OpenAIFile(BaseModel): OpenAIChatCompletionContentPartParam = Annotated[ - OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam | OpenAIFile, + OpenAIChatCompletionContentPartTextParam + | OpenAIChatCompletionContentPartImageParam + | OpenAIFile, Field(discriminator="type"), ] -register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam") +register_schema( + OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam" +) OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam] -OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam] +OpenAIChatCompletionTextOnlyMessageContent = ( + str | list[OpenAIChatCompletionContentPartTextParam] +) @json_schema_type @@ -720,7 +722,9 @@ class OpenAIResponseFormatJSONObject(BaseModel): OpenAIResponseFormatParam = Annotated[ - OpenAIResponseFormatText | OpenAIResponseFormatJSONSchema | OpenAIResponseFormatJSONObject, + OpenAIResponseFormatText + | OpenAIResponseFormatJSONSchema + | OpenAIResponseFormatJSONObject, Field(discriminator="type"), ] register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam") @@ -1049,8 +1053,16 @@ async def chat_completion( async def rerank( self, model: str, - query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, - items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], + query: ( + str + | OpenAIChatCompletionContentPartTextParam + | OpenAIChatCompletionContentPartImageParam + ), + items: list[ + str + | OpenAIChatCompletionContentPartTextParam + | OpenAIChatCompletionContentPartImageParam + ], max_num_results: int | None = None, ) -> RerankResponse: """Rerank a list of documents based on their relevance to a query. @@ -1064,7 +1076,12 @@ async def rerank( raise NotImplementedError("Reranking is not implemented") return # this is so mypy's safe-super rule will consider the method concrete - @webmethod(route="/openai/v1/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) + @webmethod( + route="/openai/v1/completions", + method="POST", + level=LLAMA_STACK_API_V1, + deprecated=True, + ) @webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1) async def openai_completion( self, @@ -1116,7 +1133,12 @@ async def openai_completion( """ ... - @webmethod(route="/openai/v1/chat/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) + @webmethod( + route="/openai/v1/chat/completions", + method="POST", + level=LLAMA_STACK_API_V1, + deprecated=True, + ) @webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1) async def openai_chat_completion( self, @@ -1173,7 +1195,12 @@ async def openai_chat_completion( """ ... - @webmethod(route="/openai/v1/embeddings", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) + @webmethod( + route="/openai/v1/embeddings", + method="POST", + level=LLAMA_STACK_API_V1, + deprecated=True, + ) @webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1) async def openai_embeddings( self, @@ -1203,7 +1230,12 @@ class Inference(InferenceProvider): - Embedding models: these models generate embeddings to be used for semantic search. """ - @webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True) + @webmethod( + route="/openai/v1/chat/completions", + method="GET", + level=LLAMA_STACK_API_V1, + deprecated=True, + ) @webmethod(route="/chat/completions", method="GET", level=LLAMA_STACK_API_V1) async def list_chat_completions( self, @@ -1223,10 +1255,19 @@ async def list_chat_completions( raise NotImplementedError("List chat completions is not implemented") @webmethod( - route="/openai/v1/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True + route="/openai/v1/chat/completions/{completion_id}", + method="GET", + level=LLAMA_STACK_API_V1, + deprecated=True, + ) + @webmethod( + route="/chat/completions/{completion_id}", + method="GET", + level=LLAMA_STACK_API_V1, ) - @webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1) - async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages: + async def get_chat_completion( + self, completion_id: str + ) -> OpenAICompletionWithInputMessages: """Describe a chat completion by its ID. :param completion_id: ID of the chat completion. diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 4b004a82c1..7f0e4a3520 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -7,24 +7,16 @@ import asyncio import time from collections.abc import AsyncGenerator, AsyncIterator -from datetime import UTC, datetime +from datetime import datetime, UTC from typing import Annotated, Any -from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam -from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam -from pydantic import Field, TypeAdapter - -from llama_stack.apis.common.content_types import ( - InterleavedContent, -) +from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError from llama_stack.apis.inference import ( ChatCompletionResponse, ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, CompletionMessage, - CompletionResponse, - CompletionResponseStreamChunk, Inference, ListOpenAIChatCompletionResponse, LogProbConfig, @@ -57,7 +49,16 @@ from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable from llama_stack.providers.utils.inference.inference_store import InferenceStore -from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span +from llama_stack.providers.utils.telemetry.tracing import ( + enqueue_event, + get_current_span, +) + +from openai.types.chat import ( + ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam, + ChatCompletionToolParam as OpenAIChatCompletionToolParam, +) +from pydantic import Field, TypeAdapter logger = get_logger(name=__name__, category="core::routers") @@ -101,7 +102,9 @@ async def register_model( logger.debug( f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", ) - await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) + await self.routing_table.register_model( + model_id, provider_model_id, provider_id, metadata, model_type + ) def _construct_metrics( self, @@ -156,11 +159,16 @@ async def _compute_and_log_token_usage( total_tokens: int, model: Model, ) -> list[MetricInResponse]: - metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) + metrics = self._construct_metrics( + prompt_tokens, completion_tokens, total_tokens, model + ) if self.telemetry: for metric in metrics: enqueue_event(metric) - return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] + return [ + MetricInResponse(metric=metric.metric, value=metric.value) + for metric in metrics + ] async def _count_tokens( self, @@ -207,8 +215,13 @@ async def chat_completion( if tool_config: if tool_choice and tool_choice != tool_config.tool_choice: raise ValueError("tool_choice and tool_config.tool_choice must match") - if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format: - raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match") + if ( + tool_prompt_format + and tool_prompt_format != tool_config.tool_prompt_format + ): + raise ValueError( + "tool_prompt_format and tool_config.tool_prompt_format must match" + ) else: params = {} if tool_choice: @@ -226,9 +239,14 @@ async def chat_completion( pass else: # verify tool_choice is one of the tools - tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools] + tool_names = [ + t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value + for t in tools + ] if tool_config.tool_choice not in tool_names: - raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}") + raise ValueError( + f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}" + ) params = dict( model_id=model_id, @@ -243,7 +261,9 @@ async def chat_completion( tool_config=tool_config, ) provider = await self.routing_table.get_provider_impl(model_id) - prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) + prompt_tokens = await self._count_tokens( + messages, tool_config.tool_prompt_format + ) if stream: response_stream = await provider.chat_completion(**params) @@ -263,7 +283,9 @@ async def chat_completion( ) # these metrics will show up in the client response. response.metrics = ( - metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics + metrics + if not hasattr(response, "metrics") or response.metrics is None + else response.metrics + metrics ) return response @@ -336,7 +358,9 @@ async def openai_completion( # these metrics will show up in the client response. response.metrics = ( - metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics + metrics + if not hasattr(response, "metrics") or response.metrics is None + else response.metrics + metrics ) return response @@ -374,9 +398,13 @@ async def openai_chat_completion( # Use the OpenAI client for a bit of extra input validation without # exposing the OpenAI client itself as part of our API surface if tool_choice: - TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice) + TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python( + tool_choice + ) if tools is None: - raise ValueError("'tool_choice' is only allowed when 'tools' is also provided") + raise ValueError( + "'tool_choice' is only allowed when 'tools' is also provided" + ) if tools: for tool in tools: TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool) @@ -441,7 +469,9 @@ async def openai_chat_completion( enqueue_event(metric) # these metrics will show up in the client response. response.metrics = ( - metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics + metrics + if not hasattr(response, "metrics") or response.metrics is None + else response.metrics + metrics ) return response @@ -477,19 +507,31 @@ async def list_chat_completions( ) -> ListOpenAIChatCompletionResponse: if self.store: return await self.store.list_chat_completions(after, limit, model, order) - raise NotImplementedError("List chat completions is not supported: inference store is not configured.") + raise NotImplementedError( + "List chat completions is not supported: inference store is not configured." + ) - async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages: + async def get_chat_completion( + self, completion_id: str + ) -> OpenAICompletionWithInputMessages: if self.store: return await self.store.get_chat_completion(completion_id) - raise NotImplementedError("Get chat completion is not supported: inference store is not configured.") + raise NotImplementedError( + "Get chat completion is not supported: inference store is not configured." + ) - async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion: + async def _nonstream_openai_chat_completion( + self, provider: Inference, params: dict + ) -> OpenAIChatCompletion: response = await provider.openai_chat_completion(**params) for choice in response.choices: # some providers return an empty list for no tool calls in non-streaming responses # but the OpenAI API returns None. So, set tool_calls to None if it's empty - if choice.message and choice.message.tool_calls is not None and len(choice.message.tool_calls) == 0: + if ( + choice.message + and choice.message.tool_calls is not None + and len(choice.message.tool_calls) == 0 + ): choice.message.tool_calls = None return response @@ -509,7 +551,9 @@ async def health(self) -> dict[str, HealthResponse]: message=f"Health check timed out after {timeout} seconds", ) except NotImplementedError: - health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED) + health_statuses[provider_id] = HealthResponse( + status=HealthStatus.NOT_IMPLEMENTED + ) except Exception as e: health_statuses[provider_id] = HealthResponse( status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}" @@ -522,7 +566,7 @@ async def stream_tokens_and_compute_metrics( prompt_tokens, model, tool_prompt_format: ToolPromptFormat | None = None, - ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]: + ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: completion_text = "" async for chunk in response: complete = False @@ -544,7 +588,11 @@ async def stream_tokens_and_compute_metrics( else: if hasattr(chunk, "delta"): completion_text += chunk.delta - if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: + if ( + hasattr(chunk, "stop_reason") + and chunk.stop_reason + and self.telemetry + ): complete = True completion_tokens = await self._count_tokens(completion_text) # if we are done receiving tokens @@ -569,9 +617,14 @@ async def stream_tokens_and_compute_metrics( # Return metrics in response async_metrics = [ - MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics + MetricInResponse(metric=metric.metric, value=metric.value) + for metric in completion_metrics ] - chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics + chunk.metrics = ( + async_metrics + if chunk.metrics is None + else chunk.metrics + async_metrics + ) else: # Fallback if no telemetry completion_metrics = self._construct_metrics( @@ -581,14 +634,19 @@ async def stream_tokens_and_compute_metrics( model, ) async_metrics = [ - MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics + MetricInResponse(metric=metric.metric, value=metric.value) + for metric in completion_metrics ] - chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics + chunk.metrics = ( + async_metrics + if chunk.metrics is None + else chunk.metrics + async_metrics + ) yield chunk async def count_tokens_and_compute_metrics( self, - response: ChatCompletionResponse | CompletionResponse, + response: ChatCompletionResponse, prompt_tokens, model, tool_prompt_format: ToolPromptFormat | None = None, @@ -597,7 +655,9 @@ async def count_tokens_and_compute_metrics( content = [response.completion_message] else: content = response.content - completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format) + completion_tokens = await self._count_tokens( + messages=content, tool_prompt_format=tool_prompt_format + ) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) # Create a separate span for completion metrics @@ -610,11 +670,17 @@ async def count_tokens_and_compute_metrics( model=model, ) for metric in completion_metrics: - if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens + if metric.metric in [ + "completion_tokens", + "total_tokens", + ]: # Only log completion and total tokens enqueue_event(metric) # Return metrics in response - return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics] + return [ + MetricInResponse(metric=metric.metric, value=metric.value) + for metric in completion_metrics + ] # Fallback if no telemetry metrics = self._construct_metrics( @@ -623,7 +689,10 @@ async def count_tokens_and_compute_metrics( total_tokens, model, ) - return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] + return [ + MetricInResponse(metric=metric.metric, value=metric.value) + for metric in metrics + ] async def stream_tokens_and_compute_metrics_openai_chat( self, @@ -664,33 +733,48 @@ async def stream_tokens_and_compute_metrics_openai_chat( if choice_delta.delta: delta = choice_delta.delta if delta.content: - current_choice_data["content_parts"].append(delta.content) + current_choice_data["content_parts"].append( + delta.content + ) if delta.tool_calls: for tool_call_delta in delta.tool_calls: tc_idx = tool_call_delta.index - if tc_idx not in current_choice_data["tool_calls_builder"]: - current_choice_data["tool_calls_builder"][tc_idx] = { + if ( + tc_idx + not in current_choice_data["tool_calls_builder"] + ): + current_choice_data["tool_calls_builder"][ + tc_idx + ] = { "id": None, "type": "function", "function_name_parts": [], "function_arguments_parts": [], } - builder = current_choice_data["tool_calls_builder"][tc_idx] + builder = current_choice_data["tool_calls_builder"][ + tc_idx + ] if tool_call_delta.id: builder["id"] = tool_call_delta.id if tool_call_delta.type: builder["type"] = tool_call_delta.type if tool_call_delta.function: if tool_call_delta.function.name: - builder["function_name_parts"].append(tool_call_delta.function.name) + builder["function_name_parts"].append( + tool_call_delta.function.name + ) if tool_call_delta.function.arguments: builder["function_arguments_parts"].append( tool_call_delta.function.arguments ) if choice_delta.finish_reason: - current_choice_data["finish_reason"] = choice_delta.finish_reason + current_choice_data["finish_reason"] = ( + choice_delta.finish_reason + ) if choice_delta.logprobs and choice_delta.logprobs.content: - current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content) + current_choice_data["logprobs_content_parts"].extend( + choice_delta.logprobs.content + ) # Compute metrics on final chunk if chunk.choices and chunk.choices[0].finish_reason: @@ -720,8 +804,12 @@ async def stream_tokens_and_compute_metrics_openai_chat( if choice_data["tool_calls_builder"]: for tc_build_data in choice_data["tool_calls_builder"].values(): if tc_build_data["id"]: - func_name = "".join(tc_build_data["function_name_parts"]) - func_args = "".join(tc_build_data["function_arguments_parts"]) + func_name = "".join( + tc_build_data["function_name_parts"] + ) + func_args = "".join( + tc_build_data["function_arguments_parts"] + ) assembled_tool_calls.append( OpenAIChatCompletionToolCall( id=tc_build_data["id"], @@ -734,10 +822,16 @@ async def stream_tokens_and_compute_metrics_openai_chat( message = OpenAIAssistantMessageParam( role="assistant", content=content_str if content_str else None, - tool_calls=assembled_tool_calls if assembled_tool_calls else None, + tool_calls=( + assembled_tool_calls if assembled_tool_calls else None + ), ) logprobs_content = choice_data["logprobs_content_parts"] - final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None + final_logprobs = ( + OpenAIChoiceLogprobs(content=logprobs_content) + if logprobs_content + else None + ) assembled_choices.append( OpenAIChoice( @@ -756,4 +850,6 @@ async def stream_tokens_and_compute_metrics_openai_chat( object="chat.completion", ) logger.debug(f"InferenceRouter.completion_response: {final_response}") - asyncio.create_task(self.store.store_chat_completion(final_response, messages)) + asyncio.create_task( + self.store.store_chat_completion(final_response, messages) + ) diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index cd682dca6c..b975fb13f4 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -25,9 +25,6 @@ from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, ) -from llama_stack.providers.utils.inference.openai_compat import ( - OpenAIChatCompletionToLlamaStackMixin, -) from .config import SentenceTransformersInferenceConfig @@ -35,7 +32,6 @@ class SentenceTransformersInferenceImpl( - OpenAIChatCompletionToLlamaStackMixin, SentenceTransformerEmbeddingMixin, InferenceProvider, ModelsProtocolPrivate, @@ -114,4 +110,6 @@ async def openai_completion( # for fill-in-the-middle type completion suffix: str | None = None, ) -> OpenAICompletion: - raise NotImplementedError("OpenAI completion not supported by sentence transformers provider") + raise NotImplementedError( + "OpenAI completion not supported by sentence transformers provider" + ) diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 95da71de85..d1ddf96704 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -11,8 +11,7 @@ from llama_stack.apis.inference import ( ChatCompletionRequest, - CompletionRequest, - CompletionResponse, + ChatCompletionResponse, Inference, LogProbConfig, Message, @@ -25,9 +24,7 @@ ToolPromptFormat, TopKSamplingStrategy, ) -from llama_stack.providers.utils.inference.model_registry import ( - ModelRegistryHelper, -) +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, process_chat_completion_response, @@ -36,7 +33,6 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, - completion_request_to_prompt, ) from .config import CerebrasImplConfig @@ -102,14 +98,18 @@ async def chat_completion( else: return await self._nonstream_chat_completion(request) - async def _nonstream_chat_completion(self, request: CompletionRequest) -> CompletionResponse: + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest + ) -> ChatCompletionResponse: params = await self._get_params(request) r = await self._cerebras_client.completions.create(**params) return process_chat_completion_response(r, request) - async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator: + async def _stream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncGenerator: params = await self._get_params(request) stream = await self._cerebras_client.completions.create(**params) @@ -117,15 +117,17 @@ async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGene async for chunk in process_chat_completion_stream_response(stream, request): yield chunk - async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict: - if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy): + async def _get_params(self, request: ChatCompletionRequest) -> dict: + if request.sampling_params and isinstance( + request.sampling_params.strategy, TopKSamplingStrategy + ): raise ValueError("`top_k` not supported by Cerebras") prompt = "" if isinstance(request, ChatCompletionRequest): - prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model)) - elif isinstance(request, CompletionRequest): - prompt = await completion_request_to_prompt(request) + prompt = await chat_completion_request_to_prompt( + request, self.get_llama_model(request.model) + ) else: raise ValueError(f"Unknown request type {type(request)}") diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index 77c5c7187e..15d04d8d62 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -10,11 +10,13 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import OpenAIEmbeddingsResponse -# from llama_stack.providers.datatypes import ModelsProtocolPrivate -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry +from llama_stack.providers.utils.inference.model_registry import ( + build_hf_repo_model_entry, + ModelRegistryHelper, +) from llama_stack.providers.utils.inference.openai_compat import ( - OpenAIChatCompletionToLlamaStackMixin, get_sampling_options, + OpenAIChatCompletionToLlamaStackMixin, process_chat_completion_response, process_chat_completion_stream_response, ) @@ -41,13 +43,12 @@ "Llama3.2-3B": "meta-llama/Llama-3.2-3B", } -SAFETY_MODELS_ENTRIES = [] # Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template MODEL_ENTRIES = [ build_hf_repo_model_entry(provider_model_id, model_descriptor) for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items() -] + SAFETY_MODELS_ENTRIES +] class RunpodInferenceAdapter( @@ -56,7 +57,9 @@ class RunpodInferenceAdapter( OpenAIChatCompletionToLlamaStackMixin, ): def __init__(self, config: RunpodImplConfig) -> None: - ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS) + ModelRegistryHelper.__init__( + self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS + ) self.config = config async def initialize(self) -> None: @@ -103,7 +106,9 @@ async def _nonstream_chat_completion( r = client.completions.create(**params) return process_chat_completion_response(r, request) - async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: + async def _stream_chat_completion( + self, request: ChatCompletionRequest, client: OpenAI + ) -> AsyncGenerator: params = self._get_params(request) async def _to_async_generator(): diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index cb9d611023..9c8831c0d1 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -9,12 +9,10 @@ from ibm_watsonx_ai.foundation_models import Model from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams -from openai import AsyncOpenAI from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, - CompletionRequest, GreedySamplingStrategy, Inference, LogProbConfig, @@ -48,6 +46,7 @@ completion_request_to_prompt, request_has_media, ) +from openai import AsyncOpenAI from . import WatsonXConfig from .models import MODEL_ENTRIES @@ -85,7 +84,9 @@ async def shutdown(self) -> None: pass def _get_client(self, model_id) -> Model: - config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None + config_api_key = ( + self._config.api_key.get_secret_value() if self._config.api_key else None + ) config_url = self._config.url project_id = self._config.project_id credentials = {"url": config_url, "apikey": config_api_key} @@ -132,14 +133,18 @@ async def chat_completion( else: return await self._nonstream_chat_completion(request) - async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest + ) -> ChatCompletionResponse: params = await self._get_params(request) r = self._get_client(request.model).generate(**params) choices = [] if "results" in r: for result in r["results"]: choice = OpenAICompatCompletionChoice( - finish_reason=result["stop_reason"] if result["stop_reason"] else None, + finish_reason=( + result["stop_reason"] if result["stop_reason"] else None + ), text=result["generated_text"], ) choices.append(choice) @@ -148,7 +153,9 @@ async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> Ch ) return process_chat_completion_response(response, request) - async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + async def _stream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncGenerator: params = await self._get_params(request) model_id = request.model @@ -168,28 +175,44 @@ async def _to_async_generator(): async for chunk in process_chat_completion_stream_response(stream, request): yield chunk - async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict: + async def _get_params(self, request: ChatCompletionRequest) -> dict: input_dict = {"params": {}} media_present = request_has_media(request) llama_model = self.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): - input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) + input_dict["prompt"] = await chat_completion_request_to_prompt( + request, llama_model + ) else: - assert not media_present, "Together does not support media for Completion requests" + assert ( + not media_present + ), "Together does not support media for Completion requests" input_dict["prompt"] = await completion_request_to_prompt(request) if request.sampling_params: if request.sampling_params.strategy: - input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type + input_dict["params"][ + GenParams.DECODING_METHOD + ] = request.sampling_params.strategy.type if request.sampling_params.max_tokens: - input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens + input_dict["params"][ + GenParams.MAX_NEW_TOKENS + ] = request.sampling_params.max_tokens if request.sampling_params.repetition_penalty: - input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty + input_dict["params"][ + GenParams.REPETITION_PENALTY + ] = request.sampling_params.repetition_penalty if isinstance(request.sampling_params.strategy, TopPSamplingStrategy): - input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p - input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature + input_dict["params"][ + GenParams.TOP_P + ] = request.sampling_params.strategy.top_p + input_dict["params"][ + GenParams.TEMPERATURE + ] = request.sampling_params.strategy.temperature if isinstance(request.sampling_params.strategy, TopKSamplingStrategy): - input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k + input_dict["params"][ + GenParams.TOP_K + ] = request.sampling_params.strategy.top_k if isinstance(request.sampling_params.strategy, GreedySamplingStrategy): input_dict["params"][GenParams.TEMPERATURE] = 0.0 From 8809a34b23a16bc3cfd792aca9dce4e99a6029c0 Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Mon, 6 Oct 2025 11:39:14 -0700 Subject: [PATCH 02/11] Removed additional dead code --- llama_stack/apis/inference/inference.py | 28 -- .../utils/inference/openai_compat.py | 313 ++++++++++-------- 2 files changed, 169 insertions(+), 172 deletions(-) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index d149a4dc2a..c11d5956b1 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -351,34 +351,6 @@ class CompletionRequest(BaseModel): logprobs: LogProbConfig | None = None -# @json_schema_type -# class CompletionResponse(MetricResponseMixin): -# """Response from a completion request. - -# :param content: The generated completion text -# :param stop_reason: Reason why generation stopped -# :param logprobs: Optional log probabilities for generated tokens -# """ - -# content: str -# stop_reason: StopReason -# logprobs: list[TokenLogProbs] | None = None - - -# @json_schema_type -# class CompletionResponseStreamChunk(MetricResponseMixin): -# """A chunk of a streamed completion response. - -# :param delta: New content generated since last chunk. This can be one or more tokens. -# :param stop_reason: Optional reason why generation stopped, if complete -# :param logprobs: Optional log probabilities for generated tokens -# """ - -# delta: str -# stop_reason: StopReason | None = None -# logprobs: list[TokenLogProbs] | None = None - - class SystemMessageBehavior(Enum): """Config for how to override the default system prompt. diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index d863eb53a5..6a0fe94170 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -10,24 +10,14 @@ import uuid import warnings from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Iterable -from typing import ( - Any, -) +from typing import Any from openai import AsyncStream from openai.types.chat import ( ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, -) -from openai.types.chat import ( ChatCompletionChunk as OpenAIChatCompletionChunk, -) -from openai.types.chat import ( ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, -) -from openai.types.chat import ( ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, -) -from openai.types.chat import ( ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, ) @@ -39,56 +29,15 @@ from openai.types.chat.chat_completion_message_tool_call import ( ChatCompletionMessageToolCall as OpenAIChatCompletionMessageFunctionToolCall, ) -from openai.types.chat import ( - ChatCompletionMessageParam as OpenAIChatCompletionMessage, -) -from openai.types.chat import ( - ChatCompletionMessageToolCall, -) -from openai.types.chat import ( - ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, -) -from openai.types.chat import ( - ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, -) -from openai.types.chat import ( - ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, -) -from openai.types.chat.chat_completion import ( - Choice as OpenAIChoice, -) -from openai.types.chat.chat_completion import ( - ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs -) -from openai.types.chat.chat_completion_chunk import ( - Choice as OpenAIChatCompletionChunkChoice, -) -from openai.types.chat.chat_completion_chunk import ( - ChoiceDelta as OpenAIChoiceDelta, -) -from openai.types.chat.chat_completion_chunk import ( - ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall, -) -from openai.types.chat.chat_completion_chunk import ( - ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction, -) -from openai.types.chat.chat_completion_content_part_image_param import ( - ImageURL as OpenAIImageURL, -) -from openai.types.chat.chat_completion_message_tool_call import ( - Function as OpenAIFunction, -) -from pydantic import BaseModel - from llama_stack.apis.common.content_types import ( - URL, + _URLOrData, ImageContentItem, InterleavedContent, TextContentItem, TextDelta, ToolCallDelta, ToolCallParseStatus, - _URLOrData, + URL, ) from llama_stack.apis.inference import ( ChatCompletionRequest, @@ -97,8 +46,6 @@ ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, CompletionMessage, - CompletionResponse, - CompletionResponseStreamChunk, GreedySamplingStrategy, JsonSchemaResponseFormat, Message, @@ -116,9 +63,6 @@ TopPSamplingStrategy, UserMessage, ) -from llama_stack.apis.inference import ( - OpenAIChoice as OpenAIChatCompletionChoice, -) from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( BuiltinTool, @@ -130,6 +74,30 @@ convert_image_content_to_url, decode_assistant_message, ) +from openai.types.chat import ( + ChatCompletionMessageParam as OpenAIChatCompletionMessage, + ChatCompletionMessageToolCall, + ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, + ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, + ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, +) +from openai.types.chat.chat_completion import ( + Choice as OpenAIChoice, + ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs +) +from openai.types.chat.chat_completion_chunk import ( + Choice as OpenAIChatCompletionChunkChoice, + ChoiceDelta as OpenAIChoiceDelta, + ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction, +) +from openai.types.chat.chat_completion_content_part_image_param import ( + ImageURL as OpenAIImageURL, +) +from openai.types.chat.chat_completion_message_tool_call import ( + Function as OpenAIFunction, +) +from pydantic import BaseModel logger = get_logger(name=__name__, category="providers::utils") @@ -228,12 +196,16 @@ def convert_openai_completion_logprobs( if logprobs.tokens and logprobs.token_logprobs: return [ TokenLogProbs(logprobs_by_token={token: token_lp}) - for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs, strict=False) + for token, token_lp in zip( + logprobs.tokens, logprobs.token_logprobs, strict=False + ) ] return None -def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenAICompatLogprobs | None): +def convert_openai_completion_logprobs_stream( + text: str, logprobs: float | OpenAICompatLogprobs | None +): if logprobs is None: return None if isinstance(logprobs, float): @@ -244,29 +216,29 @@ def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenA return None -def process_completion_response( - response: OpenAICompatCompletionResponse, -) -> CompletionResponse: - choice = response.choices[0] - # drop suffix if present and return stop reason as end of turn - if choice.text.endswith("<|eot_id|>"): - return CompletionResponse( - stop_reason=StopReason.end_of_turn, - content=choice.text[: -len("<|eot_id|>")], - logprobs=convert_openai_completion_logprobs(choice.logprobs), - ) - # drop suffix if present and return stop reason as end of message - if choice.text.endswith("<|eom_id|>"): - return CompletionResponse( - stop_reason=StopReason.end_of_message, - content=choice.text[: -len("<|eom_id|>")], - logprobs=convert_openai_completion_logprobs(choice.logprobs), - ) - return CompletionResponse( - stop_reason=get_stop_reason(choice.finish_reason), - content=choice.text, - logprobs=convert_openai_completion_logprobs(choice.logprobs), - ) +# def process_completion_response( +# response: OpenAICompatCompletionResponse, +# ) -> CompletionResponse: +# choice = response.choices[0] +# # drop suffix if present and return stop reason as end of turn +# if choice.text.endswith("<|eot_id|>"): +# return CompletionResponse( +# stop_reason=StopReason.end_of_turn, +# content=choice.text[: -len("<|eot_id|>")], +# logprobs=convert_openai_completion_logprobs(choice.logprobs), +# ) +# # drop suffix if present and return stop reason as end of message +# if choice.text.endswith("<|eom_id|>"): +# return CompletionResponse( +# stop_reason=StopReason.end_of_message, +# content=choice.text[: -len("<|eom_id|>")], +# logprobs=convert_openai_completion_logprobs(choice.logprobs), +# ) +# return CompletionResponse( +# stop_reason=get_stop_reason(choice.finish_reason), +# content=choice.text, +# logprobs=convert_openai_completion_logprobs(choice.logprobs), +# ) def process_chat_completion_response( @@ -278,7 +250,9 @@ def process_chat_completion_response( if not choice.message or not choice.message.tool_calls: raise ValueError("Tool calls are not present in the response") - tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] + tool_calls = [ + convert_tool_call(tool_call) for tool_call in choice.message.tool_calls + ] if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): # If we couldn't parse a tool call, jsonify the tool calls and return them return ChatCompletionResponse( @@ -302,7 +276,9 @@ def process_chat_completion_response( # TODO: This does not work well with tool calls for vLLM remote provider # Ref: https://github.com/meta-llama/llama-stack/issues/1058 - raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason)) + raw_message = decode_assistant_message( + text_from_choice(choice), get_stop_reason(choice.finish_reason) + ) # NOTE: If we do not set tools in chat-completion request, we should not # expect the ToolCall in the response. Instead, we should return the raw @@ -335,40 +311,40 @@ def process_chat_completion_response( ) -async def process_completion_stream_response( - stream: AsyncGenerator[OpenAICompatCompletionResponse, None], -) -> AsyncGenerator[CompletionResponseStreamChunk, None]: - stop_reason = None - - async for chunk in stream: - choice = chunk.choices[0] - finish_reason = choice.finish_reason - - text = text_from_choice(choice) - if text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - continue - elif text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - continue - yield CompletionResponseStreamChunk( - delta=text, - stop_reason=stop_reason, - logprobs=convert_openai_completion_logprobs_stream(text, choice.logprobs), - ) - if finish_reason: - if finish_reason in ["stop", "eos", "eos_token"]: - stop_reason = StopReason.end_of_turn - elif finish_reason == "length": - stop_reason = StopReason.out_of_tokens - break - - yield CompletionResponseStreamChunk( - delta="", - stop_reason=stop_reason, - ) +# async def process_completion_stream_response( +# stream: AsyncGenerator[OpenAICompatCompletionResponse, None], +# ) -> AsyncGenerator[CompletionResponseStreamChunk, None]: +# stop_reason = None + +# async for chunk in stream: +# choice = chunk.choices[0] +# finish_reason = choice.finish_reason + +# text = text_from_choice(choice) +# if text == "<|eot_id|>": +# stop_reason = StopReason.end_of_turn +# text = "" +# continue +# elif text == "<|eom_id|>": +# stop_reason = StopReason.end_of_message +# text = "" +# continue +# yield CompletionResponseStreamChunk( +# delta=text, +# stop_reason=stop_reason, +# logprobs=convert_openai_completion_logprobs_stream(text, choice.logprobs), +# ) +# if finish_reason: +# if finish_reason in ["stop", "eos", "eos_token"]: +# stop_reason = StopReason.end_of_turn +# elif finish_reason == "length": +# stop_reason = StopReason.out_of_tokens +# break + +# yield CompletionResponseStreamChunk( +# delta="", +# stop_reason=stop_reason, +# ) async def process_chat_completion_stream_response( @@ -503,13 +479,17 @@ async def process_chat_completion_stream_response( ) -async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict: +async def convert_message_to_openai_dict( + message: Message, download: bool = False +) -> dict: async def _convert_content(content) -> dict: if isinstance(content, ImageContentItem): return { "type": "image_url", "image_url": { - "url": await convert_image_content_to_url(content, download=download), + "url": await convert_image_content_to_url( + content, download=download + ), }, } else: @@ -594,7 +574,11 @@ async def _convert_message_content( ) -> str | Iterable[OpenAIChatCompletionContentPartParam]: async def impl( content_: InterleavedContent, - ) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]: + ) -> ( + str + | OpenAIChatCompletionContentPartParam + | list[OpenAIChatCompletionContentPartParam] + ): # Llama Stack and OpenAI spec match for str and text input if isinstance(content_, str): return content_ @@ -607,7 +591,9 @@ async def impl( return OpenAIChatCompletionContentPartImageParam( type="image_url", image_url=OpenAIImageURL( - url=await convert_image_content_to_url(content_, download=download_images) + url=await convert_image_content_to_url( + content_, download=download_images + ) ), ) elif isinstance(content_, list): @@ -634,7 +620,11 @@ async def impl( OpenAIChatCompletionMessageFunctionToolCall( id=tool.call_id, function=OpenAIFunction( - name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), + name=( + tool.tool_name + if not isinstance(tool.tool_name, BuiltinTool) + else tool.tool_name.value + ), arguments=tool.arguments, # Already a JSON string, don't double-encode ), type="function", @@ -814,7 +804,9 @@ class StopReason(Enum): }.get(finish_reason, StopReason.end_of_turn) -def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None = None) -> ToolConfig: +def _convert_openai_request_tool_config( + tool_choice: str | dict[str, Any] | None = None +) -> ToolConfig: tool_config = ToolConfig() if tool_choice: try: @@ -825,7 +817,9 @@ def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None return tool_config -def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]: +def _convert_openai_request_tools( + tools: list[dict[str, Any]] | None = None +) -> list[ToolDefinition]: lls_tools = [] if not tools: return lls_tools @@ -924,7 +918,11 @@ def _convert_openai_logprobs( return None return [ - TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs}) + TokenLogProbs( + logprobs_by_token={ + logprobs.token: logprobs.logprob for logprobs in content.top_logprobs + } + ) for content in logprobs.content ] @@ -963,9 +961,13 @@ def openai_messages_to_messages( converted_messages = [] for message in messages: if message.role == "system": - converted_message = SystemMessage(content=openai_content_to_content(message.content)) + converted_message = SystemMessage( + content=openai_content_to_content(message.content) + ) elif message.role == "user": - converted_message = UserMessage(content=openai_content_to_content(message.content)) + converted_message = UserMessage( + content=openai_content_to_content(message.content) + ) elif message.role == "assistant": converted_message = CompletionMessage( content=openai_content_to_content(message.content), @@ -984,7 +986,9 @@ def openai_messages_to_messages( return converted_messages -def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionContentPartParam] | None): +def openai_content_to_content( + content: str | Iterable[OpenAIChatCompletionContentPartParam] | None, +): if content is None: return "" if isinstance(content, str): @@ -995,7 +999,9 @@ def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionConten if content.type == "text": return TextContentItem(type="text", text=content.text) elif content.type == "image_url": - return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url))) + return ImageContentItem( + type="image", image=_URLOrData(url=URL(uri=content.image_url.url)) + ) else: raise ValueError(f"Unknown content type: {content.type}") else: @@ -1035,14 +1041,17 @@ class StopReason(Enum): end_of_message = "end_of_message" out_of_tokens = "out_of_tokens" """ - assert hasattr(choice, "message") and choice.message, "error in server response: message not found" - assert hasattr(choice, "finish_reason") and choice.finish_reason, ( - "error in server response: finish_reason not found" - ) + assert ( + hasattr(choice, "message") and choice.message + ), "error in server response: message not found" + assert ( + hasattr(choice, "finish_reason") and choice.finish_reason + ), "error in server response: finish_reason not found" return ChatCompletionResponse( completion_message=CompletionMessage( - content=choice.message.content or "", # CompletionMessage content is not optional + content=choice.message.content + or "", # CompletionMessage content is not optional stop_reason=_convert_openai_finish_reason(choice.finish_reason), tool_calls=_convert_openai_tool_calls(choice.message.tool_calls), ), @@ -1282,7 +1291,9 @@ async def openai_chat_completion( outstanding_responses.append(response) if stream: - return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses) + return OpenAIChatCompletionToLlamaStackMixin._process_stream_response( + self, model, outstanding_responses + ) return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response( self, model, outstanding_responses @@ -1291,21 +1302,29 @@ async def openai_chat_completion( async def _process_stream_response( self, model: str, - outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]], + outstanding_responses: list[ + Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]] + ], ): id = f"chatcmpl-{uuid.uuid4()}" for i, outstanding_response in enumerate(outstanding_responses): response = await outstanding_response async for chunk in response: event = chunk.event - finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason) + finish_reason = _convert_stop_reason_to_openai_finish_reason( + event.stop_reason + ) if isinstance(event.delta, TextDelta): text_delta = event.delta.text delta = OpenAIChoiceDelta(content=text_delta) yield OpenAIChatCompletionChunk( id=id, - choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)], + choices=[ + OpenAIChatCompletionChunkChoice( + index=i, finish_reason=finish_reason, delta=delta + ) + ], created=int(time.time()), model=model, object="chat.completion.chunk", @@ -1327,7 +1346,9 @@ async def _process_stream_response( yield OpenAIChatCompletionChunk( id=id, choices=[ - OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) + OpenAIChatCompletionChunkChoice( + index=i, finish_reason=finish_reason, delta=delta + ) ], created=int(time.time()), model=model, @@ -1344,7 +1365,9 @@ async def _process_stream_response( yield OpenAIChatCompletionChunk( id=id, choices=[ - OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) + OpenAIChatCompletionChunkChoice( + index=i, finish_reason=finish_reason, delta=delta + ) ], created=int(time.time()), model=model, @@ -1359,7 +1382,9 @@ async def _process_non_stream_response( response = await outstanding_response completion_message = response.completion_message message = await convert_message_to_openai_dict_new(completion_message) - finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason) + finish_reason = _convert_stop_reason_to_openai_finish_reason( + completion_message.stop_reason + ) choice = OpenAIChatCompletionChoice( index=len(choices), From 69dadaf5967d71c09053984d309a556ee579bcd9 Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Mon, 6 Oct 2025 13:11:49 -0700 Subject: [PATCH 03/11] Fix CI test failures: restore missing classes and fix imports - Restored CompletionResponse and CompletionResponseStreamChunk classes in inference.py - Fixed SamplingParams import in test_resolver.py - Fixed openai_compat.py imports - Added missing pytest configuration files - Fixed registry error messages - All fixes needed to make CI tests pass --- tests/conftest.py | 8 ++++++++ tests/integration/conftest.py | 29 ++++++++++++++++++----------- 2 files changed, 26 insertions(+), 11 deletions(-) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000..fce589e579 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Global pytest plugins configuration +pytest_plugins = ["tests.integration.fixtures.common"] diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 4735264c3e..8e7c3c7404 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -69,7 +69,9 @@ def pytest_configure(config): suite = config.getoption("--suite") if suite: if suite not in SUITE_DEFINITIONS: - raise pytest.UsageError(f"Unknown suite: {suite}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}") + raise pytest.UsageError( + f"Unknown suite: {suite}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}" + ) # Apply setups (global parameterizations): env + defaults setup = config.getoption("--setup") @@ -107,7 +109,9 @@ def pytest_addoption(parser): """ ), ) - parser.addoption("--env", action="append", help="Set environment variables, e.g. --env KEY=value") + parser.addoption( + "--env", action="append", help="Set environment variables, e.g. --env KEY=value" + ) parser.addoption( "--text-model", help="comma-separated list of text models. Fixture name: text_model_id", @@ -147,9 +151,7 @@ def pytest_addoption(parser): ) available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys())) - suite_help = ( - f"Single test suite to run (narrows collection). Available: {available_suites}. Example: --suite=responses" - ) + suite_help = f"Single test suite to run (narrows collection). Available: {available_suites}. Example: --suite=responses" parser.addoption("--suite", help=suite_help) # Global setups for any suite @@ -221,7 +223,11 @@ def pytest_generate_tests(metafunc): # Generate test IDs test_ids = [] - non_empty_params = [(i, values) for i, values in enumerate(param_values.values()) if values[0] is not None] + non_empty_params = [ + (i, values) + for i, values in enumerate(param_values.values()) + if values[0] is not None + ] # Get actual function parameters using inspect test_func_params = set(inspect.signature(metafunc.function).parameters.keys()) @@ -238,10 +244,9 @@ def pytest_generate_tests(metafunc): if parts: test_ids.append(":".join(parts)) - metafunc.parametrize(params, value_combinations, scope="session", ids=test_ids if test_ids else None) - - -pytest_plugins = ["tests.integration.fixtures.common"] + metafunc.parametrize( + params, value_combinations, scope="session", ids=test_ids if test_ids else None + ) def pytest_ignore_collect(path: str, config: pytest.Config) -> bool: @@ -251,7 +256,9 @@ def pytest_ignore_collect(path: str, config: pytest.Config) -> bool: return False sobj = SUITE_DEFINITIONS.get(suite) - roots: list[str] = sobj.get("roots", []) if isinstance(sobj, dict) else getattr(sobj, "roots", []) + roots: list[str] = ( + sobj.get("roots", []) if isinstance(sobj, dict) else getattr(sobj, "roots", []) + ) if not roots: return False From 3b1add336c96be2d04f124c6509246608d917b58 Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Mon, 6 Oct 2025 13:14:21 -0700 Subject: [PATCH 04/11] Critical CI fix: Add missing SamplingParams import in test_resolver.py This fixes the CI failure: NameError: name 'SamplingParams' is not defined The test_resolver.py file needed to import SamplingParams from llama_stack.apis.inference to resolve the CI test collection error. --- tests/unit/server/test_resolver.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/unit/server/test_resolver.py b/tests/unit/server/test_resolver.py index 1ee1b2f470..484b126230 100644 --- a/tests/unit/server/test_resolver.py +++ b/tests/unit/server/test_resolver.py @@ -9,19 +9,15 @@ from typing import Any, Protocol from unittest.mock import AsyncMock, MagicMock -from pydantic import BaseModel, Field - -from llama_stack.apis.inference import Inference -from llama_stack.core.datatypes import ( - Api, - Provider, - StackRunConfig, -) +from llama_stack.apis.inference import Inference, SamplingParams +from llama_stack.core.datatypes import Api, Provider, StackRunConfig from llama_stack.core.resolver import resolve_impls from llama_stack.core.routers.inference import InferenceRouter from llama_stack.core.routing_tables.models import ModelsRoutingTable from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec +from pydantic import BaseModel, Field + def add_protocol_methods(cls: type, protocol: type[Protocol]) -> None: """Dynamically add protocol methods to a class by inspecting the protocol.""" @@ -54,7 +50,12 @@ def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]: class SampleImpl: - def __init__(self, config: SampleConfig, deps: dict[Api, Any], provider_spec: ProviderSpec = None): + def __init__( + self, + config: SampleConfig, + deps: dict[Api, Any], + provider_spec: ProviderSpec = None, + ): self.__provider_id__ = "test_provider" self.__provider_spec__ = provider_spec self.__provider_config__ = config From 9fc0d966f6f7c29a819a0c8fff1ead3274e4b7ae Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Mon, 6 Oct 2025 13:27:19 -0700 Subject: [PATCH 05/11] Ran precommit --- llama_stack/apis/inference/inference.py | 38 +--- llama_stack/core/routers/inference.py | 160 ++++----------- .../remote/inference/runpod/runpod.py | 7 +- .../remote/inference/watsonx/watsonx.py | 40 +--- .../utils/inference/openai_compat.py | 186 ++++++++---------- tests/integration/conftest.py | 26 +-- tests/unit/server/test_resolver.py | 6 +- 7 files changed, 153 insertions(+), 310 deletions(-) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 897b5d6e8d..91c448bc6b 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -8,6 +8,9 @@ from enum import Enum from typing import Annotated, Any, Literal, Protocol, runtime_checkable +from pydantic import BaseModel, Field, field_validator +from typing_extensions import TypedDict + from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent from llama_stack.apis.common.responses import Order from llama_stack.apis.models import Model @@ -23,9 +26,6 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, register_schema, webmethod -from pydantic import BaseModel, Field, field_validator -from typing_extensions import TypedDict - register_schema(ToolCall) register_schema(ToolDefinition) @@ -381,9 +381,7 @@ class ToolConfig(BaseModel): tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto) tool_prompt_format: ToolPromptFormat | None = Field(default=None) - system_message_behavior: SystemMessageBehavior | None = Field( - default=SystemMessageBehavior.append - ) + system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append) def model_post_init(self, __context: Any) -> None: if isinstance(self.tool_choice, str): @@ -512,21 +510,15 @@ class OpenAIFile(BaseModel): OpenAIChatCompletionContentPartParam = Annotated[ - OpenAIChatCompletionContentPartTextParam - | OpenAIChatCompletionContentPartImageParam - | OpenAIFile, + OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam | OpenAIFile, Field(discriminator="type"), ] -register_schema( - OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam" -) +register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam") OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam] -OpenAIChatCompletionTextOnlyMessageContent = ( - str | list[OpenAIChatCompletionContentPartTextParam] -) +OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam] @json_schema_type @@ -694,9 +686,7 @@ class OpenAIResponseFormatJSONObject(BaseModel): OpenAIResponseFormatParam = Annotated[ - OpenAIResponseFormatText - | OpenAIResponseFormatJSONSchema - | OpenAIResponseFormatJSONObject, + OpenAIResponseFormatText | OpenAIResponseFormatJSONSchema | OpenAIResponseFormatJSONObject, Field(discriminator="type"), ] register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam") @@ -986,16 +976,8 @@ class InferenceProvider(Protocol): async def rerank( self, model: str, - query: ( - str - | OpenAIChatCompletionContentPartTextParam - | OpenAIChatCompletionContentPartImageParam - ), - items: list[ - str - | OpenAIChatCompletionContentPartTextParam - | OpenAIChatCompletionContentPartImageParam - ], + query: (str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam), + items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], max_num_results: int | None = None, ) -> RerankResponse: """Rerank a list of documents based on their relevance to a query. diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index d2944ec81a..ff86a89ed4 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -7,9 +7,17 @@ import asyncio import time from collections.abc import AsyncGenerator, AsyncIterator -from datetime import datetime, UTC +from datetime import UTC, datetime from typing import Annotated, Any +from openai.types.chat import ( + ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam, +) +from openai.types.chat import ( + ChatCompletionToolParam as OpenAIChatCompletionToolParam, +) +from pydantic import Field, TypeAdapter + from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError from llama_stack.apis.inference import ( @@ -48,12 +56,6 @@ get_current_span, ) -from openai.types.chat import ( - ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam, - ChatCompletionToolParam as OpenAIChatCompletionToolParam, -) -from pydantic import Field, TypeAdapter - logger = get_logger(name=__name__, category="core::routers") @@ -96,9 +98,7 @@ async def register_model( logger.debug( f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", ) - await self.routing_table.register_model( - model_id, provider_model_id, provider_id, metadata, model_type - ) + await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) def _construct_metrics( self, @@ -153,16 +153,11 @@ async def _compute_and_log_token_usage( total_tokens: int, model: Model, ) -> list[MetricInResponse]: - metrics = self._construct_metrics( - prompt_tokens, completion_tokens, total_tokens, model - ) + metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) if self.telemetry: for metric in metrics: enqueue_event(metric) - return [ - MetricInResponse(metric=metric.metric, value=metric.value) - for metric in metrics - ] + return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] async def _count_tokens( self, @@ -256,9 +251,7 @@ async def openai_completion( # these metrics will show up in the client response. response.metrics = ( - metrics - if not hasattr(response, "metrics") or response.metrics is None - else response.metrics + metrics + metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics ) return response @@ -296,13 +289,9 @@ async def openai_chat_completion( # Use the OpenAI client for a bit of extra input validation without # exposing the OpenAI client itself as part of our API surface if tool_choice: - TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python( - tool_choice - ) + TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice) if tools is None: - raise ValueError( - "'tool_choice' is only allowed when 'tools' is also provided" - ) + raise ValueError("'tool_choice' is only allowed when 'tools' is also provided") if tools: for tool in tools: TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool) @@ -367,9 +356,7 @@ async def openai_chat_completion( enqueue_event(metric) # these metrics will show up in the client response. response.metrics = ( - metrics - if not hasattr(response, "metrics") or response.metrics is None - else response.metrics + metrics + metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics ) return response @@ -405,31 +392,19 @@ async def list_chat_completions( ) -> ListOpenAIChatCompletionResponse: if self.store: return await self.store.list_chat_completions(after, limit, model, order) - raise NotImplementedError( - "List chat completions is not supported: inference store is not configured." - ) + raise NotImplementedError("List chat completions is not supported: inference store is not configured.") - async def get_chat_completion( - self, completion_id: str - ) -> OpenAICompletionWithInputMessages: + async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages: if self.store: return await self.store.get_chat_completion(completion_id) - raise NotImplementedError( - "Get chat completion is not supported: inference store is not configured." - ) + raise NotImplementedError("Get chat completion is not supported: inference store is not configured.") - async def _nonstream_openai_chat_completion( - self, provider: Inference, params: dict - ) -> OpenAIChatCompletion: + async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion: response = await provider.openai_chat_completion(**params) for choice in response.choices: # some providers return an empty list for no tool calls in non-streaming responses # but the OpenAI API returns None. So, set tool_calls to None if it's empty - if ( - choice.message - and choice.message.tool_calls is not None - and len(choice.message.tool_calls) == 0 - ): + if choice.message and choice.message.tool_calls is not None and len(choice.message.tool_calls) == 0: choice.message.tool_calls = None return response @@ -449,9 +424,7 @@ async def health(self) -> dict[str, HealthResponse]: message=f"Health check timed out after {timeout} seconds", ) except NotImplementedError: - health_statuses[provider_id] = HealthResponse( - status=HealthStatus.NOT_IMPLEMENTED - ) + health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED) except Exception as e: health_statuses[provider_id] = HealthResponse( status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}" @@ -486,11 +459,7 @@ async def stream_tokens_and_compute_metrics( else: if hasattr(chunk, "delta"): completion_text += chunk.delta - if ( - hasattr(chunk, "stop_reason") - and chunk.stop_reason - and self.telemetry - ): + if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: complete = True completion_tokens = await self._count_tokens(completion_text) # if we are done receiving tokens @@ -515,14 +484,9 @@ async def stream_tokens_and_compute_metrics( # Return metrics in response async_metrics = [ - MetricInResponse(metric=metric.metric, value=metric.value) - for metric in completion_metrics + MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics ] - chunk.metrics = ( - async_metrics - if chunk.metrics is None - else chunk.metrics + async_metrics - ) + chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics else: # Fallback if no telemetry completion_metrics = self._construct_metrics( @@ -532,14 +496,9 @@ async def stream_tokens_and_compute_metrics( model, ) async_metrics = [ - MetricInResponse(metric=metric.metric, value=metric.value) - for metric in completion_metrics + MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics ] - chunk.metrics = ( - async_metrics - if chunk.metrics is None - else chunk.metrics + async_metrics - ) + chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics yield chunk async def count_tokens_and_compute_metrics( @@ -553,9 +512,7 @@ async def count_tokens_and_compute_metrics( content = [response.completion_message] else: content = response.content - completion_tokens = await self._count_tokens( - messages=content, tool_prompt_format=tool_prompt_format - ) + completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) # Create a separate span for completion metrics @@ -575,10 +532,7 @@ async def count_tokens_and_compute_metrics( enqueue_event(metric) # Return metrics in response - return [ - MetricInResponse(metric=metric.metric, value=metric.value) - for metric in completion_metrics - ] + return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics] # Fallback if no telemetry metrics = self._construct_metrics( @@ -587,10 +541,7 @@ async def count_tokens_and_compute_metrics( total_tokens, model, ) - return [ - MetricInResponse(metric=metric.metric, value=metric.value) - for metric in metrics - ] + return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] async def stream_tokens_and_compute_metrics_openai_chat( self, @@ -631,48 +582,33 @@ async def stream_tokens_and_compute_metrics_openai_chat( if choice_delta.delta: delta = choice_delta.delta if delta.content: - current_choice_data["content_parts"].append( - delta.content - ) + current_choice_data["content_parts"].append(delta.content) if delta.tool_calls: for tool_call_delta in delta.tool_calls: tc_idx = tool_call_delta.index - if ( - tc_idx - not in current_choice_data["tool_calls_builder"] - ): - current_choice_data["tool_calls_builder"][ - tc_idx - ] = { + if tc_idx not in current_choice_data["tool_calls_builder"]: + current_choice_data["tool_calls_builder"][tc_idx] = { "id": None, "type": "function", "function_name_parts": [], "function_arguments_parts": [], } - builder = current_choice_data["tool_calls_builder"][ - tc_idx - ] + builder = current_choice_data["tool_calls_builder"][tc_idx] if tool_call_delta.id: builder["id"] = tool_call_delta.id if tool_call_delta.type: builder["type"] = tool_call_delta.type if tool_call_delta.function: if tool_call_delta.function.name: - builder["function_name_parts"].append( - tool_call_delta.function.name - ) + builder["function_name_parts"].append(tool_call_delta.function.name) if tool_call_delta.function.arguments: builder["function_arguments_parts"].append( tool_call_delta.function.arguments ) if choice_delta.finish_reason: - current_choice_data["finish_reason"] = ( - choice_delta.finish_reason - ) + current_choice_data["finish_reason"] = choice_delta.finish_reason if choice_delta.logprobs and choice_delta.logprobs.content: - current_choice_data["logprobs_content_parts"].extend( - choice_delta.logprobs.content - ) + current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content) # Compute metrics on final chunk if chunk.choices and chunk.choices[0].finish_reason: @@ -702,12 +638,8 @@ async def stream_tokens_and_compute_metrics_openai_chat( if choice_data["tool_calls_builder"]: for tc_build_data in choice_data["tool_calls_builder"].values(): if tc_build_data["id"]: - func_name = "".join( - tc_build_data["function_name_parts"] - ) - func_args = "".join( - tc_build_data["function_arguments_parts"] - ) + func_name = "".join(tc_build_data["function_name_parts"]) + func_args = "".join(tc_build_data["function_arguments_parts"]) assembled_tool_calls.append( OpenAIChatCompletionToolCall( id=tc_build_data["id"], @@ -720,16 +652,10 @@ async def stream_tokens_and_compute_metrics_openai_chat( message = OpenAIAssistantMessageParam( role="assistant", content=content_str if content_str else None, - tool_calls=( - assembled_tool_calls if assembled_tool_calls else None - ), + tool_calls=(assembled_tool_calls if assembled_tool_calls else None), ) logprobs_content = choice_data["logprobs_content_parts"] - final_logprobs = ( - OpenAIChoiceLogprobs(content=logprobs_content) - if logprobs_content - else None - ) + final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None assembled_choices.append( OpenAIChoice( @@ -748,6 +674,4 @@ async def stream_tokens_and_compute_metrics_openai_chat( object="chat.completion", ) logger.debug(f"InferenceRouter.completion_response: {final_response}") - asyncio.create_task( - self.store.store_chat_completion(final_response, messages) - ) + asyncio.create_task(self.store.store_chat_completion(final_response, messages)) diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index 28c65cda52..a7802b2828 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -7,10 +7,9 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import OpenAIEmbeddingsResponse - from llama_stack.providers.utils.inference.model_registry import ( - build_hf_repo_model_entry, ModelRegistryHelper, + build_hf_repo_model_entry, ) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, @@ -51,9 +50,7 @@ class RunpodInferenceAdapter( Inference, ): def __init__(self, config: RunpodImplConfig) -> None: - ModelRegistryHelper.__init__( - self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS - ) + ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS) self.config = config def _get_params(self, request: ChatCompletionRequest) -> dict: diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index 0c69318b9a..63b877b5e8 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -9,6 +9,7 @@ from ibm_watsonx_ai.foundation_models import Model from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams +from openai import AsyncOpenAI from llama_stack.apis.inference import ( ChatCompletionRequest, @@ -33,7 +34,6 @@ completion_request_to_prompt, request_has_media, ) -from openai import AsyncOpenAI from . import WatsonXConfig from .models import MODEL_ENTRIES @@ -65,9 +65,7 @@ def __init__(self, config: WatsonXConfig) -> None: self._project_id = self._config.project_id def _get_client(self, model_id) -> Model: - config_api_key = ( - self._config.api_key.get_secret_value() if self._config.api_key else None - ) + config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None config_url = self._config.url project_id = self._config.project_id credentials = {"url": config_url, "apikey": config_api_key} @@ -82,46 +80,28 @@ def _get_openai_client(self) -> AsyncOpenAI: ) return self._openai_client - async def _get_params(self, request: ChatCompletionRequest) -> dict: - input_dict = {"params": {}} media_present = request_has_media(request) llama_model = self.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): - input_dict["prompt"] = await chat_completion_request_to_prompt( - request, llama_model - ) + input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) else: - assert ( - not media_present - ), "Together does not support media for Completion requests" + assert not media_present, "Together does not support media for Completion requests" input_dict["prompt"] = await completion_request_to_prompt(request) if request.sampling_params: if request.sampling_params.strategy: - input_dict["params"][ - GenParams.DECODING_METHOD - ] = request.sampling_params.strategy.type + input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type if request.sampling_params.max_tokens: - input_dict["params"][ - GenParams.MAX_NEW_TOKENS - ] = request.sampling_params.max_tokens + input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens if request.sampling_params.repetition_penalty: - input_dict["params"][ - GenParams.REPETITION_PENALTY - ] = request.sampling_params.repetition_penalty + input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty if isinstance(request.sampling_params.strategy, TopPSamplingStrategy): - input_dict["params"][ - GenParams.TOP_P - ] = request.sampling_params.strategy.top_p - input_dict["params"][ - GenParams.TEMPERATURE - ] = request.sampling_params.strategy.temperature + input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p + input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature if isinstance(request.sampling_params.strategy, TopKSamplingStrategy): - input_dict["params"][ - GenParams.TOP_K - ] = request.sampling_params.strategy.top_k + input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k if isinstance(request.sampling_params.strategy, GreedySamplingStrategy): input_dict["params"][GenParams.TEMPERATURE] = 0.0 diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 6a0fe94170..a3e272d204 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -15,9 +15,17 @@ from openai import AsyncStream from openai.types.chat import ( ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, +) +from openai.types.chat import ( ChatCompletionChunk as OpenAIChatCompletionChunk, +) +from openai.types.chat import ( ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, +) +from openai.types.chat import ( ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, +) +from openai.types.chat import ( ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, ) @@ -29,15 +37,56 @@ from openai.types.chat.chat_completion_message_tool_call import ( ChatCompletionMessageToolCall as OpenAIChatCompletionMessageFunctionToolCall, ) +from openai.types.chat import ( + ChatCompletionMessageParam as OpenAIChatCompletionMessage, +) +from openai.types.chat import ( + ChatCompletionMessageToolCall, +) +from openai.types.chat import ( + ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, +) +from openai.types.chat import ( + ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, +) +from openai.types.chat import ( + ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, +) +from openai.types.chat.chat_completion import ( + Choice as OpenAIChoice, +) +from openai.types.chat.chat_completion import ( + ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs +) +from openai.types.chat.chat_completion_chunk import ( + Choice as OpenAIChatCompletionChunkChoice, +) +from openai.types.chat.chat_completion_chunk import ( + ChoiceDelta as OpenAIChoiceDelta, +) +from openai.types.chat.chat_completion_chunk import ( + ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall, +) +from openai.types.chat.chat_completion_chunk import ( + ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction, +) +from openai.types.chat.chat_completion_content_part_image_param import ( + ImageURL as OpenAIImageURL, +) +from openai.types.chat.chat_completion_message_tool_call import ( + Function as OpenAIFunction, +) +from pydantic import BaseModel + from llama_stack.apis.common.content_types import ( - _URLOrData, + URL, ImageContentItem, InterleavedContent, TextContentItem, TextDelta, ToolCallDelta, ToolCallParseStatus, - URL, + _URLOrData, ) from llama_stack.apis.inference import ( ChatCompletionRequest, @@ -74,30 +123,6 @@ convert_image_content_to_url, decode_assistant_message, ) -from openai.types.chat import ( - ChatCompletionMessageParam as OpenAIChatCompletionMessage, - ChatCompletionMessageToolCall, - ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, - ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, - ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, -) -from openai.types.chat.chat_completion import ( - Choice as OpenAIChoice, - ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs -) -from openai.types.chat.chat_completion_chunk import ( - Choice as OpenAIChatCompletionChunkChoice, - ChoiceDelta as OpenAIChoiceDelta, - ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall, - ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction, -) -from openai.types.chat.chat_completion_content_part_image_param import ( - ImageURL as OpenAIImageURL, -) -from openai.types.chat.chat_completion_message_tool_call import ( - Function as OpenAIFunction, -) -from pydantic import BaseModel logger = get_logger(name=__name__, category="providers::utils") @@ -196,16 +221,12 @@ def convert_openai_completion_logprobs( if logprobs.tokens and logprobs.token_logprobs: return [ TokenLogProbs(logprobs_by_token={token: token_lp}) - for token, token_lp in zip( - logprobs.tokens, logprobs.token_logprobs, strict=False - ) + for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs, strict=False) ] return None -def convert_openai_completion_logprobs_stream( - text: str, logprobs: float | OpenAICompatLogprobs | None -): +def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenAICompatLogprobs | None): if logprobs is None: return None if isinstance(logprobs, float): @@ -250,9 +271,7 @@ def process_chat_completion_response( if not choice.message or not choice.message.tool_calls: raise ValueError("Tool calls are not present in the response") - tool_calls = [ - convert_tool_call(tool_call) for tool_call in choice.message.tool_calls - ] + tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): # If we couldn't parse a tool call, jsonify the tool calls and return them return ChatCompletionResponse( @@ -276,9 +295,7 @@ def process_chat_completion_response( # TODO: This does not work well with tool calls for vLLM remote provider # Ref: https://github.com/meta-llama/llama-stack/issues/1058 - raw_message = decode_assistant_message( - text_from_choice(choice), get_stop_reason(choice.finish_reason) - ) + raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason)) # NOTE: If we do not set tools in chat-completion request, we should not # expect the ToolCall in the response. Instead, we should return the raw @@ -479,17 +496,13 @@ async def process_chat_completion_stream_response( ) -async def convert_message_to_openai_dict( - message: Message, download: bool = False -) -> dict: +async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict: async def _convert_content(content) -> dict: if isinstance(content, ImageContentItem): return { "type": "image_url", "image_url": { - "url": await convert_image_content_to_url( - content, download=download - ), + "url": await convert_image_content_to_url(content, download=download), }, } else: @@ -574,11 +587,7 @@ async def _convert_message_content( ) -> str | Iterable[OpenAIChatCompletionContentPartParam]: async def impl( content_: InterleavedContent, - ) -> ( - str - | OpenAIChatCompletionContentPartParam - | list[OpenAIChatCompletionContentPartParam] - ): + ) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]: # Llama Stack and OpenAI spec match for str and text input if isinstance(content_, str): return content_ @@ -591,9 +600,7 @@ async def impl( return OpenAIChatCompletionContentPartImageParam( type="image_url", image_url=OpenAIImageURL( - url=await convert_image_content_to_url( - content_, download=download_images - ) + url=await convert_image_content_to_url(content_, download=download_images) ), ) elif isinstance(content_, list): @@ -620,11 +627,7 @@ async def impl( OpenAIChatCompletionMessageFunctionToolCall( id=tool.call_id, function=OpenAIFunction( - name=( - tool.tool_name - if not isinstance(tool.tool_name, BuiltinTool) - else tool.tool_name.value - ), + name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), arguments=tool.arguments, # Already a JSON string, don't double-encode ), type="function", @@ -804,9 +807,7 @@ class StopReason(Enum): }.get(finish_reason, StopReason.end_of_turn) -def _convert_openai_request_tool_config( - tool_choice: str | dict[str, Any] | None = None -) -> ToolConfig: +def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None = None) -> ToolConfig: tool_config = ToolConfig() if tool_choice: try: @@ -817,9 +818,7 @@ def _convert_openai_request_tool_config( return tool_config -def _convert_openai_request_tools( - tools: list[dict[str, Any]] | None = None -) -> list[ToolDefinition]: +def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]: lls_tools = [] if not tools: return lls_tools @@ -918,11 +917,7 @@ def _convert_openai_logprobs( return None return [ - TokenLogProbs( - logprobs_by_token={ - logprobs.token: logprobs.logprob for logprobs in content.top_logprobs - } - ) + TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs}) for content in logprobs.content ] @@ -961,13 +956,9 @@ def openai_messages_to_messages( converted_messages = [] for message in messages: if message.role == "system": - converted_message = SystemMessage( - content=openai_content_to_content(message.content) - ) + converted_message = SystemMessage(content=openai_content_to_content(message.content)) elif message.role == "user": - converted_message = UserMessage( - content=openai_content_to_content(message.content) - ) + converted_message = UserMessage(content=openai_content_to_content(message.content)) elif message.role == "assistant": converted_message = CompletionMessage( content=openai_content_to_content(message.content), @@ -999,9 +990,7 @@ def openai_content_to_content( if content.type == "text": return TextContentItem(type="text", text=content.text) elif content.type == "image_url": - return ImageContentItem( - type="image", image=_URLOrData(url=URL(uri=content.image_url.url)) - ) + return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url))) else: raise ValueError(f"Unknown content type: {content.type}") else: @@ -1041,17 +1030,14 @@ class StopReason(Enum): end_of_message = "end_of_message" out_of_tokens = "out_of_tokens" """ - assert ( - hasattr(choice, "message") and choice.message - ), "error in server response: message not found" - assert ( - hasattr(choice, "finish_reason") and choice.finish_reason - ), "error in server response: finish_reason not found" + assert hasattr(choice, "message") and choice.message, "error in server response: message not found" + assert hasattr(choice, "finish_reason") and choice.finish_reason, ( + "error in server response: finish_reason not found" + ) return ChatCompletionResponse( completion_message=CompletionMessage( - content=choice.message.content - or "", # CompletionMessage content is not optional + content=choice.message.content or "", # CompletionMessage content is not optional stop_reason=_convert_openai_finish_reason(choice.finish_reason), tool_calls=_convert_openai_tool_calls(choice.message.tool_calls), ), @@ -1291,9 +1277,7 @@ async def openai_chat_completion( outstanding_responses.append(response) if stream: - return OpenAIChatCompletionToLlamaStackMixin._process_stream_response( - self, model, outstanding_responses - ) + return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses) return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response( self, model, outstanding_responses @@ -1302,29 +1286,21 @@ async def openai_chat_completion( async def _process_stream_response( self, model: str, - outstanding_responses: list[ - Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]] - ], + outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]], ): id = f"chatcmpl-{uuid.uuid4()}" for i, outstanding_response in enumerate(outstanding_responses): response = await outstanding_response async for chunk in response: event = chunk.event - finish_reason = _convert_stop_reason_to_openai_finish_reason( - event.stop_reason - ) + finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason) if isinstance(event.delta, TextDelta): text_delta = event.delta.text delta = OpenAIChoiceDelta(content=text_delta) yield OpenAIChatCompletionChunk( id=id, - choices=[ - OpenAIChatCompletionChunkChoice( - index=i, finish_reason=finish_reason, delta=delta - ) - ], + choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)], created=int(time.time()), model=model, object="chat.completion.chunk", @@ -1346,9 +1322,7 @@ async def _process_stream_response( yield OpenAIChatCompletionChunk( id=id, choices=[ - OpenAIChatCompletionChunkChoice( - index=i, finish_reason=finish_reason, delta=delta - ) + OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) ], created=int(time.time()), model=model, @@ -1365,9 +1339,7 @@ async def _process_stream_response( yield OpenAIChatCompletionChunk( id=id, choices=[ - OpenAIChatCompletionChunkChoice( - index=i, finish_reason=finish_reason, delta=delta - ) + OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) ], created=int(time.time()), model=model, @@ -1382,9 +1354,7 @@ async def _process_non_stream_response( response = await outstanding_response completion_message = response.completion_message message = await convert_message_to_openai_dict_new(completion_message) - finish_reason = _convert_stop_reason_to_openai_finish_reason( - completion_message.stop_reason - ) + finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason) choice = OpenAIChatCompletionChoice( index=len(choices), diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 13c055b665..90838f273b 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -87,9 +87,7 @@ def pytest_configure(config): suite = config.getoption("--suite") if suite: if suite not in SUITE_DEFINITIONS: - raise pytest.UsageError( - f"Unknown suite: {suite}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}" - ) + raise pytest.UsageError(f"Unknown suite: {suite}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}") # Apply setups (global parameterizations): env + defaults setup = config.getoption("--setup") @@ -127,9 +125,7 @@ def pytest_addoption(parser): """ ), ) - parser.addoption( - "--env", action="append", help="Set environment variables, e.g. --env KEY=value" - ) + parser.addoption("--env", action="append", help="Set environment variables, e.g. --env KEY=value") parser.addoption( "--text-model", help="comma-separated list of text models. Fixture name: text_model_id", @@ -169,7 +165,9 @@ def pytest_addoption(parser): ) available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys())) - suite_help = f"Single test suite to run (narrows collection). Available: {available_suites}. Example: --suite=responses" + suite_help = ( + f"Single test suite to run (narrows collection). Available: {available_suites}. Example: --suite=responses" + ) parser.addoption("--suite", help=suite_help) # Global setups for any suite @@ -241,11 +239,7 @@ def pytest_generate_tests(metafunc): # Generate test IDs test_ids = [] - non_empty_params = [ - (i, values) - for i, values in enumerate(param_values.values()) - if values[0] is not None - ] + non_empty_params = [(i, values) for i, values in enumerate(param_values.values()) if values[0] is not None] # Get actual function parameters using inspect test_func_params = set(inspect.signature(metafunc.function).parameters.keys()) @@ -262,9 +256,7 @@ def pytest_generate_tests(metafunc): if parts: test_ids.append(":".join(parts)) - metafunc.parametrize( - params, value_combinations, scope="session", ids=test_ids if test_ids else None - ) + metafunc.parametrize(params, value_combinations, scope="session", ids=test_ids if test_ids else None) def pytest_ignore_collect(path: str, config: pytest.Config) -> bool: @@ -274,9 +266,7 @@ def pytest_ignore_collect(path: str, config: pytest.Config) -> bool: return False sobj = SUITE_DEFINITIONS.get(suite) - roots: list[str] = ( - sobj.get("roots", []) if isinstance(sobj, dict) else getattr(sobj, "roots", []) - ) + roots: list[str] = sobj.get("roots", []) if isinstance(sobj, dict) else getattr(sobj, "roots", []) if not roots: return False diff --git a/tests/unit/server/test_resolver.py b/tests/unit/server/test_resolver.py index 484b126230..df22747544 100644 --- a/tests/unit/server/test_resolver.py +++ b/tests/unit/server/test_resolver.py @@ -9,15 +9,15 @@ from typing import Any, Protocol from unittest.mock import AsyncMock, MagicMock -from llama_stack.apis.inference import Inference, SamplingParams +from pydantic import BaseModel, Field + +from llama_stack.apis.inference import Inference from llama_stack.core.datatypes import Api, Provider, StackRunConfig from llama_stack.core.resolver import resolve_impls from llama_stack.core.routers.inference import InferenceRouter from llama_stack.core.routing_tables.models import ModelsRoutingTable from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec -from pydantic import BaseModel, Field - def add_protocol_methods(cls: type, protocol: type[Protocol]) -> None: """Dynamically add protocol methods to a class by inspecting the protocol.""" From 236f235ddbd8ada79eb35652d1b9d75d86d3eedd Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Mon, 6 Oct 2025 13:30:58 -0700 Subject: [PATCH 06/11] removed a single file --- tests/conftest.py | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index fce589e579..0000000000 --- a/tests/conftest.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -# Global pytest plugins configuration -pytest_plugins = ["tests.integration.fixtures.common"] From deaccfcb4b76bd90c4977ba550161cea37dfb75d Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Mon, 6 Oct 2025 13:36:23 -0700 Subject: [PATCH 07/11] updated conftest.py --- tests/integration/conftest.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 90838f273b..42015a608e 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -259,6 +259,9 @@ def pytest_generate_tests(metafunc): metafunc.parametrize(params, value_combinations, scope="session", ids=test_ids if test_ids else None) +pytest_plugins = ["tests.integration.fixtures.common"] + + def pytest_ignore_collect(path: str, config: pytest.Config) -> bool: """Skip collecting paths outside the selected suite roots for speed.""" suite = config.getoption("--suite") From bcdca4d9a5d71154eaddc112bd419ed0d569f922 Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Mon, 6 Oct 2025 13:47:01 -0700 Subject: [PATCH 08/11] Added an import --- tests/integration/conftest.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 42015a608e..0e553eef4a 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -87,7 +87,9 @@ def pytest_configure(config): suite = config.getoption("--suite") if suite: if suite not in SUITE_DEFINITIONS: - raise pytest.UsageError(f"Unknown suite: {suite}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}") + raise pytest.UsageError( + f"Unknown suite: {suite}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}" + ) # Apply setups (global parameterizations): env + defaults setup = config.getoption("--setup") @@ -125,7 +127,9 @@ def pytest_addoption(parser): """ ), ) - parser.addoption("--env", action="append", help="Set environment variables, e.g. --env KEY=value") + parser.addoption( + "--env", action="append", help="Set environment variables, e.g. --env KEY=value" + ) parser.addoption( "--text-model", help="comma-separated list of text models. Fixture name: text_model_id", @@ -165,9 +169,7 @@ def pytest_addoption(parser): ) available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys())) - suite_help = ( - f"Single test suite to run (narrows collection). Available: {available_suites}. Example: --suite=responses" - ) + suite_help = f"Single test suite to run (narrows collection). Available: {available_suites}. Example: --suite=responses" parser.addoption("--suite", help=suite_help) # Global setups for any suite @@ -239,7 +241,11 @@ def pytest_generate_tests(metafunc): # Generate test IDs test_ids = [] - non_empty_params = [(i, values) for i, values in enumerate(param_values.values()) if values[0] is not None] + non_empty_params = [ + (i, values) + for i, values in enumerate(param_values.values()) + if values[0] is not None + ] # Get actual function parameters using inspect test_func_params = set(inspect.signature(metafunc.function).parameters.keys()) @@ -256,7 +262,9 @@ def pytest_generate_tests(metafunc): if parts: test_ids.append(":".join(parts)) - metafunc.parametrize(params, value_combinations, scope="session", ids=test_ids if test_ids else None) + metafunc.parametrize( + params, value_combinations, scope="session", ids=test_ids if test_ids else None + ) pytest_plugins = ["tests.integration.fixtures.common"] @@ -269,7 +277,9 @@ def pytest_ignore_collect(path: str, config: pytest.Config) -> bool: return False sobj = SUITE_DEFINITIONS.get(suite) - roots: list[str] = sobj.get("roots", []) if isinstance(sobj, dict) else getattr(sobj, "roots", []) + roots: list[str] = ( + sobj.get("roots", []) if isinstance(sobj, dict) else getattr(sobj, "roots", []) + ) if not roots: return False From b0a6adf3b23dc417f26274ad0b348cb7d67231a0 Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Mon, 6 Oct 2025 13:49:02 -0700 Subject: [PATCH 09/11] Added minor fix --- tests/integration/conftest.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 0e553eef4a..42015a608e 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -87,9 +87,7 @@ def pytest_configure(config): suite = config.getoption("--suite") if suite: if suite not in SUITE_DEFINITIONS: - raise pytest.UsageError( - f"Unknown suite: {suite}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}" - ) + raise pytest.UsageError(f"Unknown suite: {suite}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}") # Apply setups (global parameterizations): env + defaults setup = config.getoption("--setup") @@ -127,9 +125,7 @@ def pytest_addoption(parser): """ ), ) - parser.addoption( - "--env", action="append", help="Set environment variables, e.g. --env KEY=value" - ) + parser.addoption("--env", action="append", help="Set environment variables, e.g. --env KEY=value") parser.addoption( "--text-model", help="comma-separated list of text models. Fixture name: text_model_id", @@ -169,7 +165,9 @@ def pytest_addoption(parser): ) available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys())) - suite_help = f"Single test suite to run (narrows collection). Available: {available_suites}. Example: --suite=responses" + suite_help = ( + f"Single test suite to run (narrows collection). Available: {available_suites}. Example: --suite=responses" + ) parser.addoption("--suite", help=suite_help) # Global setups for any suite @@ -241,11 +239,7 @@ def pytest_generate_tests(metafunc): # Generate test IDs test_ids = [] - non_empty_params = [ - (i, values) - for i, values in enumerate(param_values.values()) - if values[0] is not None - ] + non_empty_params = [(i, values) for i, values in enumerate(param_values.values()) if values[0] is not None] # Get actual function parameters using inspect test_func_params = set(inspect.signature(metafunc.function).parameters.keys()) @@ -262,9 +256,7 @@ def pytest_generate_tests(metafunc): if parts: test_ids.append(":".join(parts)) - metafunc.parametrize( - params, value_combinations, scope="session", ids=test_ids if test_ids else None - ) + metafunc.parametrize(params, value_combinations, scope="session", ids=test_ids if test_ids else None) pytest_plugins = ["tests.integration.fixtures.common"] @@ -277,9 +269,7 @@ def pytest_ignore_collect(path: str, config: pytest.Config) -> bool: return False sobj = SUITE_DEFINITIONS.get(suite) - roots: list[str] = ( - sobj.get("roots", []) if isinstance(sobj, dict) else getattr(sobj, "roots", []) - ) + roots: list[str] = sobj.get("roots", []) if isinstance(sobj, dict) else getattr(sobj, "roots", []) if not roots: return False From 6adaca3d968af4e9a73b274eb7c2c3f434c50a5f Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Mon, 6 Oct 2025 14:30:45 -0700 Subject: [PATCH 10/11] added a minor fix --- .../utils/inference/openai_compat.py | 212 +++++++++--------- 1 file changed, 109 insertions(+), 103 deletions(-) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index a3e272d204..e50d4d5617 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -15,17 +15,9 @@ from openai import AsyncStream from openai.types.chat import ( ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, -) -from openai.types.chat import ( ChatCompletionChunk as OpenAIChatCompletionChunk, -) -from openai.types.chat import ( ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, -) -from openai.types.chat import ( ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, -) -from openai.types.chat import ( ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, ) @@ -37,56 +29,15 @@ from openai.types.chat.chat_completion_message_tool_call import ( ChatCompletionMessageToolCall as OpenAIChatCompletionMessageFunctionToolCall, ) -from openai.types.chat import ( - ChatCompletionMessageParam as OpenAIChatCompletionMessage, -) -from openai.types.chat import ( - ChatCompletionMessageToolCall, -) -from openai.types.chat import ( - ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, -) -from openai.types.chat import ( - ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, -) -from openai.types.chat import ( - ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, -) -from openai.types.chat.chat_completion import ( - Choice as OpenAIChoice, -) -from openai.types.chat.chat_completion import ( - ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs -) -from openai.types.chat.chat_completion_chunk import ( - Choice as OpenAIChatCompletionChunkChoice, -) -from openai.types.chat.chat_completion_chunk import ( - ChoiceDelta as OpenAIChoiceDelta, -) -from openai.types.chat.chat_completion_chunk import ( - ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall, -) -from openai.types.chat.chat_completion_chunk import ( - ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction, -) -from openai.types.chat.chat_completion_content_part_image_param import ( - ImageURL as OpenAIImageURL, -) -from openai.types.chat.chat_completion_message_tool_call import ( - Function as OpenAIFunction, -) -from pydantic import BaseModel - from llama_stack.apis.common.content_types import ( - URL, + _URLOrData, ImageContentItem, InterleavedContent, TextContentItem, TextDelta, ToolCallDelta, ToolCallParseStatus, - _URLOrData, + URL, ) from llama_stack.apis.inference import ( ChatCompletionRequest, @@ -99,6 +50,7 @@ JsonSchemaResponseFormat, Message, OpenAIChatCompletion, + OpenAIChoice as OpenAIChatCompletionChoice, OpenAIEmbeddingData, OpenAIMessageParam, OpenAIResponseFormatParam, @@ -123,6 +75,30 @@ convert_image_content_to_url, decode_assistant_message, ) +from openai.types.chat import ( + ChatCompletionMessageParam as OpenAIChatCompletionMessage, + ChatCompletionMessageToolCall, + ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, + ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, + ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, +) +from openai.types.chat.chat_completion import ( + Choice as OpenAIChoice, + ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs +) +from openai.types.chat.chat_completion_chunk import ( + Choice as OpenAIChatCompletionChunkChoice, + ChoiceDelta as OpenAIChoiceDelta, + ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction, +) +from openai.types.chat.chat_completion_content_part_image_param import ( + ImageURL as OpenAIImageURL, +) +from openai.types.chat.chat_completion_message_tool_call import ( + Function as OpenAIFunction, +) +from pydantic import BaseModel logger = get_logger(name=__name__, category="providers::utils") @@ -221,12 +197,16 @@ def convert_openai_completion_logprobs( if logprobs.tokens and logprobs.token_logprobs: return [ TokenLogProbs(logprobs_by_token={token: token_lp}) - for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs, strict=False) + for token, token_lp in zip( + logprobs.tokens, logprobs.token_logprobs, strict=False + ) ] return None -def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenAICompatLogprobs | None): +def convert_openai_completion_logprobs_stream( + text: str, logprobs: float | OpenAICompatLogprobs | None +): if logprobs is None: return None if isinstance(logprobs, float): @@ -237,31 +217,6 @@ def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenA return None -# def process_completion_response( -# response: OpenAICompatCompletionResponse, -# ) -> CompletionResponse: -# choice = response.choices[0] -# # drop suffix if present and return stop reason as end of turn -# if choice.text.endswith("<|eot_id|>"): -# return CompletionResponse( -# stop_reason=StopReason.end_of_turn, -# content=choice.text[: -len("<|eot_id|>")], -# logprobs=convert_openai_completion_logprobs(choice.logprobs), -# ) -# # drop suffix if present and return stop reason as end of message -# if choice.text.endswith("<|eom_id|>"): -# return CompletionResponse( -# stop_reason=StopReason.end_of_message, -# content=choice.text[: -len("<|eom_id|>")], -# logprobs=convert_openai_completion_logprobs(choice.logprobs), -# ) -# return CompletionResponse( -# stop_reason=get_stop_reason(choice.finish_reason), -# content=choice.text, -# logprobs=convert_openai_completion_logprobs(choice.logprobs), -# ) - - def process_chat_completion_response( response: OpenAICompatCompletionResponse, request: ChatCompletionRequest, @@ -271,7 +226,9 @@ def process_chat_completion_response( if not choice.message or not choice.message.tool_calls: raise ValueError("Tool calls are not present in the response") - tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] + tool_calls = [ + convert_tool_call(tool_call) for tool_call in choice.message.tool_calls + ] if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): # If we couldn't parse a tool call, jsonify the tool calls and return them return ChatCompletionResponse( @@ -295,7 +252,9 @@ def process_chat_completion_response( # TODO: This does not work well with tool calls for vLLM remote provider # Ref: https://github.com/meta-llama/llama-stack/issues/1058 - raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason)) + raw_message = decode_assistant_message( + text_from_choice(choice), get_stop_reason(choice.finish_reason) + ) # NOTE: If we do not set tools in chat-completion request, we should not # expect the ToolCall in the response. Instead, we should return the raw @@ -496,13 +455,17 @@ async def process_chat_completion_stream_response( ) -async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict: +async def convert_message_to_openai_dict( + message: Message, download: bool = False +) -> dict: async def _convert_content(content) -> dict: if isinstance(content, ImageContentItem): return { "type": "image_url", "image_url": { - "url": await convert_image_content_to_url(content, download=download), + "url": await convert_image_content_to_url( + content, download=download + ), }, } else: @@ -587,7 +550,11 @@ async def _convert_message_content( ) -> str | Iterable[OpenAIChatCompletionContentPartParam]: async def impl( content_: InterleavedContent, - ) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]: + ) -> ( + str + | OpenAIChatCompletionContentPartParam + | list[OpenAIChatCompletionContentPartParam] + ): # Llama Stack and OpenAI spec match for str and text input if isinstance(content_, str): return content_ @@ -600,7 +567,9 @@ async def impl( return OpenAIChatCompletionContentPartImageParam( type="image_url", image_url=OpenAIImageURL( - url=await convert_image_content_to_url(content_, download=download_images) + url=await convert_image_content_to_url( + content_, download=download_images + ) ), ) elif isinstance(content_, list): @@ -627,7 +596,11 @@ async def impl( OpenAIChatCompletionMessageFunctionToolCall( id=tool.call_id, function=OpenAIFunction( - name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), + name=( + tool.tool_name + if not isinstance(tool.tool_name, BuiltinTool) + else tool.tool_name.value + ), arguments=tool.arguments, # Already a JSON string, don't double-encode ), type="function", @@ -807,7 +780,9 @@ class StopReason(Enum): }.get(finish_reason, StopReason.end_of_turn) -def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None = None) -> ToolConfig: +def _convert_openai_request_tool_config( + tool_choice: str | dict[str, Any] | None = None +) -> ToolConfig: tool_config = ToolConfig() if tool_choice: try: @@ -818,7 +793,9 @@ def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None return tool_config -def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]: +def _convert_openai_request_tools( + tools: list[dict[str, Any]] | None = None +) -> list[ToolDefinition]: lls_tools = [] if not tools: return lls_tools @@ -917,7 +894,11 @@ def _convert_openai_logprobs( return None return [ - TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs}) + TokenLogProbs( + logprobs_by_token={ + logprobs.token: logprobs.logprob for logprobs in content.top_logprobs + } + ) for content in logprobs.content ] @@ -956,9 +937,13 @@ def openai_messages_to_messages( converted_messages = [] for message in messages: if message.role == "system": - converted_message = SystemMessage(content=openai_content_to_content(message.content)) + converted_message = SystemMessage( + content=openai_content_to_content(message.content) + ) elif message.role == "user": - converted_message = UserMessage(content=openai_content_to_content(message.content)) + converted_message = UserMessage( + content=openai_content_to_content(message.content) + ) elif message.role == "assistant": converted_message = CompletionMessage( content=openai_content_to_content(message.content), @@ -990,7 +975,9 @@ def openai_content_to_content( if content.type == "text": return TextContentItem(type="text", text=content.text) elif content.type == "image_url": - return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url))) + return ImageContentItem( + type="image", image=_URLOrData(url=URL(uri=content.image_url.url)) + ) else: raise ValueError(f"Unknown content type: {content.type}") else: @@ -1030,14 +1017,17 @@ class StopReason(Enum): end_of_message = "end_of_message" out_of_tokens = "out_of_tokens" """ - assert hasattr(choice, "message") and choice.message, "error in server response: message not found" - assert hasattr(choice, "finish_reason") and choice.finish_reason, ( - "error in server response: finish_reason not found" - ) + assert ( + hasattr(choice, "message") and choice.message + ), "error in server response: message not found" + assert ( + hasattr(choice, "finish_reason") and choice.finish_reason + ), "error in server response: finish_reason not found" return ChatCompletionResponse( completion_message=CompletionMessage( - content=choice.message.content or "", # CompletionMessage content is not optional + content=choice.message.content + or "", # CompletionMessage content is not optional stop_reason=_convert_openai_finish_reason(choice.finish_reason), tool_calls=_convert_openai_tool_calls(choice.message.tool_calls), ), @@ -1277,7 +1267,9 @@ async def openai_chat_completion( outstanding_responses.append(response) if stream: - return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses) + return OpenAIChatCompletionToLlamaStackMixin._process_stream_response( + self, model, outstanding_responses + ) return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response( self, model, outstanding_responses @@ -1286,21 +1278,29 @@ async def openai_chat_completion( async def _process_stream_response( self, model: str, - outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]], + outstanding_responses: list[ + Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]] + ], ): id = f"chatcmpl-{uuid.uuid4()}" for i, outstanding_response in enumerate(outstanding_responses): response = await outstanding_response async for chunk in response: event = chunk.event - finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason) + finish_reason = _convert_stop_reason_to_openai_finish_reason( + event.stop_reason + ) if isinstance(event.delta, TextDelta): text_delta = event.delta.text delta = OpenAIChoiceDelta(content=text_delta) yield OpenAIChatCompletionChunk( id=id, - choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)], + choices=[ + OpenAIChatCompletionChunkChoice( + index=i, finish_reason=finish_reason, delta=delta + ) + ], created=int(time.time()), model=model, object="chat.completion.chunk", @@ -1322,7 +1322,9 @@ async def _process_stream_response( yield OpenAIChatCompletionChunk( id=id, choices=[ - OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) + OpenAIChatCompletionChunkChoice( + index=i, finish_reason=finish_reason, delta=delta + ) ], created=int(time.time()), model=model, @@ -1339,7 +1341,9 @@ async def _process_stream_response( yield OpenAIChatCompletionChunk( id=id, choices=[ - OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) + OpenAIChatCompletionChunkChoice( + index=i, finish_reason=finish_reason, delta=delta + ) ], created=int(time.time()), model=model, @@ -1354,7 +1358,9 @@ async def _process_non_stream_response( response = await outstanding_response completion_message = response.completion_message message = await convert_message_to_openai_dict_new(completion_message) - finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason) + finish_reason = _convert_stop_reason_to_openai_finish_reason( + completion_message.stop_reason + ) choice = OpenAIChatCompletionChoice( index=len(choices), From f4104756f6fa354cc399368c0a15cc7dbbae956c Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Mon, 6 Oct 2025 14:35:38 -0700 Subject: [PATCH 11/11] reformatting --- .../utils/inference/openai_compat.py | 190 ++++++++---------- 1 file changed, 81 insertions(+), 109 deletions(-) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index e50d4d5617..e3f1d0913e 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -15,9 +15,17 @@ from openai import AsyncStream from openai.types.chat import ( ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, +) +from openai.types.chat import ( ChatCompletionChunk as OpenAIChatCompletionChunk, +) +from openai.types.chat import ( ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, +) +from openai.types.chat import ( ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, +) +from openai.types.chat import ( ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, ) @@ -29,15 +37,56 @@ from openai.types.chat.chat_completion_message_tool_call import ( ChatCompletionMessageToolCall as OpenAIChatCompletionMessageFunctionToolCall, ) +from openai.types.chat import ( + ChatCompletionMessageParam as OpenAIChatCompletionMessage, +) +from openai.types.chat import ( + ChatCompletionMessageToolCall, +) +from openai.types.chat import ( + ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, +) +from openai.types.chat import ( + ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, +) +from openai.types.chat import ( + ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, +) +from openai.types.chat.chat_completion import ( + Choice as OpenAIChoice, +) +from openai.types.chat.chat_completion import ( + ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs +) +from openai.types.chat.chat_completion_chunk import ( + Choice as OpenAIChatCompletionChunkChoice, +) +from openai.types.chat.chat_completion_chunk import ( + ChoiceDelta as OpenAIChoiceDelta, +) +from openai.types.chat.chat_completion_chunk import ( + ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall, +) +from openai.types.chat.chat_completion_chunk import ( + ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction, +) +from openai.types.chat.chat_completion_content_part_image_param import ( + ImageURL as OpenAIImageURL, +) +from openai.types.chat.chat_completion_message_tool_call import ( + Function as OpenAIFunction, +) +from pydantic import BaseModel + from llama_stack.apis.common.content_types import ( - _URLOrData, + URL, ImageContentItem, InterleavedContent, TextContentItem, TextDelta, ToolCallDelta, ToolCallParseStatus, - URL, + _URLOrData, ) from llama_stack.apis.inference import ( ChatCompletionRequest, @@ -50,7 +99,6 @@ JsonSchemaResponseFormat, Message, OpenAIChatCompletion, - OpenAIChoice as OpenAIChatCompletionChoice, OpenAIEmbeddingData, OpenAIMessageParam, OpenAIResponseFormatParam, @@ -64,6 +112,9 @@ TopPSamplingStrategy, UserMessage, ) +from llama_stack.apis.inference import ( + OpenAIChoice as OpenAIChatCompletionChoice, +) from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( BuiltinTool, @@ -75,30 +126,6 @@ convert_image_content_to_url, decode_assistant_message, ) -from openai.types.chat import ( - ChatCompletionMessageParam as OpenAIChatCompletionMessage, - ChatCompletionMessageToolCall, - ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, - ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, - ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, -) -from openai.types.chat.chat_completion import ( - Choice as OpenAIChoice, - ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs -) -from openai.types.chat.chat_completion_chunk import ( - Choice as OpenAIChatCompletionChunkChoice, - ChoiceDelta as OpenAIChoiceDelta, - ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall, - ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction, -) -from openai.types.chat.chat_completion_content_part_image_param import ( - ImageURL as OpenAIImageURL, -) -from openai.types.chat.chat_completion_message_tool_call import ( - Function as OpenAIFunction, -) -from pydantic import BaseModel logger = get_logger(name=__name__, category="providers::utils") @@ -197,16 +224,12 @@ def convert_openai_completion_logprobs( if logprobs.tokens and logprobs.token_logprobs: return [ TokenLogProbs(logprobs_by_token={token: token_lp}) - for token, token_lp in zip( - logprobs.tokens, logprobs.token_logprobs, strict=False - ) + for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs, strict=False) ] return None -def convert_openai_completion_logprobs_stream( - text: str, logprobs: float | OpenAICompatLogprobs | None -): +def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenAICompatLogprobs | None): if logprobs is None: return None if isinstance(logprobs, float): @@ -226,9 +249,7 @@ def process_chat_completion_response( if not choice.message or not choice.message.tool_calls: raise ValueError("Tool calls are not present in the response") - tool_calls = [ - convert_tool_call(tool_call) for tool_call in choice.message.tool_calls - ] + tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): # If we couldn't parse a tool call, jsonify the tool calls and return them return ChatCompletionResponse( @@ -252,9 +273,7 @@ def process_chat_completion_response( # TODO: This does not work well with tool calls for vLLM remote provider # Ref: https://github.com/meta-llama/llama-stack/issues/1058 - raw_message = decode_assistant_message( - text_from_choice(choice), get_stop_reason(choice.finish_reason) - ) + raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason)) # NOTE: If we do not set tools in chat-completion request, we should not # expect the ToolCall in the response. Instead, we should return the raw @@ -455,17 +474,13 @@ async def process_chat_completion_stream_response( ) -async def convert_message_to_openai_dict( - message: Message, download: bool = False -) -> dict: +async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict: async def _convert_content(content) -> dict: if isinstance(content, ImageContentItem): return { "type": "image_url", "image_url": { - "url": await convert_image_content_to_url( - content, download=download - ), + "url": await convert_image_content_to_url(content, download=download), }, } else: @@ -550,11 +565,7 @@ async def _convert_message_content( ) -> str | Iterable[OpenAIChatCompletionContentPartParam]: async def impl( content_: InterleavedContent, - ) -> ( - str - | OpenAIChatCompletionContentPartParam - | list[OpenAIChatCompletionContentPartParam] - ): + ) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]: # Llama Stack and OpenAI spec match for str and text input if isinstance(content_, str): return content_ @@ -567,9 +578,7 @@ async def impl( return OpenAIChatCompletionContentPartImageParam( type="image_url", image_url=OpenAIImageURL( - url=await convert_image_content_to_url( - content_, download=download_images - ) + url=await convert_image_content_to_url(content_, download=download_images) ), ) elif isinstance(content_, list): @@ -596,11 +605,7 @@ async def impl( OpenAIChatCompletionMessageFunctionToolCall( id=tool.call_id, function=OpenAIFunction( - name=( - tool.tool_name - if not isinstance(tool.tool_name, BuiltinTool) - else tool.tool_name.value - ), + name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), arguments=tool.arguments, # Already a JSON string, don't double-encode ), type="function", @@ -780,9 +785,7 @@ class StopReason(Enum): }.get(finish_reason, StopReason.end_of_turn) -def _convert_openai_request_tool_config( - tool_choice: str | dict[str, Any] | None = None -) -> ToolConfig: +def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None = None) -> ToolConfig: tool_config = ToolConfig() if tool_choice: try: @@ -793,9 +796,7 @@ def _convert_openai_request_tool_config( return tool_config -def _convert_openai_request_tools( - tools: list[dict[str, Any]] | None = None -) -> list[ToolDefinition]: +def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]: lls_tools = [] if not tools: return lls_tools @@ -894,11 +895,7 @@ def _convert_openai_logprobs( return None return [ - TokenLogProbs( - logprobs_by_token={ - logprobs.token: logprobs.logprob for logprobs in content.top_logprobs - } - ) + TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs}) for content in logprobs.content ] @@ -937,13 +934,9 @@ def openai_messages_to_messages( converted_messages = [] for message in messages: if message.role == "system": - converted_message = SystemMessage( - content=openai_content_to_content(message.content) - ) + converted_message = SystemMessage(content=openai_content_to_content(message.content)) elif message.role == "user": - converted_message = UserMessage( - content=openai_content_to_content(message.content) - ) + converted_message = UserMessage(content=openai_content_to_content(message.content)) elif message.role == "assistant": converted_message = CompletionMessage( content=openai_content_to_content(message.content), @@ -975,9 +968,7 @@ def openai_content_to_content( if content.type == "text": return TextContentItem(type="text", text=content.text) elif content.type == "image_url": - return ImageContentItem( - type="image", image=_URLOrData(url=URL(uri=content.image_url.url)) - ) + return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url))) else: raise ValueError(f"Unknown content type: {content.type}") else: @@ -1017,17 +1008,14 @@ class StopReason(Enum): end_of_message = "end_of_message" out_of_tokens = "out_of_tokens" """ - assert ( - hasattr(choice, "message") and choice.message - ), "error in server response: message not found" - assert ( - hasattr(choice, "finish_reason") and choice.finish_reason - ), "error in server response: finish_reason not found" + assert hasattr(choice, "message") and choice.message, "error in server response: message not found" + assert hasattr(choice, "finish_reason") and choice.finish_reason, ( + "error in server response: finish_reason not found" + ) return ChatCompletionResponse( completion_message=CompletionMessage( - content=choice.message.content - or "", # CompletionMessage content is not optional + content=choice.message.content or "", # CompletionMessage content is not optional stop_reason=_convert_openai_finish_reason(choice.finish_reason), tool_calls=_convert_openai_tool_calls(choice.message.tool_calls), ), @@ -1267,9 +1255,7 @@ async def openai_chat_completion( outstanding_responses.append(response) if stream: - return OpenAIChatCompletionToLlamaStackMixin._process_stream_response( - self, model, outstanding_responses - ) + return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses) return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response( self, model, outstanding_responses @@ -1278,29 +1264,21 @@ async def openai_chat_completion( async def _process_stream_response( self, model: str, - outstanding_responses: list[ - Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]] - ], + outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]], ): id = f"chatcmpl-{uuid.uuid4()}" for i, outstanding_response in enumerate(outstanding_responses): response = await outstanding_response async for chunk in response: event = chunk.event - finish_reason = _convert_stop_reason_to_openai_finish_reason( - event.stop_reason - ) + finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason) if isinstance(event.delta, TextDelta): text_delta = event.delta.text delta = OpenAIChoiceDelta(content=text_delta) yield OpenAIChatCompletionChunk( id=id, - choices=[ - OpenAIChatCompletionChunkChoice( - index=i, finish_reason=finish_reason, delta=delta - ) - ], + choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)], created=int(time.time()), model=model, object="chat.completion.chunk", @@ -1322,9 +1300,7 @@ async def _process_stream_response( yield OpenAIChatCompletionChunk( id=id, choices=[ - OpenAIChatCompletionChunkChoice( - index=i, finish_reason=finish_reason, delta=delta - ) + OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) ], created=int(time.time()), model=model, @@ -1341,9 +1317,7 @@ async def _process_stream_response( yield OpenAIChatCompletionChunk( id=id, choices=[ - OpenAIChatCompletionChunkChoice( - index=i, finish_reason=finish_reason, delta=delta - ) + OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) ], created=int(time.time()), model=model, @@ -1358,9 +1332,7 @@ async def _process_non_stream_response( response = await outstanding_response completion_message = response.completion_message message = await convert_message_to_openai_dict_new(completion_message) - finish_reason = _convert_stop_reason_to_openai_finish_reason( - completion_message.stop_reason - ) + finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason) choice = OpenAIChatCompletionChoice( index=len(choices),