Skip to content
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from dataclasses import dataclass

from ..query_runner.query_runner import QueryRunner
from ..server_version.server_version import ServerVersion
from graphdatascience.query_runner.query_runner import QueryRunner
from graphdatascience.server_version.server_version import ServerVersion


@dataclass(frozen=True)
Expand Down
210 changes: 210 additions & 0 deletions graphdatascience/arrow_client/authenticated_arrow_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import Any, Optional, Union

from pyarrow import __version__ as arrow_version
from pyarrow import flight
from pyarrow._flight import (
Action,
FlightInternalError,
FlightStreamReader,
FlightTimedOutError,
FlightUnavailableError,
Ticket,
)
from tenacity import retry, retry_any, retry_if_exception_type, stop_after_attempt, stop_after_delay, wait_exponential

from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication
from graphdatascience.arrow_client.arrow_info import ArrowInfo
from graphdatascience.retry_utils.retry_config import RetryConfig

from ..retry_utils.retry_utils import before_log
from ..version import __version__
Comment on lines +20 to +25
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: mixing import styles

from .middleware.AuthMiddleware import AuthFactory, AuthMiddleware
from .middleware.UserAgentMiddleware import UserAgentFactory


class AuthenticatedArrowClient:
@staticmethod
def create(
arrow_info: ArrowInfo,
auth: Optional[ArrowAuthentication] = None,
encrypted: bool = False,
disable_server_verification: bool = False,
tls_root_certs: Optional[bytes] = None,
connection_string_override: Optional[str] = None,
retry_config: Optional[RetryConfig] = None,
) -> AuthenticatedArrowClient:
connection_string: str
if connection_string_override is not None:
connection_string = connection_string_override
else:
connection_string = arrow_info.listenAddress

host, port = connection_string.split(":")

if retry_config is None:
retry_config = RetryConfig(
retry=retry_any(
retry_if_exception_type(FlightTimedOutError),
retry_if_exception_type(FlightUnavailableError),
retry_if_exception_type(FlightInternalError),
),
stop=(stop_after_delay(10) | stop_after_attempt(5)),
wait=wait_exponential(multiplier=1, min=1, max=10),
)

return AuthenticatedArrowClient(
host,
retry_config,
int(port),
auth,
encrypted,
disable_server_verification,
tls_root_certs,
)

def __init__(
self,
host: str,
retry_config: RetryConfig,
port: int = 8491,
auth: Optional[Union[ArrowAuthentication, tuple[str, str]]] = None,
encrypted: bool = False,
disable_server_verification: bool = False,
tls_root_certs: Optional[bytes] = None,
user_agent: Optional[str] = None,
):
"""Creates a new GdsArrowClient instance.

Parameters
----------
host: str
The host address of the GDS Arrow server
port: int
The host port of the GDS Arrow server (default is 8491)
auth: Optional[Union[ArrowAuthentication, tuple[str, str]]]
Either an implementation of ArrowAuthentication providing a pair to be used for basic authentication, or a username, password tuple
encrypted: bool
A flag that indicates whether the connection should be encrypted (default is False)
disable_server_verification: bool
A flag that disables server verification for TLS connections (default is False)
tls_root_certs: Optional[bytes]
PEM-encoded certificates that are used for the connection to the GDS Arrow Flight server
user_agent: Optional[str]
The user agent string to use for the connection. (default is `neo4j-graphdatascience-v[VERSION] pyarrow-v[PYARROW_VERSION])
retry_config: Optional[RetryConfig]
The retry configuration to use for the Arrow requests send by the client.
"""
self._host = host
self._port = port
self._auth = None
self._encrypted = encrypted
self._disable_server_verification = disable_server_verification
self._tls_root_certs = tls_root_certs
self._user_agent = user_agent
self._retry_config = retry_config
self._logger = logging.getLogger("gds_arrow_client")
self._retry_config = RetryConfig(
retry=retry_any(
retry_if_exception_type(FlightTimedOutError),
retry_if_exception_type(FlightUnavailableError),
retry_if_exception_type(FlightInternalError),
),
stop=(stop_after_delay(10) | stop_after_attempt(5)),
wait=wait_exponential(multiplier=1, min=1, max=10),
)

if auth:
self._auth = auth
self._auth_middleware = AuthMiddleware(auth)

self._flight_client = self._instantiate_flight_client()

def connection_info(self) -> ConnectionInfo:
"""
Returns the host and port of the GDS Arrow server.

Returns
-------
tuple[str, int]
the host and port of the GDS Arrow server
"""
return ConnectionInfo(self._host, self._port, self._encrypted)

def request_token(self) -> Optional[str]:
"""
Requests a token from the server and returns it.

Returns
-------
Optional[str]
a token from the server and returns it.
"""

@retry(
reraise=True,
before=before_log("Request token", self._logger, logging.DEBUG),
retry=self._retry_config.retry,
stop=self._retry_config.stop,
wait=self._retry_config.wait,
)
def auth_with_retry() -> None:
client = self._flight_client
if self._auth:
auth_pair = self._auth.auth_pair()
client.authenticate_basic_token(auth_pair[0], auth_pair[1])

if self._auth:
auth_with_retry()
return self._auth_middleware.token()
else:
return "IGNORED"

def get_stream(self, ticket: Ticket) -> FlightStreamReader:
return self._flight_client.do_get(ticket)

def do_action(self, endpoint: str, payload: bytes):
return self._flight_client.do_action(Action(endpoint, payload))

