-
Notifications
You must be signed in to change notification settings - Fork 54
V2 Arrow Clients #918
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
V2 Arrow Clients #918
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
5bd0418
Introduce AuthenticatedArrowClient
DarthMax adb8dbb
Introduce JobClient
DarthMax 5f1b3cd
Add MutationClient
DarthMax d16947f
Add MutationClient
DarthMax 1a9ecbb
Add WriteBackClient
DarthMax 5b7e29f
Do not ignore retry_config parameter in AuthenticatedFlightClient
DarthMax File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
|
||
|
||
class ArrowAuthentication(ABC): | ||
@abstractmethod | ||
def auth_pair(self) -> tuple[str, str]: | ||
"""Returns the auth pair used for authentication.""" | ||
pass | ||
|
||
|
||
@dataclass | ||
class UsernamePasswordAuthentication(ArrowAuthentication): | ||
username: str | ||
password: str | ||
|
||
def auth_pair(self) -> tuple[str, str]: | ||
return self.username, self.password |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from typing import Any | ||
|
||
from pydantic import BaseModel, ConfigDict | ||
from pydantic.alias_generators import to_camel | ||
|
||
|
||
class ArrowBaseModel(BaseModel): | ||
model_config = ConfigDict(alias_generator=to_camel) | ||
|
||
def dump_camel(self) -> dict[str, Any]: | ||
return self.model_dump(by_alias=True) | ||
|
||
def dump_json(self) -> str: | ||
return self.model_dump_json(by_alias=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
|
||
from ..query_runner.query_runner import QueryRunner | ||
from ..server_version.server_version import ServerVersion | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ArrowInfo: | ||
listenAddress: str | ||
enabled: bool | ||
running: bool | ||
versions: list[str] | ||
|
||
@staticmethod | ||
def create(query_runner: QueryRunner) -> ArrowInfo: | ||
debugYields = ["listenAddress", "enabled", "running"] | ||
if query_runner.server_version() > ServerVersion(2, 6, 0): | ||
debugYields.append("versions") | ||
|
||
procResult = query_runner.call_procedure( | ||
endpoint="gds.debug.arrow", custom_error=False, yields=debugYields | ||
).iloc[0] | ||
|
||
return ArrowInfo( | ||
listenAddress=procResult["listenAddress"], | ||
enabled=procResult["enabled"], | ||
running=procResult["running"], | ||
versions=procResult.get("versions", []), | ||
) |
195 changes: 195 additions & 0 deletions
195
graphdatascience/arrow_client/authenticated_flight_client.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from dataclasses import dataclass | ||
from typing import Any, Iterator, Optional | ||
|
||
from pyarrow import __version__ as arrow_version | ||
from pyarrow import flight | ||
from pyarrow._flight import ( | ||
Action, | ||
FlightInternalError, | ||
FlightStreamReader, | ||
FlightTimedOutError, | ||
FlightUnavailableError, | ||
Result, | ||
Ticket, | ||
) | ||
from tenacity import retry, retry_any, retry_if_exception_type, stop_after_attempt, stop_after_delay, wait_exponential | ||
|
||
from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication | ||
from graphdatascience.arrow_client.arrow_info import ArrowInfo | ||
from graphdatascience.retry_utils.retry_config import RetryConfig | ||
|
||
from ..retry_utils.retry_utils import before_log | ||
from ..version import __version__ | ||
from .middleware.auth_middleware import AuthFactory, AuthMiddleware | ||
from .middleware.user_agent_middleware import UserAgentFactory | ||
|
||
|
||
class AuthenticatedArrowClient: | ||
@staticmethod | ||
def create( | ||
arrow_info: ArrowInfo, | ||
auth: Optional[ArrowAuthentication] = None, | ||
encrypted: bool = False, | ||
arrow_client_options: Optional[dict[str, Any]] = None, | ||
connection_string_override: Optional[str] = None, | ||
retry_config: Optional[RetryConfig] = None, | ||
) -> AuthenticatedArrowClient: | ||
connection_string: str | ||
if connection_string_override is not None: | ||
connection_string = connection_string_override | ||
else: | ||
connection_string = arrow_info.listenAddress | ||
|
||
host, port = connection_string.split(":") | ||
|
||
if retry_config is None: | ||
retry_config = RetryConfig( | ||
retry=retry_any( | ||
retry_if_exception_type(FlightTimedOutError), | ||
retry_if_exception_type(FlightUnavailableError), | ||
retry_if_exception_type(FlightInternalError), | ||
), | ||
stop=(stop_after_delay(10) | stop_after_attempt(5)), | ||
wait=wait_exponential(multiplier=1, min=1, max=10), | ||
) | ||
|
||
return AuthenticatedArrowClient( | ||
host=host, | ||
retry_config=retry_config, | ||
port=int(port), | ||
auth=auth, | ||
encrypted=encrypted, | ||
arrow_client_options=arrow_client_options, | ||
) | ||
|
||
def __init__( | ||
self, | ||
host: str, | ||
retry_config: RetryConfig, | ||
port: int = 8491, | ||
auth: Optional[ArrowAuthentication] = None, | ||
encrypted: bool = False, | ||
arrow_client_options: Optional[dict[str, Any]] = None, | ||
user_agent: Optional[str] = None, | ||
): | ||
"""Creates a new GdsArrowClient instance. | ||
|
||
Parameters | ||
---------- | ||
host: str | ||
The host address of the GDS Arrow server | ||
port: int | ||
The host port of the GDS Arrow server (default is 8491) | ||
auth: Optional[ArrowAuthentication] | ||
An implementation of ArrowAuthentication providing a pair to be used for basic authentication | ||
encrypted: bool | ||
A flag that indicates whether the connection should be encrypted (default is False) | ||
arrow_client_options: Optional[dict[str, Any]] | ||
Additional options to be passed to the Arrow Flight client. | ||
user_agent: Optional[str] | ||
The user agent string to use for the connection. (default is `neo4j-graphdatascience-v[VERSION] pyarrow-v[PYARROW_VERSION]) | ||
retry_config: Optional[RetryConfig] | ||
The retry configuration to use for the Arrow requests send by the client. | ||
""" | ||
self._host = host | ||
self._port = port | ||
self._auth = None | ||
self._encrypted = encrypted | ||
self._arrow_client_options = arrow_client_options | ||
self._user_agent = user_agent | ||
self._retry_config = retry_config | ||
self._logger = logging.getLogger("gds_arrow_client") | ||
self._retry_config = retry_config | ||
if auth: | ||
self._auth = auth | ||
self._auth_middleware = AuthMiddleware(auth) | ||
|
||
self._flight_client = self._instantiate_flight_client() | ||
|
||
def connection_info(self) -> ConnectionInfo: | ||
""" | ||
Returns the host and port of the GDS Arrow server. | ||
|
||
Returns | ||
------- | ||
tuple[str, int] | ||
the host and port of the GDS Arrow server | ||
""" | ||
return ConnectionInfo(self._host, self._port, self._encrypted) | ||
|
||
def request_token(self) -> Optional[str]: | ||
""" | ||
Requests a token from the server and returns it. | ||
|
||
Returns | ||
------- | ||
Optional[str] | ||
a token from the server and returns it. | ||
""" | ||
|
||
@retry( | ||
reraise=True, | ||
before=before_log("Request token", self._logger, logging.DEBUG), | ||
retry=self._retry_config.retry, | ||
stop=self._retry_config.stop, | ||
wait=self._retry_config.wait, | ||
) | ||
def auth_with_retry() -> None: | ||
client = self._flight_client | ||
if self._auth: | ||
auth_pair = self._auth.auth_pair() | ||
client.authenticate_basic_token(auth_pair[0], auth_pair[1]) | ||
|
||
if self._auth: | ||
auth_with_retry() | ||
return self._auth_middleware.token() | ||
else: | ||
return "IGNORED" | ||
|
||
def get_stream(self, ticket: Ticket) -> FlightStreamReader: | ||
return self._flight_client.do_get(ticket) | ||
|
||
def do_action(self, endpoint: str, payload: bytes) -> Iterator[Result]: | ||
return self._flight_client.do_action(Action(endpoint, payload)) # type: ignore | ||
|
||
def do_action_with_retry(self, endpoint: str, payload: bytes) -> Iterator[Result]: | ||
@retry( | ||
reraise=True, | ||
before=before_log("Send action", self._logger, logging.DEBUG), | ||
retry=self._retry_config.retry, | ||
stop=self._retry_config.stop, | ||
wait=self._retry_config.wait, | ||
) | ||
def run_with_retry() -> Iterator[Result]: | ||
return self.do_action(endpoint, payload) | ||
|
||
return run_with_retry() | ||
|
||
def _instantiate_flight_client(self) -> flight.FlightClient: | ||
location = ( | ||
flight.Location.for_grpc_tls(self._host, self._port) | ||
if self._encrypted | ||
else flight.Location.for_grpc_tcp(self._host, self._port) | ||
) | ||
client_options: dict[str, Any] = (self._arrow_client_options or {}).copy() | ||
if self._auth: | ||
user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}" | ||
if self._user_agent: | ||
user_agent = self._user_agent | ||
|
||
client_options["middleware"] = [ | ||
AuthFactory(self._auth_middleware), | ||
UserAgentFactory(useragent=user_agent), | ||
] | ||
|
||
return flight.FlightClient(location, **client_options) | ||
|
||
|
||
@dataclass | ||
class ConnectionInfo: | ||
host: str | ||
port: int | ||
encrypted: bool |
63 changes: 63 additions & 0 deletions
63
graphdatascience/arrow_client/middleware/auth_middleware.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from __future__ import annotations | ||
|
||
import base64 | ||
import time | ||
from typing import Any, Optional | ||
|
||
from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory | ||
|
||
from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication | ||
|
||
|
||
class AuthFactory(ClientMiddlewareFactory): # type: ignore | ||
def __init__(self, middleware: AuthMiddleware, *args: Any, **kwargs: Any) -> None: | ||
super().__init__(*args, **kwargs) | ||
self._middleware = middleware | ||
|
||
def start_call(self, info: Any) -> AuthMiddleware: | ||
return self._middleware | ||
|
||
|
||
class AuthMiddleware(ClientMiddleware): # type: ignore | ||
def __init__(self, auth: ArrowAuthentication, *args: Any, **kwargs: Any) -> None: | ||
super().__init__(*args, **kwargs) | ||
self._auth = auth | ||
self._token: Optional[str] = None | ||
self._token_timestamp = 0 | ||
|
||
def token(self) -> Optional[str]: | ||
# check whether the token is older than 10 minutes. If so, reset it. | ||
if self._token and int(time.time()) - self._token_timestamp > 600: | ||
self._token = None | ||
|
||
return self._token | ||
|
||
def _set_token(self, token: str) -> None: | ||
self._token = token | ||
self._token_timestamp = int(time.time()) | ||
|
||
def received_headers(self, headers: dict[str, Any]) -> None: | ||
auth_header = headers.get("authorization", None) | ||
if not auth_header: | ||
return | ||
|
||
# the result is always a list | ||
header_value = auth_header[0] | ||
|
||
if not isinstance(header_value, str): | ||
raise ValueError(f"Incompatible header value received from server: `{header_value}`") | ||
|
||
auth_type, token = header_value.split(" ", 1) | ||
if auth_type == "Bearer": | ||
self._set_token(token) | ||
|
||
def sending_headers(self) -> dict[str, str]: | ||
token = self.token() | ||
if token is not None: | ||
return {"authorization": "Bearer " + token} | ||
|
||
auth_pair = self._auth.auth_pair() | ||
auth_token = f"{auth_pair[0]}:{auth_pair[1]}" | ||
auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII") | ||
# There seems to be a bug, `authorization` must be lower key | ||
return {"authorization": auth_token} |
26 changes: 26 additions & 0 deletions
26
graphdatascience/arrow_client/middleware/user_agent_middleware.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory | ||
|
||
|
||
class UserAgentFactory(ClientMiddlewareFactory): # type: ignore | ||
def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None: | ||
super().__init__(*args, **kwargs) | ||
self._middleware = UserAgentMiddleware(useragent) | ||
|
||
def start_call(self, info: Any) -> ClientMiddleware: | ||
return self._middleware | ||
|
||
|
||
class UserAgentMiddleware(ClientMiddleware): # type: ignore | ||
def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None: | ||
super().__init__(*args, **kwargs) | ||
self._useragent = useragent | ||
|
||
def sending_headers(self) -> dict[str, str]: | ||
return {"x-gds-user-agent": self._useragent} | ||
|
||
def received_headers(self, headers: dict[str, Any]) -> None: | ||
pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from graphdatascience.arrow_client.arrow_base_model import ArrowBaseModel | ||
|
||
|
||
class JobIdConfig(ArrowBaseModel): | ||
job_id: str | ||
|
||
|
||
class JobStatus(ArrowBaseModel): | ||
job_id: str | ||
status: str | ||
progress: float | ||
|
||
|
||
class MutateResult(ArrowBaseModel): | ||
node_properties_written: int | ||
relationships_written: int |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import json | ||
from typing import Any, Iterator | ||
|
||
from pyarrow._flight import Result | ||
|
||
|
||
def deserialize_single(input_stream: Iterator[Result]) -> dict[str, Any]: | ||
rows = deserialize(input_stream) | ||
if len(rows) != 1: | ||
raise ValueError(f"Expected exactly one result, got {len(rows)}") | ||
|
||
return rows[0] | ||
|
||
|
||
def deserialize(input_stream: Iterator[Result]) -> list[dict[str, Any]]: | ||
def deserialize_row(row: Result): # type:ignore | ||
return json.loads(row.body.to_pybytes().decode()) | ||
|
||
return [deserialize_row(row) for row in list(input_stream)] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.