Skip to content

V2 Arrow Clients #918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions graphdatascience/arrow_client/arrow_authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass


class ArrowAuthentication(ABC):
@abstractmethod
def auth_pair(self) -> tuple[str, str]:
"""Returns the auth pair used for authentication."""
pass


@dataclass
class UsernamePasswordAuthentication(ArrowAuthentication):
username: str
password: str

def auth_pair(self) -> tuple[str, str]:
return self.username, self.password
14 changes: 14 additions & 0 deletions graphdatascience/arrow_client/arrow_base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Any

from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel


class ArrowBaseModel(BaseModel):
model_config = ConfigDict(alias_generator=to_camel)

def dump_camel(self) -> dict[str, Any]:
return self.model_dump(by_alias=True)

def dump_json(self) -> str:
return self.model_dump_json(by_alias=True)
31 changes: 31 additions & 0 deletions graphdatascience/arrow_client/arrow_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

from dataclasses import dataclass

from ..query_runner.query_runner import QueryRunner
from ..server_version.server_version import ServerVersion


@dataclass(frozen=True)
class ArrowInfo:
listenAddress: str
enabled: bool
running: bool
versions: list[str]

@staticmethod
def create(query_runner: QueryRunner) -> ArrowInfo:
debugYields = ["listenAddress", "enabled", "running"]
if query_runner.server_version() > ServerVersion(2, 6, 0):
debugYields.append("versions")

procResult = query_runner.call_procedure(
endpoint="gds.debug.arrow", custom_error=False, yields=debugYields
).iloc[0]

return ArrowInfo(
listenAddress=procResult["listenAddress"],
enabled=procResult["enabled"],
running=procResult["running"],
versions=procResult.get("versions", []),
)
195 changes: 195 additions & 0 deletions graphdatascience/arrow_client/authenticated_flight_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import Any, Iterator, Optional

from pyarrow import __version__ as arrow_version
from pyarrow import flight
from pyarrow._flight import (
Action,
FlightInternalError,
FlightStreamReader,
FlightTimedOutError,
FlightUnavailableError,
Result,
Ticket,
)
from tenacity import retry, retry_any, retry_if_exception_type, stop_after_attempt, stop_after_delay, wait_exponential

from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication
from graphdatascience.arrow_client.arrow_info import ArrowInfo
from graphdatascience.retry_utils.retry_config import RetryConfig

from ..retry_utils.retry_utils import before_log
from ..version import __version__
from .middleware.auth_middleware import AuthFactory, AuthMiddleware
from .middleware.user_agent_middleware import UserAgentFactory


class AuthenticatedArrowClient:
@staticmethod
def create(
arrow_info: ArrowInfo,
auth: Optional[ArrowAuthentication] = None,
encrypted: bool = False,
arrow_client_options: Optional[dict[str, Any]] = None,
connection_string_override: Optional[str] = None,
retry_config: Optional[RetryConfig] = None,
) -> AuthenticatedArrowClient:
connection_string: str
if connection_string_override is not None:
connection_string = connection_string_override
else:
connection_string = arrow_info.listenAddress

host, port = connection_string.split(":")

if retry_config is None:
retry_config = RetryConfig(
retry=retry_any(
retry_if_exception_type(FlightTimedOutError),
retry_if_exception_type(FlightUnavailableError),
retry_if_exception_type(FlightInternalError),
),
stop=(stop_after_delay(10) | stop_after_attempt(5)),
wait=wait_exponential(multiplier=1, min=1, max=10),
)

return AuthenticatedArrowClient(
host=host,
retry_config=retry_config,
port=int(port),
auth=auth,
encrypted=encrypted,
arrow_client_options=arrow_client_options,
)

def __init__(
self,
host: str,
retry_config: RetryConfig,
port: int = 8491,
auth: Optional[ArrowAuthentication] = None,
encrypted: bool = False,
arrow_client_options: Optional[dict[str, Any]] = None,
user_agent: Optional[str] = None,
):
"""Creates a new GdsArrowClient instance.

Parameters
----------
host: str
The host address of the GDS Arrow server
port: int
The host port of the GDS Arrow server (default is 8491)
auth: Optional[ArrowAuthentication]
An implementation of ArrowAuthentication providing a pair to be used for basic authentication
encrypted: bool
A flag that indicates whether the connection should be encrypted (default is False)
arrow_client_options: Optional[dict[str, Any]]
Additional options to be passed to the Arrow Flight client.
user_agent: Optional[str]
The user agent string to use for the connection. (default is `neo4j-graphdatascience-v[VERSION] pyarrow-v[PYARROW_VERSION])
retry_config: Optional[RetryConfig]
The retry configuration to use for the Arrow requests send by the client.
"""
self._host = host
self._port = port
self._auth = None
self._encrypted = encrypted
self._arrow_client_options = arrow_client_options
self._user_agent = user_agent
self._retry_config = retry_config
self._logger = logging.getLogger("gds_arrow_client")
self._retry_config = retry_config
if auth:
self._auth = auth
self._auth_middleware = AuthMiddleware(auth)

self._flight_client = self._instantiate_flight_client()

def connection_info(self) -> ConnectionInfo:
"""
Returns the host and port of the GDS Arrow server.

Returns
-------
tuple[str, int]
the host and port of the GDS Arrow server
"""
return ConnectionInfo(self._host, self._port, self._encrypted)