def do_action_with_retry(self, endpoint: str, payload: bytes):
@retry(
reraise=True,
before=before_log("Send action", self._logger, logging.DEBUG),
retry=self._retry_config.retry,
stop=self._retry_config.stop,
wait=self._retry_config.wait,
)
def run_with_retry():
return self.do_action(endpoint, payload)

return run_with_retry()

def _instantiate_flight_client(self) -> flight.FlightClient:
location = (
flight.Location.for_grpc_tls(self._host, self._port)
if self._encrypted
else flight.Location.for_grpc_tcp(self._host, self._port)
)
client_options: dict[str, Any] = {"disable_server_verification": self._disable_server_verification}
if self._auth:
user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}"
if self._user_agent:
user_agent = self._user_agent

client_options["middleware"] = [
AuthFactory(self._auth_middleware),
UserAgentFactory(useragent=user_agent),
]
if self._tls_root_certs:
client_options["tls_root_certs"] = self._tls_root_certs
return flight.FlightClient(location, **client_options)


@dataclass
class ConnectionInfo:
host: str
port: int
encrypted: bool
55 changes: 55 additions & 0 deletions graphdatascience/arrow_client/data_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import dataclasses
import json
from dataclasses import fields
from typing import Any, Dict, Iterator, Type, TypeVar

from pyarrow._flight import Result


class DataMapper:
T = TypeVar("T")

@staticmethod
def deserialize_single(input_stream: Iterator[Result], cls: Type[T]) -> T:
rows = DataMapper.deserialize(input_stream, cls)

if len(rows) != 1:
raise ValueError(f"Expected exactly one row, got {len(rows)}")

return rows[0]

@staticmethod
def deserialize(input_stream, cls: Type[T]) -> list[T]:
def deserialize_row(row: Any):
result_dicts = json.loads(row.body.to_pybytes().decode())
if cls == Dict:
return result_dicts
return DataMapper.dict_to_dataclass(result_dicts, cls)

return [deserialize_row(row) for row in list(input_stream)]

@staticmethod
def dict_to_dataclass(data: Dict[str, Any], cls: Type[T], strict: bool = False) -> T:
"""
Convert a dictionary to a dataclass instance with nested dataclass support.
"""
if not dataclasses.is_dataclass(cls):
raise ValueError(f"{cls} is not a dataclass")

field_dict = {f.name: f for f in fields(cls)}
filtered_data = {}

for key, value in data.items():
if key in field_dict:
field = field_dict[key]
field_type = field.type

# Handle nested dataclasses
if dataclasses.is_dataclass(field_type) and isinstance(value, dict):
filtered_data[key] = DataMapper.dict_to_dataclass(value, field_type, strict)
else:
filtered_data[key] = value
elif strict:
raise ValueError(f"Extra field '{key}' not allowed in {cls.__name__}")

return cls(**filtered_data)
63 changes: 63 additions & 0 deletions graphdatascience/arrow_client/middleware/AuthMiddleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

import base64
import time
from typing import Any, Optional

from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory

from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication


class AuthFactory(ClientMiddlewareFactory): # type: ignore
def __init__(self, middleware: AuthMiddleware, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._middleware = middleware

def start_call(self, info: Any) -> AuthMiddleware:
return self._middleware


class AuthMiddleware(ClientMiddleware): # type: ignore
def __init__(self, auth: ArrowAuthentication, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._auth = auth
self._token: Optional[str] = None
self._token_timestamp = 0

def token(self) -> Optional[str]:
# check whether the token is older than 10 minutes. If so, reset it.
if self._token and int(time.time()) - self._token_timestamp > 600:
self._token = None

return self._token

def _set_token(self, token: str) -> None:
self._token = token
self._token_timestamp = int(time.time())

def received_headers(self, headers: dict[str, Any]) -> None:
auth_header = headers.get("authorization", None)
if not auth_header:
return

# the result is always a list
header_value = auth_header[0]

if not isinstance(header_value, str):
raise ValueError(f"Incompatible header value received from server: `{header_value}`")

auth_type, token = header_value.split(" ", 1)
if auth_type == "Bearer":
self._set_token(token)

def sending_headers(self) -> dict[str, str]:
token = self.token()
if token is not None:
return {"authorization": "Bearer " + token}

auth_pair = self._auth.auth_pair()
auth_token = f"{auth_pair[0]}:{auth_pair[1]}"
auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII")
# There seems to be a bug, `authorization` must be lower key
return {"authorization": auth_token}
26 changes: 26 additions & 0 deletions graphdatascience/arrow_client/middleware/UserAgentMiddleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

from typing import Any

from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory


class UserAgentFactory(ClientMiddlewareFactory): # type: ignore
def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._middleware = UserAgentMiddleware(useragent)

def start_call(self, info: Any) -> ClientMiddleware:
return self._middleware


class UserAgentMiddleware(ClientMiddleware): # type: ignore
def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._useragent = useragent

def sending_headers(self) -> dict[str, str]:
return {"x-gds-user-agent": self._useragent}

def received_headers(self, headers: dict[str, Any]) -> None:
pass
19 changes: 19 additions & 0 deletions graphdatascience/arrow_client/v2/api_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from dataclasses import dataclass


@dataclass(frozen=True, repr=True)
class JobIdConfig:
jobId: str


@dataclass(frozen=True, repr=True)
class JobStatus:
jobId: str
status: str
progress: float


@dataclass(frozen=True, repr=True)
class MutateResult:
nodePropertiesWritten: int
relationshipsWritten: int
Loading