diff --git a/newsfragments/3749.internal.rst b/newsfragments/3749.internal.rst new file mode 100644 index 0000000000..2c4a8c63ed --- /dev/null +++ b/newsfragments/3749.internal.rst @@ -0,0 +1 @@ +Remove websockets deprecation warning by using the asyncio websocket provider diff --git a/tests/utils.py b/tests/utils.py index 4779d33658..827857d70c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,8 +4,6 @@ from websockets import ( WebSocketException, -) -from websockets.legacy.client import ( connect, ) diff --git a/web3/providers/async_base.py b/web3/providers/async_base.py index b1f6314fa1..fa393b036d 100644 --- a/web3/providers/async_base.py +++ b/web3/providers/async_base.py @@ -54,10 +54,6 @@ ) if TYPE_CHECKING: - from websockets.legacy.client import ( - WebSocketClientProtocol, - ) - from web3 import ( # noqa: F401 AsyncWeb3, WebSocketProvider, @@ -174,9 +170,6 @@ async def disconnect(self) -> None: "Persistent connection providers must implement this method" ) - # WebSocket typing - _ws: "WebSocketClientProtocol" - # IPC typing _reader: Optional[asyncio.StreamReader] _writer: Optional[asyncio.StreamWriter] diff --git a/web3/providers/legacy_websocket.py b/web3/providers/legacy_websocket.py index a4907c3326..0744124bc7 100644 --- a/web3/providers/legacy_websocket.py +++ b/web3/providers/legacy_websocket.py @@ -1,3 +1,7 @@ +from __future__ import ( + annotations, +) + import asyncio import json import logging @@ -9,22 +13,15 @@ TracebackType, ) from typing import ( + TYPE_CHECKING, Any, List, - Optional, - Tuple, - Type, - Union, cast, ) from eth_typing import ( URI, ) -from websockets.legacy.client import ( - WebSocketClientProtocol, - connect, -) from web3._utils.batching import ( sort_batch_response_by_response_ids, @@ -43,6 +40,11 @@ RPCResponse, ) +if TYPE_CHECKING: + from websockets.legacy.client import ( + WebSocketClientProtocol, + ) + RESTRICTED_WEBSOCKET_KWARGS = {"uri", "loop"} DEFAULT_WEBSOCKET_TIMEOUT = 30 @@ -66,18 +68,22 @@ def get_default_endpoint() -> URI: class PersistentWebSocket: def __init__(self, endpoint_uri: URI, websocket_kwargs: Any) -> None: - self.ws: Optional[WebSocketClientProtocol] = None + self.ws: WebSocketClientProtocol | None = None self.endpoint_uri = endpoint_uri self.websocket_kwargs = websocket_kwargs async def __aenter__(self) -> WebSocketClientProtocol: if self.ws is None: + from websockets.legacy.client import ( + connect, + ) + self.ws = await connect(uri=self.endpoint_uri, **self.websocket_kwargs) return self.ws async def __aexit__( self, - exc_type: Type[BaseException], + exc_type: type[BaseException], exc_val: BaseException, exc_tb: TracebackType, ) -> None: @@ -95,8 +101,8 @@ class LegacyWebSocketProvider(JSONBaseProvider): def __init__( self, - endpoint_uri: Optional[Union[URI, str]] = None, - websocket_kwargs: Optional[Any] = None, + endpoint_uri: URI | str | None = None, + websocket_kwargs: Any | None = None, websocket_timeout: int = DEFAULT_WEBSOCKET_TIMEOUT, **kwargs: Any, ) -> None: @@ -144,8 +150,8 @@ def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: return future.result() def make_batch_request( - self, requests: List[Tuple[RPCEndpoint, Any]] - ) -> List[RPCResponse]: + self, requests: list[tuple[RPCEndpoint, Any]] + ) -> list[RPCResponse]: self.logger.debug( "Making batch request WebSocket. URI: %s, Methods: %s", self.endpoint_uri, diff --git a/web3/providers/persistent/websocket.py b/web3/providers/persistent/websocket.py index f5035adf76..c05560a101 100644 --- a/web3/providers/persistent/websocket.py +++ b/web3/providers/persistent/websocket.py @@ -1,12 +1,13 @@ +from __future__ import ( + annotations, +) + import asyncio import json import logging import os from typing import ( Any, - Dict, - Optional, - Union, ) from eth_typing import ( @@ -15,14 +16,21 @@ from toolz import ( merge, ) + +# python3.8 supports up to version 13, +# which does not default to the asyncio implementation yet. +# For this reason connect and ClientConnection need to be imported +# from asyncio.client explicitly. +# When web3.py stops supporting python3.8, +# it'll be possible to use `from websockets import connect, ClientConnection`. +from websockets.asyncio.client import ( + ClientConnection, + connect, +) from websockets.exceptions import ( ConnectionClosedOK, WebSocketException, ) -from websockets.legacy.client import ( - WebSocketClientProtocol, - connect, -) from web3.exceptions import ( PersistentConnectionClosedOK, @@ -57,12 +65,14 @@ class WebSocketProvider(PersistentConnectionProvider): logger = logging.getLogger("web3.providers.WebSocketProvider") is_async: bool = True + _ws: ClientConnection + def __init__( self, - endpoint_uri: Optional[Union[URI, str]] = None, - websocket_kwargs: Optional[Dict[str, Any]] = None, + endpoint_uri: URI | str | None = None, + websocket_kwargs: dict[str, Any] | None = None, # uses binary frames by default - use_text_frames: Optional[bool] = False, + use_text_frames: bool | None = False, # `PersistentConnectionProvider` kwargs can be passed through **kwargs: Any, ) -> None: @@ -72,7 +82,7 @@ def __init__( ) super().__init__(**kwargs) self.use_text_frames = use_text_frames - self._ws: Optional[WebSocketClientProtocol] = None + self._ws: ClientConnection | None = None if not any( self.endpoint_uri.startswith(prefix) @@ -119,7 +129,7 @@ async def socket_send(self, request_data: bytes) -> None: "Connection to websocket has not been initiated for the provider." ) - payload: Union[bytes, str] = request_data + payload: bytes | str = request_data if self.use_text_frames: payload = request_data.decode("utf-8") @@ -136,7 +146,7 @@ async def _provider_specific_connect(self) -> None: async def _provider_specific_disconnect(self) -> None: # this should remain idempotent - if self._ws is not None and not self._ws.closed: + if self._ws is not None: await self._ws.close() self._ws = None