Skip to content

Create explicit WCC endpoints #859

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from dataclasses import dataclass

from ..query_runner.query_runner import QueryRunner
from ..server_version.server_version import ServerVersion
from graphdatascience.query_runner.query_runner import QueryRunner
from graphdatascience.server_version.server_version import ServerVersion


@dataclass(frozen=True)
Expand Down
211 changes: 211 additions & 0 deletions graphdatascience/arrow_client/authenticated_arrow_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import Any, Iterator, Optional

from pyarrow import __version__ as arrow_version
from pyarrow import flight
from pyarrow._flight import (
Action,
FlightInternalError,
FlightStreamReader,
FlightTimedOutError,
FlightUnavailableError,
Result,
Ticket,
)
from tenacity import retry, retry_any, retry_if_exception_type, stop_after_attempt, stop_after_delay, wait_exponential

from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication
from graphdatascience.arrow_client.arrow_info import ArrowInfo
from graphdatascience.retry_utils.retry_config import RetryConfig

from ..retry_utils.retry_utils import before_log
from ..version import __version__
Comment on lines +20 to +25
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: mixing import styles

from .middleware.AuthMiddleware import AuthFactory, AuthMiddleware
from .middleware.UserAgentMiddleware import UserAgentFactory


class AuthenticatedArrowClient:
@staticmethod
def create(
arrow_info: ArrowInfo,
auth: Optional[ArrowAuthentication] = None,
encrypted: bool = False,
disable_server_verification: bool = False,
tls_root_certs: Optional[bytes] = None,
connection_string_override: Optional[str] = None,
retry_config: Optional[RetryConfig] = None,
) -> AuthenticatedArrowClient:
connection_string: str
if connection_string_override is not None:
connection_string = connection_string_override
else:
connection_string = arrow_info.listenAddress

host, port = connection_string.split(":")

if retry_config is None:
retry_config = RetryConfig(
retry=retry_any(
retry_if_exception_type(FlightTimedOutError),
retry_if_exception_type(FlightUnavailableError),
retry_if_exception_type(FlightInternalError),
),
stop=(stop_after_delay(10) | stop_after_attempt(5)),
wait=wait_exponential(multiplier=1, min=1, max=10),
)

return AuthenticatedArrowClient(
host,
retry_config,
int(port),
auth,
encrypted,
disable_server_verification,
tls_root_certs,
)

def __init__(
self,
host: str,
retry_config: RetryConfig,
port: int = 8491,
auth: Optional[ArrowAuthentication] = None,
encrypted: bool = False,
disable_server_verification: bool = False,
tls_root_certs: Optional[bytes] = None,
user_agent: Optional[str] = None,
):
"""Creates a new GdsArrowClient instance.

Parameters
----------
host: str
The host address of the GDS Arrow server
port: int
The host port of the GDS Arrow server (default is 8491)
auth: Optional[ArrowAuthentication]
Either an implementation of ArrowAuthentication providing a pair to be used for basic authentication, or a username, password tuple
encrypted: bool
A flag that indicates whether the connection should be encrypted (default is False)
disable_server_verification: bool
A flag that disables server verification for TLS connections (default is False)
tls_root_certs: Optional[bytes]
PEM-encoded certificates that are used for the connection to the GDS Arrow Flight server
user_agent: Optional[str]
The user agent string to use for the connection. (default is `neo4j-graphdatascience-v[VERSION] pyarrow-v[PYARROW_VERSION])
retry_config: Optional[RetryConfig]
The retry configuration to use for the Arrow requests send by the client.
"""
self._host = host
self._port = port
self._auth = None
self._encrypted = encrypted
self._disable_server_verification = disable_server_verification
self._tls_root_certs = tls_root_certs
self._user_agent = user_agent
self._retry_config = retry_config
self._logger = logging.getLogger("gds_arrow_client")
self._retry_config = RetryConfig(
retry=retry_any(
retry_if_exception_type(FlightTimedOutError),
retry_if_exception_type(FlightUnavailableError),
retry_if_exception_type(FlightInternalError),
),
stop=(stop_after_delay(10) | stop_after_attempt(5)),
wait=wait_exponential(multiplier=1, min=1, max=10),
)

if auth:
self._auth = auth
self._auth_middleware = AuthMiddleware(auth)

self._flight_client = self._instantiate_flight_client()

def connection_info(self) -> ConnectionInfo:
"""
Returns the host and port of the GDS Arrow server.

Returns
-------
tuple[str, int]
the host and port of the GDS Arrow server
"""
return ConnectionInfo(self._host, self._port, self._encrypted)

def request_token(self) -> Optional[str]:
"""
Requests a token from the server and returns it.

Returns
-------
Optional[str]
a token from the server and returns it.
"""

@retry(
reraise=True,
before=before_log("Request token", self._logger, logging.DEBUG),
retry=self._retry_config.retry,
stop=self._retry_config.stop,
wait=self._retry_config.wait,
)
def auth_with_retry() -> None:
client = self._flight_client
if self._auth:
auth_pair = self._auth.auth_pair()
client.authenticate_basic_token(auth_pair[0], auth_pair[1])

if self._auth:
auth_with_retry()
return self._auth_middleware.token()
else:
return "IGNORED"

def get_stream(self, ticket: Ticket) -> FlightStreamReader:
return self._flight_client.do_get(ticket)

def do_action(self, endpoint: str, payload: bytes) -> Iterator[Result]:
return self._flight_client.do_action(Action(endpoint, payload)) # type: ignore

