@@ -577,97 +577,6 @@ def _record_errors_and_metrics(
577577 was_error = user_exception is not None ,
578578 )
579579
580- async def _call_streaming (
581- self ,
582- request_metadata : RequestMetadata ,
583- request_args : Tuple [Any ],
584- request_kwargs : Dict [str , Any ],
585- status_code_callback : StatusCodeCallback ,
586- ) -> AsyncGenerator [Any , None ]:
587- """Calls a user method for a streaming call and yields its results.
588-
589- The user method is called in an asyncio `Task` and places its results on a
590- `result_queue`. This method pulls and yields from the `result_queue`.
591- """
592- call_user_method_future = None
593- wait_for_message_task = None
594- try :
595- result_queue = MessageQueue ()
596-
597- # `asyncio.Event`s are not thread safe, so `call_soon_threadsafe` must be
598- # used to interact with the result queue from the user callable thread.
599- def _enqueue_thread_safe (item : Any ):
600- self ._event_loop .call_soon_threadsafe (result_queue .put_nowait , item )
601-
602- if request_metadata .is_http_request :
603- call_user_method_future = (
604- self ._user_callable_wrapper .call_http_entrypoint (
605- request_metadata ,
606- request_args ,
607- request_kwargs ,
608- generator_result_callback = _enqueue_thread_safe ,
609- )
610- )
611- else :
612- call_user_method_future = (
613- self ._user_callable_wrapper .call_user_generator (
614- request_metadata ,
615- request_args ,
616- request_kwargs ,
617- generator_result_callback = _enqueue_thread_safe ,
618- )
619- )
620-
621- first_message_peeked = False
622- while True :
623- wait_for_message_task = self ._event_loop .create_task (
624- result_queue .wait_for_message ()
625- )
626- done , _ = await asyncio .wait (
627- [call_user_method_future , wait_for_message_task ],
628- return_when = asyncio .FIRST_COMPLETED ,
629- )
630-
631- # Consume and yield all available messages in the queue.
632- messages = result_queue .get_messages_nowait ()
633- if messages :
634- # HTTP (ASGI) messages are only consumed by the proxy so batch them
635- # and use vanilla pickle (we know it's safe because these messages
636- # only contain primitive Python types).
637- if request_metadata .is_http_request :
638- # Peek the first ASGI message to determine the status code.
639- if not first_message_peeked :
640- msg = messages [0 ]
641- first_message_peeked = True
642- if msg ["type" ] == "http.response.start" :
643- # HTTP responses begin with exactly one
644- # "http.response.start" message containing the "status"
645- # field. Other response types like WebSockets may not.
646- status_code_callback (str (msg ["status" ]))
647-
648- yield pickle .dumps (messages )
649- else :
650- for msg in messages :
651- yield msg
652-
653- # Exit once `call_user_generator` has finished. In this case, all
654- # messages must have already been sent.
655- if call_user_method_future in done :
656- break
657-
658- e = call_user_method_future .exception ()
659- if e is not None :
660- raise e from None
661- finally :
662- if (
663- call_user_method_future is not None
664- and not call_user_method_future .done ()
665- ):
666- call_user_method_future .cancel ()
667-
668- if wait_for_message_task is not None and not wait_for_message_task .done ():
669- wait_for_message_task .cancel ()
670-
671580 async def handle_request (
672581 self , request_metadata : RequestMetadata , * request_args , ** request_kwargs
673582 ) -> Tuple [bytes , Any ]:
@@ -683,13 +592,21 @@ async def handle_request_streaming(
683592 async with self ._wrap_user_method_call (
684593 request_metadata , request_args
685594 ) as status_code_callback :
686- async for result in self ._call_streaming (
687- request_metadata ,
688- request_args ,
689- request_kwargs ,
690- status_code_callback = status_code_callback ,
691- ):
692- yield result
595+ if request_metadata .is_http_request :
596+ async for result in self ._user_callable_wrapper .call_http_entrypoint (
597+ request_metadata ,
598+ request_args ,
599+ request_kwargs ,
600+ status_code_callback = status_code_callback ,
601+ ):
602+ yield result
603+ else :
604+ async for result in self ._user_callable_wrapper .call_user_generator (
605+ request_metadata ,
606+ request_args ,
607+ request_kwargs ,
608+ ):
609+ yield result
693610
694611 async def handle_request_with_rejection (
695612 self , request_metadata : RequestMetadata , * request_args , ** request_kwargs
@@ -715,14 +632,21 @@ async def handle_request_with_rejection(
715632 num_ongoing_requests = self .get_num_ongoing_requests (),
716633 )
717634
718- if request_metadata .is_streaming :
719- async for result in self ._call_streaming (
635+ if request_metadata .is_http_request :
636+ async for result in self ._user_callable_wrapper . call_http_entrypoint (
720637 request_metadata ,
721638 request_args ,
722639 request_kwargs ,
723640 status_code_callback = status_code_callback ,
724641 ):
725642 yield result
643+ elif request_metadata .is_streaming :
644+ async for result in self ._user_callable_wrapper .call_user_generator (
645+ request_metadata ,
646+ request_args ,
647+ request_kwargs ,
648+ ):
649+ yield result
726650 else :
727651 yield await self ._user_callable_wrapper .call_user_method (
728652 request_metadata , request_args , request_kwargs
@@ -1674,8 +1598,49 @@ async def _handle_user_method_result(
16741598
16751599 return result
16761600
1677- @_run_user_code
16781601 async def call_http_entrypoint (
1602+ self ,
1603+ request_metadata : RequestMetadata ,
1604+ request_args : Tuple [Any ],
1605+ request_kwargs : Dict [str , Any ],
1606+ status_code_callback : StatusCodeCallback ,
1607+ ) -> Any :
1608+ result_queue = MessageQueue ()
1609+
1610+ # `asyncio.Event`s are not thread safe, so `call_soon_threadsafe` must be
1611+ # used to interact with the result queue from the user callable thread.
1612+ system_event_loop = asyncio .get_running_loop ()
1613+
1614+ def _enqueue_thread_safe (item : Any ):
1615+ system_event_loop .call_soon_threadsafe (result_queue .put_nowait , item )
1616+
1617+ call_user_method_future = self ._call_http_entrypoint (
1618+ request_metadata ,
1619+ request_args ,
1620+ request_kwargs ,
1621+ generator_result_callback = _enqueue_thread_safe ,
1622+ )
1623+ first_message_peeked = False
1624+ async for messages in result_queue .fetch_messages_from_queue (
1625+ call_user_method_future
1626+ ):
1627+ # HTTP (ASGI) messages are only consumed by the proxy so batch them
1628+ # and use vanilla pickle (we know it's safe because these messages
1629+ # only contain primitive Python types).
1630+ # Peek the first ASGI message to determine the status code.
1631+ if not first_message_peeked :
1632+ msg = messages [0 ]
1633+ first_message_peeked = True
1634+ if msg ["type" ] == "http.response.start" :
1635+ # HTTP responses begin with exactly one
1636+ # "http.response.start" message containing the "status"
1637+ # field. Other response types like WebSockets may not.
1638+ status_code_callback (str (msg ["status" ]))
1639+
1640+ yield pickle .dumps (messages )
1641+
1642+ @_run_user_code
1643+ async def _call_http_entrypoint (
16791644 self ,
16801645 request_metadata : RequestMetadata ,
16811646 request_args : Tuple [Any ],
@@ -1691,7 +1656,7 @@ async def call_http_entrypoint(
16911656 Raises any exception raised by the user code so it can be propagated as a
16921657 `RayTaskError`.
16931658 """
1694- self ._raise_if_not_initialized ("call_http_entrypoint " )
1659+ self ._raise_if_not_initialized ("_call_http_entrypoint " )
16951660
16961661 logger .info (
16971662 f"Started executing request to method '{ request_metadata .call_method } '." ,
@@ -1764,12 +1729,45 @@ async def call_http_entrypoint(
17641729
17651730 raise
17661731
1767- @_run_user_code
17681732 async def call_user_generator (
17691733 self ,
17701734 request_metadata : RequestMetadata ,
17711735 request_args : Tuple [Any ],
17721736 request_kwargs : Dict [str , Any ],
1737+ ) -> AsyncGenerator [Any , None ]:
1738+ """Calls a user method for a streaming call and yields its results.
1739+
1740+ The user method is called in an asyncio `Task` and places its results on a
1741+ `result_queue`. This method pulls and yields from the `result_queue`.
1742+ """
1743+ result_queue = MessageQueue ()
1744+
1745+ # `asyncio.Event`s are not thread safe, so `call_soon_threadsafe` must be
1746+ # used to interact with the result queue from the user callable thread.
1747+ system_event_loop = asyncio .get_running_loop ()
1748+
1749+ def _enqueue_thread_safe (item : Any ):
1750+ system_event_loop .call_soon_threadsafe (result_queue .put_nowait , item )
1751+
1752+ call_user_method_future = self ._call_user_generator (
1753+ request_metadata ,
1754+ request_args ,
1755+ request_kwargs ,
1756+ generator_result_callback = _enqueue_thread_safe ,
1757+ )
1758+
1759+ async for messages in result_queue .fetch_messages_from_queue (
1760+ call_user_method_future
1761+ ):
1762+ for msg in messages :
1763+ yield msg
1764+
1765+ @_run_user_code
1766+ async def _call_user_generator (
1767+ self ,
1768+ request_metadata : RequestMetadata ,
1769+ request_args : Tuple [Any ],
1770+ request_kwargs : Dict [str , Any ],
17731771 * ,
17741772 generator_result_callback : Optional [Callable ] = None ,
17751773 ) -> Any :
@@ -1781,7 +1779,7 @@ async def call_user_generator(
17811779 Raises any exception raised by the user code so it can be propagated as a
17821780 `RayTaskError`.
17831781 """
1784- self ._raise_if_not_initialized ("call_user_generator " )
1782+ self ._raise_if_not_initialized ("_call_user_generator " )
17851783
17861784 logger .info (
17871785 f"Started executing request to method '{ request_metadata .call_method } '." ,
0 commit comments