1111from typing import Any
1212
1313import orjson
14+ from asgi_tools import ResponseWebSocket
1415from asgiref import typing as asgi_types
1516from asgiref.compatibility import guarantee_single_callable
1617from servestatic import ServeStaticASGI
2627 AsgiHttpApp,
2728 AsgiLifespanApp,
2829 AsgiWebsocketApp,
30+ AsgiWebsocketReceive,
31+ AsgiWebsocketSend,
2932 Connection,
3033 Location,
3134 ReactPyConfig,
@@ -153,41 +156,56 @@ async def __call__(
153156 send: asgi_types.ASGISendCallable,
154157 ) -> None:
155158 """ASGI app for rendering ReactPy Python components."""
156- dispatcher: asyncio.Task[Any] | None = None
157- recv_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
158-
159159 # Start a loop that handles ASGI websocket events
160- while True:
161- event = await receive()
162- if event["type"] == "websocket.connect":
163- await send(
164- {"type": "websocket.accept", "subprotocol": None, "headers": []}
165- )
166- dispatcher = asyncio.create_task(
167- self.run_dispatcher(scope, receive, send, recv_queue)
168- )
169-
170- elif event["type"] == "websocket.disconnect":
171- if dispatcher:
172- dispatcher.cancel()
173- break
174-
175- elif event["type"] == "websocket.receive" and event["text"]:
176- queue_put_func = recv_queue.put(orjson.loads(event["text"]))
177- await queue_put_func
178-
179- async def run_dispatcher(
160+ async with ReactPyWebsocket(scope, receive, send, parent=self.parent) as ws: # type: ignore
161+ while True:
162+ # Wait for the webserver to notify us of a new event
163+ event: dict[str, Any] = await ws.receive(raw=True) # type: ignore
164+
165+ # If the event is a `receive` event, parse the message and send it to the rendering queue
166+ if event["type"] == "websocket.receive":
167+ msg: dict[str, str] = orjson.loads(event["text"])
168+ if msg.get("type") == "layout-event":
169+ await ws.rendering_queue.put(msg)
170+ else: # pragma: no cover
171+ await asyncio.to_thread(
172+ _logger.warning, f"Unknown message type: {msg.get('type')}"
173+ )
174+
175+ # If the event is a `disconnect` event, break the rendering loop and close the connection
176+ elif event["type"] == "websocket.disconnect":
177+ break
178+
179+
180+ class ReactPyWebsocket(ResponseWebSocket):
181+ def __init__(
180182 self,
181183 scope: asgi_types.WebSocketScope,
182- receive: asgi_types.ASGIReceiveCallable ,
183- send: asgi_types.ASGISendCallable ,
184- recv_queue: asyncio.Queue[dict[str, Any]] ,
184+ receive: AsgiWebsocketReceive ,
185+ send: AsgiWebsocketSend ,
186+ parent: ReactPyMiddleware ,
185187 ) -> None:
186- """Asyncio background task that renders and transmits layout updates of ReactPy components."""
188+ super().__init__(scope=scope, receive=receive, send=send) # type: ignore
189+ self.scope = scope
190+ self.parent = parent
191+ self.rendering_queue: asyncio.Queue[dict[str, str]] = asyncio.Queue()
192+ self.dispatcher: asyncio.Task[Any] | None = None
193+
194+ async def __aenter__(self) -> ReactPyWebsocket:
195+ self.dispatcher = asyncio.create_task(self.run_dispatcher())
196+ return await super().__aenter__() # type: ignore
197+
198+ async def __aexit__(self, *_: Any) -> None:
199+ if self.dispatcher:
200+ self.dispatcher.cancel()
201+ await super().__aexit__() # type: ignore
202+
203+ async def run_dispatcher(self) -> None:
204+ """Async background task that renders ReactPy components over a websocket."""
187205 try:
188206 # Determine component to serve by analyzing the URL and/or class parameters.
189207 if self.parent.multiple_root_components:
190- url_match = re.match(self.parent.dispatcher_pattern, scope["path"])
208+ url_match = re.match(self.parent.dispatcher_pattern, self. scope["path"])
191209 if not url_match: # pragma: no cover
192210 raise RuntimeError("Could not find component in URL path.")
193211 dotted_path = url_match["dotted_path"]
@@ -203,10 +221,10 @@ async def run_dispatcher(
203221
204222 # Create a connection object by analyzing the websocket's query string.
205223 ws_query_string = urllib.parse.parse_qs(
206- scope["query_string"].decode(), strict_parsing=True
224+ self. scope["query_string"].decode(), strict_parsing=True
207225 )
208226 connection = Connection(
209- scope=scope,
227+ scope=self. scope,
210228 location=Location(
211229 path=ws_query_string.get("http_pathname", [""])[0],
212230 query_string=ws_query_string.get("http_query_string", [""])[0],
@@ -217,20 +235,19 @@ async def run_dispatcher(
217235 # Start the ReactPy component rendering loop
218236 await serve_layout(
219237 Layout(ConnectionContext(component(), value=connection)),
220- lambda msg: send(
221- {
222- "type": "websocket.send",
223- "text": orjson.dumps(msg).decode(),
224- "bytes": None,
225- }
226- ),
227- recv_queue.get, # type: ignore
238+ self.send_json,
239+ self.rendering_queue.get, # type: ignore
228240 )
229241
230242 # Manually log exceptions since this function is running in a separate asyncio task.
231243 except Exception as error:
232244 await asyncio.to_thread(_logger.error, f"{error}\n{traceback.format_exc()}")
233245
246+ async def send_json(self, data: Any) -> None:
247+ return await self._send(
248+ {"type": "websocket.send", "text": orjson.dumps(data).decode()}
249+ )
250+
234251
235252@dataclass
236253class StaticFileApp:
0 commit comments