Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/3749.internal.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Remove websockets deprecation warning by using the asyncio websocket provider
2 changes: 0 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from websockets import (
WebSocketException,
)
from websockets.legacy.client import (
connect,
)

Expand Down
7 changes: 0 additions & 7 deletions web3/providers/async_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@
)

if TYPE_CHECKING:
from websockets.legacy.client import (
WebSocketClientProtocol,
)

from web3 import ( # noqa: F401
AsyncWeb3,
WebSocketProvider,
Expand Down Expand Up @@ -174,9 +170,6 @@ async def disconnect(self) -> None:
"Persistent connection providers must implement this method"
)

# WebSocket typing
_ws: "WebSocketClientProtocol"

# IPC typing
_reader: Optional[asyncio.StreamReader]
_writer: Optional[asyncio.StreamWriter]
Expand Down
34 changes: 20 additions & 14 deletions web3/providers/legacy_websocket.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import (
annotations,
)

import asyncio
import json
import logging
Expand All @@ -8,23 +12,16 @@
from types import (
TracebackType,
)
import typing
from typing import (
Any,
List,
Optional,
Tuple,
Type,
Union,
cast,
)

from eth_typing import (
URI,
)
from websockets.legacy.client import (
WebSocketClientProtocol,
connect,
)

from web3._utils.batching import (
sort_batch_response_by_response_ids,
Expand All @@ -43,6 +40,11 @@
RPCResponse,
)

if typing.TYPE_CHECKING:
from websockets.legacy.client import (
WebSocketClientProtocol,
)

RESTRICTED_WEBSOCKET_KWARGS = {"uri", "loop"}
DEFAULT_WEBSOCKET_TIMEOUT = 30

Expand All @@ -66,18 +68,22 @@ def get_default_endpoint() -> URI:

class PersistentWebSocket:
def __init__(self, endpoint_uri: URI, websocket_kwargs: Any) -> None:
self.ws: Optional[WebSocketClientProtocol] = None
self.ws: WebSocketClientProtocol | None = None
self.endpoint_uri = endpoint_uri
self.websocket_kwargs = websocket_kwargs

async def __aenter__(self) -> WebSocketClientProtocol:
if self.ws is None:
from websockets.legacy.client import (
connect,
)

self.ws = await connect(uri=self.endpoint_uri, **self.websocket_kwargs)
return self.ws

async def __aexit__(
self,
exc_type: Type[BaseException],
exc_type: type[BaseException],
exc_val: BaseException,
exc_tb: TracebackType,
) -> None:
Expand All @@ -95,8 +101,8 @@ class LegacyWebSocketProvider(JSONBaseProvider):

def __init__(
self,
endpoint_uri: Optional[Union[URI, str]] = None,
websocket_kwargs: Optional[Any] = None,
endpoint_uri: URI | str | None = None,
websocket_kwargs: Any | None = None,
websocket_timeout: int = DEFAULT_WEBSOCKET_TIMEOUT,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -144,8 +150,8 @@ def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:
return future.result()

def make_batch_request(
self, requests: List[Tuple[RPCEndpoint, Any]]
) -> List[RPCResponse]:
self, requests: list[tuple[RPCEndpoint, Any]]
) -> list[RPCResponse]:
self.logger.debug(
"Making batch request WebSocket. URI: %s, Methods: %s",
self.endpoint_uri,
Expand Down
36 changes: 23 additions & 13 deletions web3/providers/persistent/websocket.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import (
annotations,
)

import asyncio
import json
import logging
import os
from typing import (
Any,
Dict,
Optional,
Union,
)

from eth_typing import (
Expand All @@ -15,14 +16,21 @@
from toolz import (
merge,
)

# python3.8 supports up to version 13,
# which does not default to the asyncio implementation yet.
# For this reason connect and ClientConnection need to be imported
# from asyncio.client explicitly.
# When web3.py stops supporting python3.8,
# it'll be possible to use `from websockets import connect, ClientConnection`.
from websockets.asyncio.client import (
ClientConnection,
connect,
)
from websockets.exceptions import (
ConnectionClosedOK,
WebSocketException,
)
from websockets.legacy.client import (
WebSocketClientProtocol,
connect,
)

from web3.exceptions import (
PersistentConnectionClosedOK,
Expand Down Expand Up @@ -57,12 +65,14 @@ class WebSocketProvider(PersistentConnectionProvider):
logger = logging.getLogger("web3.providers.WebSocketProvider")
is_async: bool = True

_ws: ClientConnection

def __init__(
self,
endpoint_uri: Optional[Union[URI, str]] = None,
websocket_kwargs: Optional[Dict[str, Any]] = None,
endpoint_uri: URI | str | None = None,
websocket_kwargs: dict[str, Any] | None = None,
# uses binary frames by default
use_text_frames: Optional[bool] = False,
use_text_frames: bool | None = False,
# `PersistentConnectionProvider` kwargs can be passed through
**kwargs: Any,
) -> None:
Expand All @@ -72,7 +82,7 @@ def __init__(
)
super().__init__(**kwargs)
self.use_text_frames = use_text_frames
self._ws: Optional[WebSocketClientProtocol] = None
self._ws: ClientConnection | None = None

if not any(
self.endpoint_uri.startswith(prefix)
Expand Down Expand Up @@ -119,7 +129,7 @@ async def socket_send(self, request_data: bytes) -> None:
"Connection to websocket has not been initiated for the provider."
)

payload: Union[bytes, str] = request_data
payload: bytes | str = request_data
if self.use_text_frames:
payload = request_data.decode("utf-8")

Expand All @@ -136,7 +146,7 @@ async def _provider_specific_connect(self) -> None:

async def _provider_specific_disconnect(self) -> None:
# this should remain idempotent
if self._ws is not None and not self._ws.closed:
if self._ws is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we still need to figure out how to check for self._ws.closed, even if it's not that same method name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, I did check this and I found out that self._ws.close() is idempotent already, what I didn't realise though is that web3.py still supports websockets >=10.0,<13.0, meaning me importing directly from websockets.asyncio would break that compatibility.
So I would need to rework my PR to take that into account.

I also noticed this issue #3679 which plans to move websockets bottom pin to >=14 #3530, which would essentially remove this problem completely.
Should I just do that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could either go down the path of supporting all the variations (aka <13; >=13,<14; >=14), i.e.

import websockets
websockets_version = tuple(int(x) for x in websockets.__version__.split(".") if x.isdigit())

if websockets_version < (13, 0):
    from websockets.legacy.client import (
        WebSocketClientProtocol as ClientConnection,  # we are safe to steal the name as the scope of interface we use is the same
        connect,
    )
else:
    # python3.8 supports up to version 13,
    # which does not default to the asyncio implementation yet.
    # For this reason connect and ClientConnection need to be imported
    # from asyncio.client explicitly.
    # When web3.py stops supporting python3.8,
    # it'll be possible to use `from websockets import connect, ClientConnection`.
    from websockets.asyncio.client import (
        ClientConnection,
        connect,
    )

or simply drop support for at least version <13.

Given the amount of usage of the websockets library interface either solution is fine, not a big deal.

await self._ws.close()
self._ws = None

Expand Down