-
Notifications
You must be signed in to change notification settings - Fork 54
Create explicit WCC endpoints #859
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
DarthMax
wants to merge
12
commits into
neo4j:main
Choose a base branch
from
DarthMax:create_explicit_procedure_endpoints_for_wcc
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
4246b30
Introduce explict endpoints for Wcc
DarthMax c0edacc
Implement Wcc endpoints using Cypher procedures
DarthMax 50d88ab
Move arrow client related code into its own package
DarthMax 6c39bc0
Return custom type from wcc.mutate
DarthMax be2985e
Introduce AuthenticatedArrowClient
DarthMax 877427a
Implement Arrow based wcc endpoints
DarthMax 8554f63
Fix formatting
DarthMax 235fd9e
Fix Cypher wcc endpoint tests
DarthMax d8f221c
Fix type issues
DarthMax 03eb011
Generalize config extraction for arrow endpoints
DarthMax ed37a4d
Use config converter also for Cypher endpoints
DarthMax ae2f8cb
Add feature flag to enabled explicit APIs
DarthMax File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
File renamed without changes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
211 changes: 211 additions & 0 deletions
211
graphdatascience/arrow_client/authenticated_arrow_client.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from dataclasses import dataclass | ||
from typing import Any, Iterator, Optional | ||
|
||
from pyarrow import __version__ as arrow_version | ||
from pyarrow import flight | ||
from pyarrow._flight import ( | ||
Action, | ||
FlightInternalError, | ||
FlightStreamReader, | ||
FlightTimedOutError, | ||
FlightUnavailableError, | ||
Result, | ||
Ticket, | ||
) | ||
from tenacity import retry, retry_any, retry_if_exception_type, stop_after_attempt, stop_after_delay, wait_exponential | ||
|
||
from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication | ||
from graphdatascience.arrow_client.arrow_info import ArrowInfo | ||
from graphdatascience.retry_utils.retry_config import RetryConfig | ||
|
||
from ..retry_utils.retry_utils import before_log | ||
from ..version import __version__ | ||
from .middleware.AuthMiddleware import AuthFactory, AuthMiddleware | ||
from .middleware.UserAgentMiddleware import UserAgentFactory | ||
|
||
|
||
class AuthenticatedArrowClient: | ||
@staticmethod | ||
def create( | ||
arrow_info: ArrowInfo, | ||
auth: Optional[ArrowAuthentication] = None, | ||
encrypted: bool = False, | ||
disable_server_verification: bool = False, | ||
tls_root_certs: Optional[bytes] = None, | ||
connection_string_override: Optional[str] = None, | ||
retry_config: Optional[RetryConfig] = None, | ||
) -> AuthenticatedArrowClient: | ||
connection_string: str | ||
if connection_string_override is not None: | ||
connection_string = connection_string_override | ||
else: | ||
connection_string = arrow_info.listenAddress | ||
|
||
host, port = connection_string.split(":") | ||
|
||
if retry_config is None: | ||
retry_config = RetryConfig( | ||
retry=retry_any( | ||
retry_if_exception_type(FlightTimedOutError), | ||
retry_if_exception_type(FlightUnavailableError), | ||
retry_if_exception_type(FlightInternalError), | ||
), | ||
stop=(stop_after_delay(10) | stop_after_attempt(5)), | ||
wait=wait_exponential(multiplier=1, min=1, max=10), | ||
) | ||
|
||
return AuthenticatedArrowClient( | ||
host, | ||
retry_config, | ||
int(port), | ||
auth, | ||
encrypted, | ||
disable_server_verification, | ||
tls_root_certs, | ||
) | ||
|
||
def __init__( | ||
self, | ||
host: str, | ||
retry_config: RetryConfig, | ||
port: int = 8491, | ||
auth: Optional[ArrowAuthentication] = None, | ||
encrypted: bool = False, | ||
disable_server_verification: bool = False, | ||
tls_root_certs: Optional[bytes] = None, | ||
user_agent: Optional[str] = None, | ||
): | ||
"""Creates a new GdsArrowClient instance. | ||
|
||
Parameters | ||
---------- | ||
host: str | ||
The host address of the GDS Arrow server | ||
port: int | ||
The host port of the GDS Arrow server (default is 8491) | ||
auth: Optional[ArrowAuthentication] | ||
Either an implementation of ArrowAuthentication providing a pair to be used for basic authentication, or a username, password tuple | ||
encrypted: bool | ||
A flag that indicates whether the connection should be encrypted (default is False) | ||
disable_server_verification: bool | ||
A flag that disables server verification for TLS connections (default is False) | ||
tls_root_certs: Optional[bytes] | ||
PEM-encoded certificates that are used for the connection to the GDS Arrow Flight server | ||
user_agent: Optional[str] | ||
The user agent string to use for the connection. (default is `neo4j-graphdatascience-v[VERSION] pyarrow-v[PYARROW_VERSION]) | ||
retry_config: Optional[RetryConfig] | ||
The retry configuration to use for the Arrow requests send by the client. | ||
""" | ||
self._host = host | ||
self._port = port | ||
self._auth = None | ||
self._encrypted = encrypted | ||
self._disable_server_verification = disable_server_verification | ||
self._tls_root_certs = tls_root_certs | ||
self._user_agent = user_agent | ||
self._retry_config = retry_config | ||
self._logger = logging.getLogger("gds_arrow_client") | ||
self._retry_config = RetryConfig( | ||
retry=retry_any( | ||
retry_if_exception_type(FlightTimedOutError), | ||
retry_if_exception_type(FlightUnavailableError), | ||
retry_if_exception_type(FlightInternalError), | ||
), | ||
stop=(stop_after_delay(10) | stop_after_attempt(5)), | ||
wait=wait_exponential(multiplier=1, min=1, max=10), | ||
) | ||
|
||
if auth: | ||
self._auth = auth | ||
self._auth_middleware = AuthMiddleware(auth) | ||
|
||
self._flight_client = self._instantiate_flight_client() | ||
|
||
def connection_info(self) -> ConnectionInfo: | ||
""" | ||
Returns the host and port of the GDS Arrow server. | ||
|
||
Returns | ||
------- | ||
tuple[str, int] | ||
the host and port of the GDS Arrow server | ||
""" | ||
return ConnectionInfo(self._host, self._port, self._encrypted) | ||
|
||
def request_token(self) -> Optional[str]: | ||
""" | ||
Requests a token from the server and returns it. | ||
|
||
Returns | ||
------- | ||
Optional[str] | ||
a token from the server and returns it. | ||
""" | ||
|
||
@retry( | ||
reraise=True, | ||
before=before_log("Request token", self._logger, logging.DEBUG), | ||
retry=self._retry_config.retry, | ||
stop=self._retry_config.stop, | ||
wait=self._retry_config.wait, | ||
) | ||
def auth_with_retry() -> None: | ||
client = self._flight_client | ||
if self._auth: | ||
auth_pair = self._auth.auth_pair() | ||
client.authenticate_basic_token(auth_pair[0], auth_pair[1]) | ||
|
||
if self._auth: | ||
auth_with_retry() | ||
return self._auth_middleware.token() | ||
else: | ||
return "IGNORED" | ||
|
||
def get_stream(self, ticket: Ticket) -> FlightStreamReader: | ||
return self._flight_client.do_get(ticket) | ||
|
||
def do_action(self, endpoint: str, payload: bytes) -> Iterator[Result]: | ||
return self._flight_client.do_action(Action(endpoint, payload)) # type: ignore | ||
|
||
def do_action_with_retry(self, endpoint: str, payload: bytes) -> Iterator[Result]: | ||
@retry( | ||
reraise=True, | ||
before=before_log("Send action", self._logger, logging.DEBUG), | ||
retry=self._retry_config.retry, | ||
stop=self._retry_config.stop, | ||
wait=self._retry_config.wait, | ||
) | ||
def run_with_retry() -> Iterator[Result]: | ||
return self.do_action(endpoint, payload) | ||
|
||
return run_with_retry() | ||
|
||
def _instantiate_flight_client(self) -> flight.FlightClient: | ||
location = ( | ||
flight.Location.for_grpc_tls(self._host, self._port) | ||
if self._encrypted | ||
else flight.Location.for_grpc_tcp(self._host, self._port) | ||
) | ||
client_options: dict[str, Any] = {"disable_server_verification": self._disable_server_verification} | ||
if self._auth: | ||
user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}" | ||
if self._user_agent: | ||
user_agent = self._user_agent | ||
|
||
client_options["middleware"] = [ | ||
AuthFactory(self._auth_middleware), | ||
UserAgentFactory(useragent=user_agent), | ||
] | ||
if self._tls_root_certs: | ||
client_options["tls_root_certs"] = self._tls_root_certs | ||
return flight.FlightClient(location, **client_options) | ||
|
||
|
||
@dataclass | ||
class ConnectionInfo: | ||
host: str | ||
port: int | ||
encrypted: bool |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import dataclasses | ||
import json | ||
from dataclasses import fields | ||
from typing import Any, Dict, Iterator, Type, TypeVar | ||
|
||
from pyarrow._flight import Result | ||
|
||
|
||
class DataMapper: | ||
T = TypeVar("T") | ||
|
||
@staticmethod | ||
def deserialize_single(input_stream: Iterator[Result], cls: Type[T]) -> T: | ||
rows = DataMapper.deserialize(input_stream, cls) | ||
|
||
if len(rows) != 1: | ||
raise ValueError(f"Expected exactly one row, got {len(rows)}") | ||
|
||
return rows[0] | ||
|
||
@staticmethod | ||
def deserialize(input_stream: Iterator[Result], cls: Type[T]) -> list[T]: | ||
def deserialize_row(row: Result): # type:ignore | ||
result_dicts = json.loads(row.body.to_pybytes().decode()) | ||
if cls == Dict: | ||
return result_dicts | ||
return DataMapper.dict_to_dataclass(result_dicts, cls) | ||
|
||
return [deserialize_row(row) for row in list(input_stream)] | ||
|
||
@staticmethod | ||
def dict_to_dataclass(data: Dict[str, Any], cls: Type[T], strict: bool = False) -> T: | ||
""" | ||
Convert a dictionary to a dataclass instance with nested dataclass support. | ||
""" | ||
if not dataclasses.is_dataclass(cls): | ||
raise ValueError(f"{cls} is not a dataclass") | ||
|
||
field_dict = {f.name: f for f in fields(cls)} | ||
filtered_data = {} | ||
|
||
for key, value in data.items(): | ||
if key in field_dict: | ||
field = field_dict[key] | ||
field_type = field.type | ||
|
||
# Handle nested dataclasses | ||
if dataclasses.is_dataclass(field_type) and isinstance(value, dict): | ||
filtered_data[key] = DataMapper.dict_to_dataclass(value, field_type, strict) # type:ignore | ||
else: | ||
filtered_data[key] = value | ||
elif strict: | ||
raise ValueError(f"Extra field '{key}' not allowed in {cls.__name__}") | ||
|
||
return cls(**filtered_data) # type: ignore | ||
Comment on lines
+32
to
+55
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for this validation we could think of also using pydantic instead of rolling our custom validation |
63 changes: 63 additions & 0 deletions
63
graphdatascience/arrow_client/middleware/AuthMiddleware.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from __future__ import annotations | ||
|
||
import base64 | ||
import time | ||
from typing import Any, Optional | ||
|
||
from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory | ||
|
||
from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication | ||
|
||
|
||
class AuthFactory(ClientMiddlewareFactory): # type: ignore | ||
def __init__(self, middleware: AuthMiddleware, *args: Any, **kwargs: Any) -> None: | ||
super().__init__(*args, **kwargs) | ||
self._middleware = middleware | ||
|
||
def start_call(self, info: Any) -> AuthMiddleware: | ||
return self._middleware | ||
|
||
|
||
class AuthMiddleware(ClientMiddleware): # type: ignore | ||
def __init__(self, auth: ArrowAuthentication, *args: Any, **kwargs: Any) -> None: | ||
super().__init__(*args, **kwargs) | ||
self._auth = auth | ||
self._token: Optional[str] = None | ||
self._token_timestamp = 0 | ||
|
||
def token(self) -> Optional[str]: | ||
# check whether the token is older than 10 minutes. If so, reset it. | ||
if self._token and int(time.time()) - self._token_timestamp > 600: | ||
self._token = None | ||
|
||
return self._token | ||
|
||
def _set_token(self, token: str) -> None: | ||
self._token = token | ||
self._token_timestamp = int(time.time()) | ||
|
||
def received_headers(self, headers: dict[str, Any]) -> None: | ||
auth_header = headers.get("authorization", None) | ||
if not auth_header: | ||
return | ||
|
||
# the result is always a list | ||
header_value = auth_header[0] | ||
|
||
if not isinstance(header_value, str): | ||
raise ValueError(f"Incompatible header value received from server: `{header_value}`") | ||
|
||
auth_type, token = header_value.split(" ", 1) | ||
if auth_type == "Bearer": | ||
self._set_token(token) | ||
|
||
def sending_headers(self) -> dict[str, str]: | ||
token = self.token() | ||
if token is not None: | ||
return {"authorization": "Bearer " + token} | ||
|
||
auth_pair = self._auth.auth_pair() | ||
auth_token = f"{auth_pair[0]}:{auth_pair[1]}" | ||
auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII") | ||
# There seems to be a bug, `authorization` must be lower key | ||
return {"authorization": auth_token} |
26 changes: 26 additions & 0 deletions
26
graphdatascience/arrow_client/middleware/UserAgentMiddleware.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory | ||
|
||
|
||
class UserAgentFactory(ClientMiddlewareFactory): # type: ignore | ||
def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None: | ||
super().__init__(*args, **kwargs) | ||
self._middleware = UserAgentMiddleware(useragent) | ||
|
||
def start_call(self, info: Any) -> ClientMiddleware: | ||
return self._middleware | ||
|
||
|
||
class UserAgentMiddleware(ClientMiddleware): # type: ignore | ||
def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None: | ||
super().__init__(*args, **kwargs) | ||
self._useragent = useragent | ||
|
||
def sending_headers(self) -> dict[str, str]: | ||
return {"x-gds-user-agent": self._useragent} | ||
|
||
def received_headers(self, headers: dict[str, Any]) -> None: | ||
pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass(frozen=True, repr=True) | ||
class JobIdConfig: | ||
jobId: str | ||
|
||
|
||
@dataclass(frozen=True, repr=True) | ||
class JobStatus: | ||
jobId: str | ||
status: str | ||
progress: float | ||
|
||
|
||
@dataclass(frozen=True, repr=True) | ||
class MutateResult: | ||
nodePropertiesWritten: int | ||
relationshipsWritten: int |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: mixing import styles