def do_action_with_retry(self, endpoint: str, payload: bytes) -> Iterator[Result]:
@retry(
reraise=True,
before=before_log("Send action", self._logger, logging.DEBUG),
retry=self._retry_config.retry,
stop=self._retry_config.stop,
wait=self._retry_config.wait,
)
def run_with_retry() -> Iterator[Result]:
return self.do_action(endpoint, payload)

return run_with_retry()

def _instantiate_flight_client(self) -> flight.FlightClient:
location = (
flight.Location.for_grpc_tls(self._host, self._port)
if self._encrypted
else flight.Location.for_grpc_tcp(self._host, self._port)
)
client_options: dict[str, Any] = {"disable_server_verification": self._disable_server_verification}
if self._auth:
user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}"
if self._user_agent:
user_agent = self._user_agent

client_options["middleware"] = [
AuthFactory(self._auth_middleware),
UserAgentFactory(useragent=user_agent),
]
if self._tls_root_certs:
client_options["tls_root_certs"] = self._tls_root_certs
return flight.FlightClient(location, **client_options)


@dataclass
class ConnectionInfo:
host: str
port: int
encrypted: bool
55 changes: 55 additions & 0 deletions graphdatascience/arrow_client/data_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import dataclasses
import json
from dataclasses import fields
from typing import Any, Dict, Iterator, Type, TypeVar

from pyarrow._flight import Result


class DataMapper:
T = TypeVar("T")

@staticmethod
def deserialize_single(input_stream: Iterator[Result], cls: Type[T]) -> T:
rows = DataMapper.deserialize(input_stream, cls)

if len(rows) != 1:
raise ValueError(f"Expected exactly one row, got {len(rows)}")

return rows[0]

@staticmethod
def deserialize(input_stream: Iterator[Result], cls: Type[T]) -> list[T]:
def deserialize_row(row: Result): # type:ignore
result_dicts = json.loads(row.body.to_pybytes().decode())
if cls == Dict:
return result_dicts
return DataMapper.dict_to_dataclass(result_dicts, cls)

return [deserialize_row(row) for row in list(input_stream)]

@staticmethod
def dict_to_dataclass(data: Dict[str, Any], cls: Type[T], strict: bool = False) -> T:
"""
Convert a dictionary to a dataclass instance with nested dataclass support.
"""
if not dataclasses.is_dataclass(cls):
raise ValueError(f"{cls} is not a dataclass")

field_dict = {f.name: f for f in fields(cls)}
filtered_data = {}

for key, value in data.items():
if key in field_dict:
field = field_dict[key]
field_type = field.type

# Handle nested dataclasses
if dataclasses.is_dataclass(field_type) and isinstance(value, dict):
filtered_data[key] = DataMapper.dict_to_dataclass(value, field_type, strict) # type:ignore
else:
filtered_data[key] = value
elif strict:
raise ValueError(f"Extra field '{key}' not allowed in {cls.__name__}")

return cls(**filtered_data) # type: ignore
Comment on lines +32 to +55
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this validation we could think of also using pydantic instead of rolling our custom validation

63 changes: 63 additions & 0 deletions graphdatascience/arrow_client/middleware/AuthMiddleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

import base64
import time
from typing import Any, Optional

from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory

from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication


class AuthFactory(ClientMiddlewareFactory): # type: ignore
def __init__(self, middleware: AuthMiddleware, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._middleware = middleware

def start_call(self, info: Any) -> AuthMiddleware:
return self._middleware


class AuthMiddleware(ClientMiddleware): # type: ignore
def __init__(self, auth: ArrowAuthentication, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._auth = auth
self._token: Optional[str] = None
self._token_timestamp = 0

def token(self) -> Optional[str]:
# check whether the token is older than 10 minutes. If so, reset it.
if self._token and int(time.time()) - self._token_timestamp > 600:
self._token = None

return self._token

def _set_token(self, token: str) -> None:
self._token = token
self._token_timestamp = int(time.time())

def received_headers(self, headers: dict[str, Any]) -> None:
auth_header = headers.get("authorization", None)
if not auth_header:
return

# the result is always a list
header_value = auth_header[0]

if not isinstance(header_value, str):
raise ValueError(f"Incompatible header value received from server: `{header_value}`")

auth_type, token = header_value.split(" ", 1)
if auth_type == "Bearer":
self._set_token(token)

def sending_headers(self) -> dict[str, str]:
token = self.token()
if token is not None:
return {"authorization": "Bearer " + token}

auth_pair = self._auth.auth_pair()
auth_token = f"{auth_pair[0]}:{auth_pair[1]}"
auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII")
# There seems to be a bug, `authorization` must be lower key
return {"authorization": auth_token}
26 changes: 26 additions & 0 deletions graphdatascience/arrow_client/middleware/UserAgentMiddleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

from typing import Any

from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory


class UserAgentFactory(ClientMiddlewareFactory): # type: ignore
def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._middleware = UserAgentMiddleware(useragent)

def start_call(self, info: Any) -> ClientMiddleware:
return self._middleware


class UserAgentMiddleware(ClientMiddleware): # type: ignore
def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._useragent = useragent

def sending_headers(self) -> dict[str, str]:
return {"x-gds-user-agent": self._useragent}

def received_headers(self, headers: dict[str, Any]) -> None:
pass
19 changes: 19 additions & 0 deletions graphdatascience/arrow_client/v2/api_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from dataclasses import dataclass


@dataclass(frozen=True, repr=True)
class JobIdConfig:
jobId: str


@dataclass(frozen=True, repr=True)
class JobStatus:
jobId: str
status: str
progress: float


@dataclass(frozen=True, repr=True)
class MutateResult:
nodePropertiesWritten: int
relationshipsWritten: int
Loading