def request_token(self) -> Optional[str]:
"""
Requests a token from the server and returns it.

Returns
-------
Optional[str]
a token from the server and returns it.
"""

@retry(
reraise=True,
before=before_log("Request token", self._logger, logging.DEBUG),
retry=self._retry_config.retry,
stop=self._retry_config.stop,
wait=self._retry_config.wait,
)
def auth_with_retry() -> None:
client = self._flight_client
if self._auth:
auth_pair = self._auth.auth_pair()
client.authenticate_basic_token(auth_pair[0], auth_pair[1])

if self._auth:
auth_with_retry()
return self._auth_middleware.token()
else:
return "IGNORED"

def get_stream(self, ticket: Ticket) -> FlightStreamReader:
return self._flight_client.do_get(ticket)

def do_action(self, endpoint: str, payload: bytes) -> Iterator[Result]:
return self._flight_client.do_action(Action(endpoint, payload)) # type: ignore

def do_action_with_retry(self, endpoint: str, payload: bytes) -> Iterator[Result]:
@retry(
reraise=True,
before=before_log("Send action", self._logger, logging.DEBUG),
retry=self._retry_config.retry,
stop=self._retry_config.stop,
wait=self._retry_config.wait,
)
def run_with_retry() -> Iterator[Result]:
return self.do_action(endpoint, payload)

return run_with_retry()

def _instantiate_flight_client(self) -> flight.FlightClient:
location = (
flight.Location.for_grpc_tls(self._host, self._port)
if self._encrypted
else flight.Location.for_grpc_tcp(self._host, self._port)
)
client_options: dict[str, Any] = (self._arrow_client_options or {}).copy()
if self._auth:
user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}"
if self._user_agent:
user_agent = self._user_agent

client_options["middleware"] = [
AuthFactory(self._auth_middleware),
UserAgentFactory(useragent=user_agent),
]

return flight.FlightClient(location, **client_options)


@dataclass
class ConnectionInfo:
host: str
port: int
encrypted: bool
63 changes: 63 additions & 0 deletions graphdatascience/arrow_client/middleware/auth_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

import base64
import time
from typing import Any, Optional

from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory

from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication


class AuthFactory(ClientMiddlewareFactory): # type: ignore
def __init__(self, middleware: AuthMiddleware, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._middleware = middleware

def start_call(self, info: Any) -> AuthMiddleware:
return self._middleware


class AuthMiddleware(ClientMiddleware): # type: ignore
def __init__(self, auth: ArrowAuthentication, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._auth = auth
self._token: Optional[str] = None
self._token_timestamp = 0

def token(self) -> Optional[str]:
# check whether the token is older than 10 minutes. If so, reset it.
if self._token and int(time.time()) - self._token_timestamp > 600:
self._token = None

return self._token

def _set_token(self, token: str) -> None:
self._token = token
self._token_timestamp = int(time.time())

def received_headers(self, headers: dict[str, Any]) -> None:
auth_header = headers.get("authorization", None)
if not auth_header:
return

# the result is always a list
header_value = auth_header[0]

if not isinstance(header_value, str):
raise ValueError(f"Incompatible header value received from server: `{header_value}`")

auth_type, token = header_value.split(" ", 1)
if auth_type == "Bearer":
self._set_token(token)

def sending_headers(self) -> dict[str, str]:
token = self.token()
if token is not None:
return {"authorization": "Bearer " + token}

auth_pair = self._auth.auth_pair()
auth_token = f"{auth_pair[0]}:{auth_pair[1]}"
auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII")
# There seems to be a bug, `authorization` must be lower key
return {"authorization": auth_token}
26 changes: 26 additions & 0 deletions graphdatascience/arrow_client/middleware/user_agent_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

from typing import Any

from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory


class UserAgentFactory(ClientMiddlewareFactory): # type: ignore
def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._middleware = UserAgentMiddleware(useragent)

def start_call(self, info: Any) -> ClientMiddleware:
return self._middleware


class UserAgentMiddleware(ClientMiddleware): # type: ignore
def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._useragent = useragent

def sending_headers(self) -> dict[str, str]:
return {"x-gds-user-agent": self._useragent}

def received_headers(self, headers: dict[str, Any]) -> None:
pass
16 changes: 16 additions & 0 deletions graphdatascience/arrow_client/v2/api_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from graphdatascience.arrow_client.arrow_base_model import ArrowBaseModel


class JobIdConfig(ArrowBaseModel):
job_id: str


class JobStatus(ArrowBaseModel):
job_id: str
status: str
progress: float


class MutateResult(ArrowBaseModel):
node_properties_written: int
relationships_written: int
19 changes: 19 additions & 0 deletions graphdatascience/arrow_client/v2/data_mapper_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import json
from typing import Any, Iterator

from pyarrow._flight import Result


def deserialize_single(input_stream: Iterator[Result]) -> dict[str, Any]:
rows = deserialize(input_stream)
if len(rows) != 1:
raise ValueError(f"Expected exactly one result, got {len(rows)}")

return rows[0]


def deserialize(input_stream: Iterator[Result]) -> list[dict[str, Any]]:
def deserialize_row(row: Result): # type:ignore
return json.loads(row.body.to_pybytes().decode())

return [deserialize_row(row) for row in list(input_stream)]
Loading