Skip to content

Commit 57125ad

Browse files
authored
[serve] Move logic into user callable wrapper (#54177)
## Why are these changes needed? Move logic in `_call_streaming` into `UserCallableWrapper` and `MessageQueue`. Now all `handle_request_...` methods are calling into `UserCallableWrapper` directly, since logic for calling user methods is all living in `UserCallableWrapper`. --------- Signed-off-by: Cindy Zhang <cindyzyx9@gmail.com>
1 parent aec6fe9 commit 57125ad

File tree

4 files changed

+176
-118
lines changed

4 files changed

+176
-118
lines changed

python/ray/serve/_private/http_util.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,17 @@
77
from collections import deque
88
from copy import deepcopy
99
from dataclasses import dataclass
10-
from typing import Any, Awaitable, Callable, List, Optional, Tuple, Type, Union
10+
from typing import (
11+
Any,
12+
AsyncGenerator,
13+
Awaitable,
14+
Callable,
15+
List,
16+
Optional,
17+
Tuple,
18+
Type,
19+
Union,
20+
)
1121

1222
import starlette
1323
import uvicorn
@@ -265,6 +275,49 @@ async def get_one_message(self) -> Message:
265275
elif len(self._message_queue) == 0 and self._closed:
266276
raise StopAsyncIteration
267277

278+
async def fetch_messages_from_queue(
279+
self, call_fut: asyncio.Future
280+
) -> AsyncGenerator[List[Any], None]:
281+
"""Repeatedly consume messages from the queue and yield them.
282+
283+
This is used to fetch queue messages in the system event loop in
284+
a thread-safe manner.
285+
286+
Args:
287+
call_fut: The async Future pointing to the task from the user
288+
code event loop that is pushing messages onto the queue.
289+
290+
Yields:
291+
List[Any]: Messages from the queue.
292+
"""
293+
# Repeatedly consume messages from the queue.
294+
wait_for_msg_task = None
295+
try:
296+
while True:
297+
wait_for_msg_task = asyncio.create_task(self.wait_for_message())
298+
done, _ = await asyncio.wait(
299+
[call_fut, wait_for_msg_task], return_when=asyncio.FIRST_COMPLETED
300+
)
301+
302+
messages = self.get_messages_nowait()
303+
if messages:
304+
yield messages
305+
306+
# Exit once `call_fut` has finished. In this case, all
307+
# messages must have already been sent.
308+
if call_fut in done:
309+
break
310+
311+
e = call_fut.exception()
312+
if e is not None:
313+
raise e from None
314+
finally:
315+
if not call_fut.done():
316+
call_fut.cancel()
317+
318+
if wait_for_msg_task is not None and not wait_for_msg_task.done():
319+
wait_for_msg_task.cancel()
320+
268321

269322
class ASGIReceiveProxy:
270323
"""Proxies ASGI receive from an actor.

python/ray/serve/_private/local_testing_mode.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,15 @@ def generator_result_callback(item: Any):
299299
generator_result_callback = None
300300

301301
# Conform to the router interface of returning a future to the ReplicaResult.
302-
if request_meta.is_streaming:
303-
fut = self._user_callable_wrapper.call_user_generator(
302+
if request_meta.is_http_request:
303+
fut = self._user_callable_wrapper._call_http_entrypoint(
304+
request_meta,
305+
request_args,
306+
request_kwargs,
307+
generator_result_callback=generator_result_callback,
308+
)
309+
elif request_meta.is_streaming:
310+
fut = self._user_callable_wrapper._call_user_generator(
304311
request_meta,
305312
request_args,
306313
request_kwargs,

python/ray/serve/_private/replica.py

Lines changed: 102 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -577,97 +577,6 @@ def _record_errors_and_metrics(
577577
was_error=user_exception is not None,
578578
)
579579

580-
async def _call_streaming(
581-
self,
582-
request_metadata: RequestMetadata,
583-
request_args: Tuple[Any],
584-
request_kwargs: Dict[str, Any],
585-
status_code_callback: StatusCodeCallback,
586-
) -> AsyncGenerator[Any, None]:
587-
"""Calls a user method for a streaming call and yields its results.
588-
589-
The user method is called in an asyncio `Task` and places its results on a
590-
`result_queue`. This method pulls and yields from the `result_queue`.
591-
"""
592-
call_user_method_future = None
593-
wait_for_message_task = None
594-
try:
595-
result_queue = MessageQueue()
596-
597-
# `asyncio.Event`s are not thread safe, so `call_soon_threadsafe` must be
598-
# used to interact with the result queue from the user callable thread.
599-
def _enqueue_thread_safe(item: Any):
600-
self._event_loop.call_soon_threadsafe(result_queue.put_nowait, item)
601-
602-
if request_metadata.is_http_request:
603-
call_user_method_future = (
604-
self._user_callable_wrapper.call_http_entrypoint(
605-
request_metadata,
606-
request_args,
607-
request_kwargs,
608-
generator_result_callback=_enqueue_thread_safe,
609-
)
610-
)
611-
else:
612-
call_user_method_future = (
613-
self._user_callable_wrapper.call_user_generator(
614-
request_metadata,
615-
request_args,
616-
request_kwargs,
617-
generator_result_callback=_enqueue_thread_safe,
618-
)
619-
)
620-
621-
first_message_peeked = False
622-
while True:
623-
wait_for_message_task = self._event_loop.create_task(
624-
result_queue.wait_for_message()
625-
)
626-
done, _ = await asyncio.wait(
627-
[call_user_method_future, wait_for_message_task],
628-
return_when=asyncio.FIRST_COMPLETED,
629-
)
630-
631-
# Consume and yield all available messages in the queue.
632-
messages = result_queue.get_messages_nowait()
633-
if messages:
634-
# HTTP (ASGI) messages are only consumed by the proxy so batch them
635-
# and use vanilla pickle (we know it's safe because these messages
636-
# only contain primitive Python types).
637-
if request_metadata.is_http_request:
638-
# Peek the first ASGI message to determine the status code.
639-
if not first_message_peeked:
640-
msg = messages[0]
641-
first_message_peeked = True
642-
if msg["type"] == "http.response.start":
643-
# HTTP responses begin with exactly one
644-
# "http.response.start" message containing the "status"
645-
# field. Other response types like WebSockets may not.
646-
status_code_callback(str(msg["status"]))
647-
648-
yield pickle.dumps(messages)
649-
else:
650-
for msg in messages:
651-
yield msg
652-
653-
# Exit once `call_user_generator` has finished. In this case, all
654-
# messages must have already been sent.
655-
if call_user_method_future in done:
656-
break
657-
658-
e = call_user_method_future.exception()
659-
if e is not None:
660-
raise e from None
661-
finally:
662-
if (
663-
call_user_method_future is not None
664-
and not call_user_method_future.done()
665-
):
666-
call_user_method_future.cancel()
667-
668-
if wait_for_message_task is not None and not wait_for_message_task.done():
669-
wait_for_message_task.cancel()
670-
671580
async def handle_request(
672581
self, request_metadata: RequestMetadata, *request_args, **request_kwargs
673582
) -> Tuple[bytes, Any]:
@@ -683,13 +592,21 @@ async def handle_request_streaming(
683592
async with self._wrap_user_method_call(
684593
request_metadata, request_args
685594
) as status_code_callback:
686-
async for result in self._call_streaming(
687-
request_metadata,
688-
request_args,
689-
request_kwargs,
690-
status_code_callback=status_code_callback,
691-
):
692-
yield result
595+
if request_metadata.is_http_request:
596+
async for result in self._user_callable_wrapper.call_http_entrypoint(
597+
request_metadata,
598+
request_args,
599+
request_kwargs,
600+
status_code_callback=status_code_callback,
601+
):
602+
yield result
603+
else:
604+
async for result in self._user_callable_wrapper.call_user_generator(
605+
request_metadata,
606+
request_args,
607+
request_kwargs,
608+
):
609+
yield result
693610

694611
async def handle_request_with_rejection(
695612
self, request_metadata: RequestMetadata, *request_args, **request_kwargs
@@ -715,14 +632,21 @@ async def handle_request_with_rejection(
715632
num_ongoing_requests=self.get_num_ongoing_requests(),
716633
)
717634

718-
if request_metadata.is_streaming:
719-
async for result in self._call_streaming(
635+
if request_metadata.is_http_request:
636+
async for result in self._user_callable_wrapper.call_http_entrypoint(
720637
request_metadata,
721638
request_args,
722639
request_kwargs,
723640
status_code_callback=status_code_callback,
724641
):
725642
yield result
643+
elif request_metadata.is_streaming:
644+
async for result in self._user_callable_wrapper.call_user_generator(
645+
request_metadata,
646+
request_args,
647+
request_kwargs,
648+
):
649+
yield result
726650
else:
727651
yield await self._user_callable_wrapper.call_user_method(
728652
request_metadata, request_args, request_kwargs
@@ -1674,8 +1598,49 @@ async def _handle_user_method_result(
16741598

16751599
return result
16761600

1677-
@_run_user_code
16781601
async def call_http_entrypoint(
1602+
self,
1603+
request_metadata: RequestMetadata,
1604+
request_args: Tuple[Any],
1605+
request_kwargs: Dict[str, Any],
1606+
status_code_callback: StatusCodeCallback,
1607+
) -> Any:
1608+
result_queue = MessageQueue()
1609+
1610+
# `asyncio.Event`s are not thread safe, so `call_soon_threadsafe` must be
1611+
# used to interact with the result queue from the user callable thread.
1612+
system_event_loop = asyncio.get_running_loop()
1613+
1614+
def _enqueue_thread_safe(item: Any):
1615+
system_event_loop.call_soon_threadsafe(result_queue.put_nowait, item)
1616+
1617+
call_user_method_future = self._call_http_entrypoint(
1618+
request_metadata,
1619+
request_args,
1620+
request_kwargs,
1621+
generator_result_callback=_enqueue_thread_safe,
1622+
)
1623+
first_message_peeked = False
1624+
async for messages in result_queue.fetch_messages_from_queue(
1625+
call_user_method_future
1626+
):
1627+
# HTTP (ASGI) messages are only consumed by the proxy so batch them
1628+
# and use vanilla pickle (we know it's safe because these messages
1629+
# only contain primitive Python types).
1630+
# Peek the first ASGI message to determine the status code.
1631+
if not first_message_peeked:
1632+
msg = messages[0]
1633+
first_message_peeked = True
1634+
if msg["type"] == "http.response.start":
1635+
# HTTP responses begin with exactly one
1636+
# "http.response.start" message containing the "status"
1637+
# field. Other response types like WebSockets may not.
1638+
status_code_callback(str(msg["status"]))
1639+
1640+
yield pickle.dumps(messages)
1641+
1642+
@_run_user_code
1643+
async def _call_http_entrypoint(
16791644
self,
16801645
request_metadata: RequestMetadata,
16811646
request_args: Tuple[Any],
@@ -1691,7 +1656,7 @@ async def call_http_entrypoint(
16911656
Raises any exception raised by the user code so it can be propagated as a
16921657
`RayTaskError`.
16931658
"""
1694-
self._raise_if_not_initialized("call_http_entrypoint")
1659+
self._raise_if_not_initialized("_call_http_entrypoint")
16951660

16961661
logger.info(
16971662
f"Started executing request to method '{request_metadata.call_method}'.",
@@ -1764,12 +1729,45 @@ async def call_http_entrypoint(
17641729

17651730
raise
17661731

1767-
@_run_user_code
17681732
async def call_user_generator(
17691733
self,
17701734
request_metadata: RequestMetadata,
17711735
request_args: Tuple[Any],
17721736
request_kwargs: Dict[str, Any],
1737+
) -> AsyncGenerator[Any, None]:
1738+
"""Calls a user method for a streaming call and yields its results.
1739+
1740+
The user method is called in an asyncio `Task` and places its results on a
1741+
`result_queue`. This method pulls and yields from the `result_queue`.
1742+
"""
1743+
result_queue = MessageQueue()
1744+
1745+
# `asyncio.Event`s are not thread safe, so `call_soon_threadsafe` must be
1746+
# used to interact with the result queue from the user callable thread.
1747+
system_event_loop = asyncio.get_running_loop()
1748+
1749+
def _enqueue_thread_safe(item: Any):
1750+
system_event_loop.call_soon_threadsafe(result_queue.put_nowait, item)
1751+
1752+
call_user_method_future = self._call_user_generator(
1753+
request_metadata,
1754+
request_args,
1755+
request_kwargs,
1756+
generator_result_callback=_enqueue_thread_safe,
1757+
)
1758+
1759+
async for messages in result_queue.fetch_messages_from_queue(
1760+
call_user_method_future
1761+
):
1762+
for msg in messages:
1763+
yield msg
1764+
1765+
@_run_user_code
1766+
async def _call_user_generator(
1767+
self,
1768+
request_metadata: RequestMetadata,
1769+
request_args: Tuple[Any],
1770+
request_kwargs: Dict[str, Any],
17731771
*,
17741772
generator_result_callback: Optional[Callable] = None,
17751773
) -> Any:
@@ -1781,7 +1779,7 @@ async def call_user_generator(
17811779
Raises any exception raised by the user code so it can be propagated as a
17821780
`RayTaskError`.
17831781
"""
1784-
self._raise_if_not_initialized("call_user_generator")
1782+
self._raise_if_not_initialized("_call_user_generator")
17851783

17861784
logger.info(
17871785
f"Started executing request to method '{request_metadata.call_method}'.",

0 commit comments

Comments
 (0)