Skip to content

Commit d7aaa6b

Browse files
committed
test: improve generic type handling in response deserialization
Enhance the inference recorder's deserialization logic to handle generic types like AsyncPage[Model] by recording as a list. Then replicate a AsyncPaginator for sharing Signed-off-by: Derek Higgins <[email protected]>
1 parent 7117358 commit d7aaa6b

File tree

1 file changed

+64
-2
lines changed

1 file changed

+64
-2
lines changed

llama_stack/testing/inference_recorder.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pathlib import Path
1818
from typing import Any, Literal, cast
1919

20+
from openai.pagination import AsyncPage
2021
from openai.types.chat import ChatCompletion, ChatCompletionChunk
2122

2223
from llama_stack.log import get_logger
@@ -108,6 +109,7 @@ def _deserialize_response(data: dict[str, Any]) -> Any:
108109
try:
109110
# Import the original class and reconstruct the object
110111
module_path, class_name = data["__type__"].rsplit(".", 1)
112+
111113
module = __import__(module_path, fromlist=[class_name])
112114
cls = getattr(module, class_name)
113115

@@ -298,8 +300,11 @@ async def replay_stream():
298300
# Determine if this is a streaming request based on request parameters
299301
is_streaming = body.get("stream", False)
300302

301-
if is_streaming:
302-
# For streaming responses, we need to collect all chunks immediately before yielding
303+
# Check if this is a paged response
304+
is_paged = isinstance(response, AsyncPage)
305+
306+
if is_streaming or is_paged:
307+
# For streaming and paged responses, we need to collect all chunks immediately before yielding
303308
# This ensures the recording is saved even if the generator isn't fully consumed
304309
chunks = []
305310
async for chunk in response:
@@ -332,9 +337,11 @@ def patch_inference_clients():
332337
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
333338
from openai.resources.completions import AsyncCompletions
334339
from openai.resources.embeddings import AsyncEmbeddings
340+
from openai.resources.models import AsyncModels
335341

336342
# Store original methods for both OpenAI and Ollama clients
337343
_original_methods = {
344+
"model_list": AsyncModels.list,
338345
"chat_completions_create": AsyncChatCompletions.create,
339346
"completions_create": AsyncCompletions.create,
340347
"embeddings_create": AsyncEmbeddings.create,
@@ -347,6 +354,58 @@ def patch_inference_clients():
347354
}
348355

349356
# Create patched methods for OpenAI client
357+
def patched_model_list(self, *args, **kwargs):
358+
# The original models.list() returns an AsyncPaginator that can be used with async for
359+
# We need to create a wrapper that preserves this behavior
360+
class PatchedAsyncPaginator:
361+
def __init__(self, original_method, instance, client_type, endpoint, args, kwargs):
362+
self.original_method = original_method
363+
self.instance = instance
364+
self.client_type = client_type
365+
self.endpoint = endpoint
366+
self.args = args
367+
self.kwargs = kwargs
368+
self._result = None
369+
self._iter_index = 0
370+
371+
def __await__(self):
372+
# Make it awaitable like the original AsyncPaginator
373+
async def _await():
374+
self._result = await _patched_inference_method(
375+
self.original_method, self.instance, self.client_type, self.endpoint, *self.args, **self.kwargs
376+
)
377+
return self._result
378+
379+
return _await().__await__()
380+
381+
def __aiter__(self):
382+
# Make it async iterable like the original AsyncPaginator
383+
return self
384+
385+
async def __anext__(self):
386+
# Get the result if we haven't already
387+
if self._result is None:
388+
self._result = [
389+
r
390+
async for r in await _patched_inference_method(
391+
self.original_method,
392+
self.instance,
393+
self.client_type,
394+
self.endpoint,
395+
*self.args,
396+
**self.kwargs,
397+
)
398+
]
399+
400+
# Return next item from the list
401+
if self._iter_index >= len(self._result):
402+
raise StopAsyncIteration
403+
item = self._result[self._iter_index]
404+
self._iter_index += 1
405+
return item
406+
407+
return PatchedAsyncPaginator(_original_methods["model_list"], self, "openai", "/v1/models", args, kwargs)
408+
350409
async def patched_chat_completions_create(self, *args, **kwargs):
351410
return await _patched_inference_method(
352411
_original_methods["chat_completions_create"], self, "openai", "/v1/chat/completions", *args, **kwargs
@@ -363,6 +422,7 @@ async def patched_embeddings_create(self, *args, **kwargs):
363422
)
364423

365424
# Apply OpenAI patches
425+
AsyncModels.list = patched_model_list
366426
AsyncChatCompletions.create = patched_chat_completions_create
367427
AsyncCompletions.create = patched_completions_create
368428
AsyncEmbeddings.create = patched_embeddings_create
@@ -419,8 +479,10 @@ def unpatch_inference_clients():
419479
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
420480
from openai.resources.completions import AsyncCompletions
421481
from openai.resources.embeddings import AsyncEmbeddings
482+
from openai.resources.models import AsyncModels
422483

423484
# Restore OpenAI client methods
485+
AsyncModels.list = _original_methods["model_list"]
424486
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
425487
AsyncCompletions.create = _original_methods["completions_create"]
426488
AsyncEmbeddings.create = _original_methods["embeddings_create"]

0 commit comments

Comments
 (0)