17
17
from pathlib import Path
18
18
from typing import Any , Literal , cast
19
19
20
+ from openai .pagination import AsyncPage
20
21
from openai .types .chat import ChatCompletion , ChatCompletionChunk
21
22
22
23
from llama_stack .log import get_logger
@@ -108,6 +109,7 @@ def _deserialize_response(data: dict[str, Any]) -> Any:
108
109
try :
109
110
# Import the original class and reconstruct the object
110
111
module_path , class_name = data ["__type__" ].rsplit ("." , 1 )
112
+
111
113
module = __import__ (module_path , fromlist = [class_name ])
112
114
cls = getattr (module , class_name )
113
115
@@ -298,8 +300,11 @@ async def replay_stream():
298
300
# Determine if this is a streaming request based on request parameters
299
301
is_streaming = body .get ("stream" , False )
300
302
301
- if is_streaming :
302
- # For streaming responses, we need to collect all chunks immediately before yielding
303
+ # Check if this is a paged response
304
+ is_paged = isinstance (response , AsyncPage )
305
+
306
+ if is_streaming or is_paged :
307
+ # For streaming and paged responses, we need to collect all chunks immediately before yielding
303
308
# This ensures the recording is saved even if the generator isn't fully consumed
304
309
chunks = []
305
310
async for chunk in response :
@@ -332,9 +337,11 @@ def patch_inference_clients():
332
337
from openai .resources .chat .completions import AsyncCompletions as AsyncChatCompletions
333
338
from openai .resources .completions import AsyncCompletions
334
339
from openai .resources .embeddings import AsyncEmbeddings
340
+ from openai .resources .models import AsyncModels
335
341
336
342
# Store original methods for both OpenAI and Ollama clients
337
343
_original_methods = {
344
+ "model_list" : AsyncModels .list ,
338
345
"chat_completions_create" : AsyncChatCompletions .create ,
339
346
"completions_create" : AsyncCompletions .create ,
340
347
"embeddings_create" : AsyncEmbeddings .create ,
@@ -347,6 +354,58 @@ def patch_inference_clients():
347
354
}
348
355
349
356
# Create patched methods for OpenAI client
357
+ def patched_model_list (self , * args , ** kwargs ):
358
+ # The original models.list() returns an AsyncPaginator that can be used with async for
359
+ # We need to create a wrapper that preserves this behavior
360
+ class PatchedAsyncPaginator :
361
+ def __init__ (self , original_method , instance , client_type , endpoint , args , kwargs ):
362
+ self .original_method = original_method
363
+ self .instance = instance
364
+ self .client_type = client_type
365
+ self .endpoint = endpoint
366
+ self .args = args
367
+ self .kwargs = kwargs
368
+ self ._result = None
369
+ self ._iter_index = 0
370
+
371
+ def __await__ (self ):
372
+ # Make it awaitable like the original AsyncPaginator
373
+ async def _await ():
374
+ self ._result = await _patched_inference_method (
375
+ self .original_method , self .instance , self .client_type , self .endpoint , * self .args , ** self .kwargs
376
+ )
377
+ return self ._result
378
+
379
+ return _await ().__await__ ()
380
+
381
+ def __aiter__ (self ):
382
+ # Make it async iterable like the original AsyncPaginator
383
+ return self
384
+
385
+ async def __anext__ (self ):
386
+ # Get the result if we haven't already
387
+ if self ._result is None :
388
+ self ._result = [
389
+ r
390
+ async for r in await _patched_inference_method (
391
+ self .original_method ,
392
+ self .instance ,
393
+ self .client_type ,
394
+ self .endpoint ,
395
+ * self .args ,
396
+ ** self .kwargs ,
397
+ )
398
+ ]
399
+
400
+ # Return next item from the list
401
+ if self ._iter_index >= len (self ._result ):
402
+ raise StopAsyncIteration
403
+ item = self ._result [self ._iter_index ]
404
+ self ._iter_index += 1
405
+ return item
406
+
407
+ return PatchedAsyncPaginator (_original_methods ["model_list" ], self , "openai" , "/v1/models" , args , kwargs )
408
+
350
409
async def patched_chat_completions_create (self , * args , ** kwargs ):
351
410
return await _patched_inference_method (
352
411
_original_methods ["chat_completions_create" ], self , "openai" , "/v1/chat/completions" , * args , ** kwargs
@@ -363,6 +422,7 @@ async def patched_embeddings_create(self, *args, **kwargs):
363
422
)
364
423
365
424
# Apply OpenAI patches
425
+ AsyncModels .list = patched_model_list
366
426
AsyncChatCompletions .create = patched_chat_completions_create
367
427
AsyncCompletions .create = patched_completions_create
368
428
AsyncEmbeddings .create = patched_embeddings_create
@@ -419,8 +479,10 @@ def unpatch_inference_clients():
419
479
from openai .resources .chat .completions import AsyncCompletions as AsyncChatCompletions
420
480
from openai .resources .completions import AsyncCompletions
421
481
from openai .resources .embeddings import AsyncEmbeddings
482
+ from openai .resources .models import AsyncModels
422
483
423
484
# Restore OpenAI client methods
485
+ AsyncModels .list = _original_methods ["model_list" ]
424
486
AsyncChatCompletions .create = _original_methods ["chat_completions_create" ]
425
487
AsyncCompletions .create = _original_methods ["completions_create" ]
426
488
AsyncEmbeddings .create = _original_methods ["embeddings_create" ]
0 commit comments