diff --git a/graphdatascience/arrow_client/arrow_authentication.py b/graphdatascience/arrow_client/arrow_authentication.py new file mode 100644 index 000000000..948c2fe6a --- /dev/null +++ b/graphdatascience/arrow_client/arrow_authentication.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass + + +class ArrowAuthentication(ABC): + @abstractmethod + def auth_pair(self) -> tuple[str, str]: + """Returns the auth pair used for authentication.""" + pass + + +@dataclass +class UsernamePasswordAuthentication(ArrowAuthentication): + username: str + password: str + + def auth_pair(self) -> tuple[str, str]: + return self.username, self.password diff --git a/graphdatascience/arrow_client/arrow_base_model.py b/graphdatascience/arrow_client/arrow_base_model.py new file mode 100644 index 000000000..75a5e34fd --- /dev/null +++ b/graphdatascience/arrow_client/arrow_base_model.py @@ -0,0 +1,14 @@ +from typing import Any + +from pydantic import BaseModel, ConfigDict +from pydantic.alias_generators import to_camel + + +class ArrowBaseModel(BaseModel): + model_config = ConfigDict(alias_generator=to_camel) + + def dump_camel(self) -> dict[str, Any]: + return self.model_dump(by_alias=True) + + def dump_json(self) -> str: + return self.model_dump_json(by_alias=True) diff --git a/graphdatascience/arrow_client/arrow_info.py b/graphdatascience/arrow_client/arrow_info.py new file mode 100644 index 000000000..8a2399182 --- /dev/null +++ b/graphdatascience/arrow_client/arrow_info.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from ..query_runner.query_runner import QueryRunner +from ..server_version.server_version import ServerVersion + + +@dataclass(frozen=True) +class ArrowInfo: + listenAddress: str + enabled: bool + running: bool + versions: list[str] + + @staticmethod + def create(query_runner: QueryRunner) -> ArrowInfo: + debugYields = ["listenAddress", "enabled", "running"] + if query_runner.server_version() > ServerVersion(2, 6, 0): + debugYields.append("versions") + + procResult = query_runner.call_procedure( + endpoint="gds.debug.arrow", custom_error=False, yields=debugYields + ).iloc[0] + + return ArrowInfo( + listenAddress=procResult["listenAddress"], + enabled=procResult["enabled"], + running=procResult["running"], + versions=procResult.get("versions", []), + ) diff --git a/graphdatascience/arrow_client/authenticated_flight_client.py b/graphdatascience/arrow_client/authenticated_flight_client.py new file mode 100644 index 000000000..816aa2c11 --- /dev/null +++ b/graphdatascience/arrow_client/authenticated_flight_client.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any, Iterator, Optional + +from pyarrow import __version__ as arrow_version +from pyarrow import flight +from pyarrow._flight import ( + Action, + FlightInternalError, + FlightStreamReader, + FlightTimedOutError, + FlightUnavailableError, + Result, + Ticket, +) +from tenacity import retry, retry_any, retry_if_exception_type, stop_after_attempt, stop_after_delay, wait_exponential + +from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication +from graphdatascience.arrow_client.arrow_info import ArrowInfo +from graphdatascience.retry_utils.retry_config import RetryConfig + +from ..retry_utils.retry_utils import before_log +from ..version import __version__ +from .middleware.auth_middleware import AuthFactory, AuthMiddleware +from .middleware.user_agent_middleware import UserAgentFactory + + +class AuthenticatedArrowClient: + @staticmethod + def create( + arrow_info: ArrowInfo, + auth: Optional[ArrowAuthentication] = None, + encrypted: bool = False, + arrow_client_options: Optional[dict[str, Any]] = None, + connection_string_override: Optional[str] = None, + retry_config: Optional[RetryConfig] = None, + ) -> AuthenticatedArrowClient: + connection_string: str + if connection_string_override is not None: + connection_string = connection_string_override + else: + connection_string = arrow_info.listenAddress + + host, port = connection_string.split(":") + + if retry_config is None: + retry_config = RetryConfig( + retry=retry_any( + retry_if_exception_type(FlightTimedOutError), + retry_if_exception_type(FlightUnavailableError), + retry_if_exception_type(FlightInternalError), + ), + stop=(stop_after_delay(10) | stop_after_attempt(5)), + wait=wait_exponential(multiplier=1, min=1, max=10), + ) + + return AuthenticatedArrowClient( + host=host, + retry_config=retry_config, + port=int(port), + auth=auth, + encrypted=encrypted, + arrow_client_options=arrow_client_options, + ) + + def __init__( + self, + host: str, + retry_config: RetryConfig, + port: int = 8491, + auth: Optional[ArrowAuthentication] = None, + encrypted: bool = False, + arrow_client_options: Optional[dict[str, Any]] = None, + user_agent: Optional[str] = None, + ): + """Creates a new GdsArrowClient instance. + + Parameters + ---------- + host: str + The host address of the GDS Arrow server + port: int + The host port of the GDS Arrow server (default is 8491) + auth: Optional[ArrowAuthentication] + An implementation of ArrowAuthentication providing a pair to be used for basic authentication + encrypted: bool + A flag that indicates whether the connection should be encrypted (default is False) + arrow_client_options: Optional[dict[str, Any]] + Additional options to be passed to the Arrow Flight client. + user_agent: Optional[str] + The user agent string to use for the connection. (default is `neo4j-graphdatascience-v[VERSION] pyarrow-v[PYARROW_VERSION]) + retry_config: Optional[RetryConfig] + The retry configuration to use for the Arrow requests send by the client. + """ + self._host = host + self._port = port + self._auth = None + self._encrypted = encrypted + self._arrow_client_options = arrow_client_options + self._user_agent = user_agent + self._retry_config = retry_config + self._logger = logging.getLogger("gds_arrow_client") + self._retry_config = retry_config + if auth: + self._auth = auth + self._auth_middleware = AuthMiddleware(auth) + + self._flight_client = self._instantiate_flight_client() + + def connection_info(self) -> ConnectionInfo: + """ + Returns the host and port of the GDS Arrow server. + + Returns + ------- + tuple[str, int] + the host and port of the GDS Arrow server + """ + return ConnectionInfo(self._host, self._port, self._encrypted) + + def request_token(self) -> Optional[str]: + """ + Requests a token from the server and returns it. + + Returns + ------- + Optional[str] + a token from the server and returns it. + """ + + @retry( + reraise=True, + before=before_log("Request token", self._logger, logging.DEBUG), + retry=self._retry_config.retry, + stop=self._retry_config.stop, + wait=self._retry_config.wait, + ) + def auth_with_retry() -> None: + client = self._flight_client + if self._auth: + auth_pair = self._auth.auth_pair() + client.authenticate_basic_token(auth_pair[0], auth_pair[1]) + + if self._auth: + auth_with_retry() + return self._auth_middleware.token() + else: + return "IGNORED" + + def get_stream(self, ticket: Ticket) -> FlightStreamReader: + return self._flight_client.do_get(ticket) + + def do_action(self, endpoint: str, payload: bytes) -> Iterator[Result]: + return self._flight_client.do_action(Action(endpoint, payload)) # type: ignore + + def do_action_with_retry(self, endpoint: str, payload: bytes) -> Iterator[Result]: + @retry( + reraise=True, + before=before_log("Send action", self._logger, logging.DEBUG), + retry=self._retry_config.retry, + stop=self._retry_config.stop, + wait=self._retry_config.wait, + ) + def run_with_retry() -> Iterator[Result]: + return self.do_action(endpoint, payload) + + return run_with_retry() + + def _instantiate_flight_client(self) -> flight.FlightClient: + location = ( + flight.Location.for_grpc_tls(self._host, self._port) + if self._encrypted + else flight.Location.for_grpc_tcp(self._host, self._port) + ) + client_options: dict[str, Any] = (self._arrow_client_options or {}).copy() + if self._auth: + user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}" + if self._user_agent: + user_agent = self._user_agent + + client_options["middleware"] = [ + AuthFactory(self._auth_middleware), + UserAgentFactory(useragent=user_agent), + ] + + return flight.FlightClient(location, **client_options) + + +@dataclass +class ConnectionInfo: + host: str + port: int + encrypted: bool diff --git a/graphdatascience/arrow_client/middleware/auth_middleware.py b/graphdatascience/arrow_client/middleware/auth_middleware.py new file mode 100644 index 000000000..0acb6e269 --- /dev/null +++ b/graphdatascience/arrow_client/middleware/auth_middleware.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import base64 +import time +from typing import Any, Optional + +from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory + +from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication + + +class AuthFactory(ClientMiddlewareFactory): # type: ignore + def __init__(self, middleware: AuthMiddleware, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._middleware = middleware + + def start_call(self, info: Any) -> AuthMiddleware: + return self._middleware + + +class AuthMiddleware(ClientMiddleware): # type: ignore + def __init__(self, auth: ArrowAuthentication, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._auth = auth + self._token: Optional[str] = None + self._token_timestamp = 0 + + def token(self) -> Optional[str]: + # check whether the token is older than 10 minutes. If so, reset it. + if self._token and int(time.time()) - self._token_timestamp > 600: + self._token = None + + return self._token + + def _set_token(self, token: str) -> None: + self._token = token + self._token_timestamp = int(time.time()) + + def received_headers(self, headers: dict[str, Any]) -> None: + auth_header = headers.get("authorization", None) + if not auth_header: + return + + # the result is always a list + header_value = auth_header[0] + + if not isinstance(header_value, str): + raise ValueError(f"Incompatible header value received from server: `{header_value}`") + + auth_type, token = header_value.split(" ", 1) + if auth_type == "Bearer": + self._set_token(token) + + def sending_headers(self) -> dict[str, str]: + token = self.token() + if token is not None: + return {"authorization": "Bearer " + token} + + auth_pair = self._auth.auth_pair() + auth_token = f"{auth_pair[0]}:{auth_pair[1]}" + auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII") + # There seems to be a bug, `authorization` must be lower key + return {"authorization": auth_token} diff --git a/graphdatascience/arrow_client/middleware/user_agent_middleware.py b/graphdatascience/arrow_client/middleware/user_agent_middleware.py new file mode 100644 index 000000000..704713bb4 --- /dev/null +++ b/graphdatascience/arrow_client/middleware/user_agent_middleware.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import Any + +from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory + + +class UserAgentFactory(ClientMiddlewareFactory): # type: ignore + def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._middleware = UserAgentMiddleware(useragent) + + def start_call(self, info: Any) -> ClientMiddleware: + return self._middleware + + +class UserAgentMiddleware(ClientMiddleware): # type: ignore + def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._useragent = useragent + + def sending_headers(self) -> dict[str, str]: + return {"x-gds-user-agent": self._useragent} + + def received_headers(self, headers: dict[str, Any]) -> None: + pass diff --git a/graphdatascience/arrow_client/v2/api_types.py b/graphdatascience/arrow_client/v2/api_types.py new file mode 100644 index 000000000..9af67ab0d --- /dev/null +++ b/graphdatascience/arrow_client/v2/api_types.py @@ -0,0 +1,16 @@ +from graphdatascience.arrow_client.arrow_base_model import ArrowBaseModel + + +class JobIdConfig(ArrowBaseModel): + job_id: str + + +class JobStatus(ArrowBaseModel): + job_id: str + status: str + progress: float + + +class MutateResult(ArrowBaseModel): + node_properties_written: int + relationships_written: int diff --git a/graphdatascience/arrow_client/v2/data_mapper_utils.py b/graphdatascience/arrow_client/v2/data_mapper_utils.py new file mode 100644 index 000000000..6cd3ebc1f --- /dev/null +++ b/graphdatascience/arrow_client/v2/data_mapper_utils.py @@ -0,0 +1,19 @@ +import json +from typing import Any, Iterator + +from pyarrow._flight import Result + + +def deserialize_single(input_stream: Iterator[Result]) -> dict[str, Any]: + rows = deserialize(input_stream) + if len(rows) != 1: + raise ValueError(f"Expected exactly one result, got {len(rows)}") + + return rows[0] + + +def deserialize(input_stream: Iterator[Result]) -> list[dict[str, Any]]: + def deserialize_row(row: Result): # type:ignore + return json.loads(row.body.to_pybytes().decode()) + + return [deserialize_row(row) for row in list(input_stream)] diff --git a/graphdatascience/arrow_client/v2/job_client.py b/graphdatascience/arrow_client/v2/job_client.py new file mode 100644 index 000000000..f42480652 --- /dev/null +++ b/graphdatascience/arrow_client/v2/job_client.py @@ -0,0 +1,63 @@ +import json +from typing import Any + +from pandas import ArrowDtype, DataFrame +from pyarrow._flight import Ticket + +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.v2.api_types import JobIdConfig, JobStatus +from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize_single + +JOB_STATUS_ENDPOINT = "v2/jobs.status" +RESULTS_SUMMARY_ENDPOINT = "v2/results.summary" + + +class JobClient: + @staticmethod + def run_job_and_wait(client: AuthenticatedArrowClient, endpoint: str, config: dict[str, Any]) -> str: + job_id = JobClient.run_job(client, endpoint, config) + JobClient.wait_for_job(client, job_id) + return job_id + + @staticmethod + def run_job(client: AuthenticatedArrowClient, endpoint: str, config: dict[str, Any]) -> str: + encoded_config = json.dumps(config).encode("utf-8") + res = client.do_action_with_retry(endpoint, encoded_config) + + single = deserialize_single(res) + return JobIdConfig(**single).job_id + + @staticmethod + def wait_for_job(client: AuthenticatedArrowClient, job_id: str) -> None: + while True: + encoded_config = JobIdConfig(jobId=job_id).dump_json().encode("utf-8") + + arrow_res = client.do_action_with_retry(JOB_STATUS_ENDPOINT, encoded_config) + job_status = JobStatus(**deserialize_single(arrow_res)) + if job_status.status == "Done": + break + + @staticmethod + def get_summary(client: AuthenticatedArrowClient, job_id: str) -> dict[str, Any]: + encoded_config = JobIdConfig(jobId=job_id).dump_json().encode("utf-8") + + res = client.do_action_with_retry(RESULTS_SUMMARY_ENDPOINT, encoded_config) + return deserialize_single(res) + + @staticmethod + def stream_results(client: AuthenticatedArrowClient, job_id: str) -> DataFrame: + encoded_config = JobIdConfig(jobId=job_id).dump_json().encode("utf-8") + + res = client.do_action_with_retry("v2/results.stream", encoded_config) + export_job_id = JobIdConfig(**deserialize_single(res)).job_id + + payload = { + "name": export_job_id, + "version": 1, + } + + ticket = Ticket(json.dumps(payload).encode("utf-8")) + with client.get_stream(ticket) as get: + arrow_table = get.read_all() + + return arrow_table.to_pandas(types_mapper=ArrowDtype) # type: ignore diff --git a/graphdatascience/arrow_client/v2/mutation_client.py b/graphdatascience/arrow_client/v2/mutation_client.py new file mode 100644 index 000000000..67677bf7f --- /dev/null +++ b/graphdatascience/arrow_client/v2/mutation_client.py @@ -0,0 +1,16 @@ +import json + +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.v2.api_types import MutateResult +from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize_single + + +class MutationClient: + MUTATE_ENDPOINT = "v2/results.mutate" + + @staticmethod + def mutate_node_property(client: AuthenticatedArrowClient, job_id: str, mutate_property: str) -> MutateResult: + mutate_config = {"jobId": job_id, "mutateProperty": mutate_property} + encoded_config = json.dumps(mutate_config).encode("utf-8") + mutate_arrow_res = client.do_action_with_retry(MutationClient.MUTATE_ENDPOINT, encoded_config) + return MutateResult(**deserialize_single(mutate_arrow_res)) diff --git a/graphdatascience/arrow_client/v2/write_back_client.py b/graphdatascience/arrow_client/v2/write_back_client.py new file mode 100644 index 000000000..e74d4ccfd --- /dev/null +++ b/graphdatascience/arrow_client/v2/write_back_client.py @@ -0,0 +1,54 @@ +import time +from typing import Any, Optional + +from graphdatascience import QueryRunner +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.call_parameters import CallParameters +from graphdatascience.query_runner.protocol.write_protocols import WriteProtocol +from graphdatascience.query_runner.termination_flag import TerminationFlagNoop +from graphdatascience.session.dbms.protocol_resolver import ProtocolVersionResolver + + +class WriteBackClient: + def __init__(self, arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner): + self._arrow_client = arrow_client + self._query_runner = query_runner + + protocol_version = ProtocolVersionResolver(query_runner).resolve() + self._write_protocol = WriteProtocol.select(protocol_version) + + # TODO: Add progress logging + # TODO: Support setting custom writeProperties and relationshipTypes + def write(self, graph_name: str, job_id: str, concurrency: Optional[int]) -> int: + arrow_config = self._arrow_configuration() + + configuration = {} + if concurrency is not None: + configuration["concurrency"] = concurrency + + write_back_params = CallParameters( + graphName=graph_name, + jobId=job_id, + arrowConfiguration=arrow_config, + configuration=configuration, + ) + + start_time = time.time() + + self._write_protocol.run_write_back(self._query_runner, write_back_params, None, TerminationFlagNoop()) + + return int((time.time() - start_time) * 1000) + + def _arrow_configuration(self) -> dict[str, Any]: + connection_info = self._arrow_client.connection_info() + token = self._arrow_client.request_token() + if token is None: + token = "IGNORED" + arrow_config = { + "host": connection_info.host, + "port": connection_info.port, + "token": token, + "encrypted": connection_info.encrypted, + } + + return arrow_config diff --git a/graphdatascience/procedure_surface/__init__.py b/graphdatascience/procedure_surface/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/procedure_surface/arrow/__init__.py b/graphdatascience/procedure_surface/arrow/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/tests/unit/arrow_client/V2/__init__.py b/graphdatascience/tests/unit/arrow_client/V2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/tests/unit/arrow_client/V2/test_data_mapper_utils.py b/graphdatascience/tests/unit/arrow_client/V2/test_data_mapper_utils.py new file mode 100644 index 000000000..0465dc565 --- /dev/null +++ b/graphdatascience/tests/unit/arrow_client/V2/test_data_mapper_utils.py @@ -0,0 +1,26 @@ +from typing import Iterator + +import pytest +from pyarrow._flight import Result + +from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize_single +from graphdatascience.tests.unit.arrow_client.arrow_test_utils import ArrowTestResult + + +def test_deserialize_single_success() -> None: + input_stream = iter([ArrowTestResult({"key": "value"})]) + expected = {"key": "value"} + actual = deserialize_single(input_stream) + assert expected == actual + + +def test_deserialize_single_raises_on_empty_stream() -> None: + input_stream: Iterator[Result] = iter([]) + with pytest.raises(ValueError, match="Expected exactly one result, got 0"): + deserialize_single(input_stream) + + +def test_deserialize_single_raises_on_multiple_results() -> None: + input_stream = iter([ArrowTestResult({"key1": "value1"}), ArrowTestResult({"key2": "value2"})]) + with pytest.raises(ValueError, match="Expected exactly one result, got 2"): + deserialize_single(input_stream) diff --git a/graphdatascience/tests/unit/arrow_client/V2/test_job_client.py b/graphdatascience/tests/unit/arrow_client/V2/test_job_client.py new file mode 100644 index 000000000..5d9458a1b --- /dev/null +++ b/graphdatascience/tests/unit/arrow_client/V2/test_job_client.py @@ -0,0 +1,109 @@ +import json +import unittest +from unittest.mock import MagicMock + +from graphdatascience.arrow_client.v2.api_types import JobIdConfig, JobStatus +from graphdatascience.arrow_client.v2.job_client import JobClient +from graphdatascience.tests.unit.arrow_client.arrow_test_utils import ArrowTestResult + + +class TestJobClient(unittest.TestCase): + def setUp(self) -> None: + self.mock_client = MagicMock() + + def test_run_job(self) -> None: + job_id = "test-job-123" + endpoint = "v2/test.endpoint" + config = {"param1": "value1", "param2": 42} + + self.mock_client.do_action_with_retry.return_value = iter([ArrowTestResult({"jobId": job_id})]) + + result = JobClient.run_job(self.mock_client, endpoint, config) + + expected_config = json.dumps(config).encode("utf-8") + self.mock_client.do_action_with_retry.assert_called_once_with(endpoint, expected_config) + self.assertEqual(result, job_id) + + def test_run_job_and_wait( + self, + ) -> None: + job_id = "test-job-456" + endpoint = "v2/test.endpoint" + config = {"param": "value"} + + job_id_config = JobIdConfig(jobId=job_id) + + status = JobStatus( + jobId=job_id, + progress=1.0, + status="Done", + ) + + do_action_with_retry = MagicMock() + do_action_with_retry.side_effect = [ + iter([ArrowTestResult(job_id_config.dump_camel())]), + iter([ArrowTestResult(status.dump_camel())]), + ] + + self.mock_client.do_action_with_retry = do_action_with_retry + + result = JobClient.run_job_and_wait(self.mock_client, endpoint, config) + + do_action_with_retry.assert_called_with("v2/jobs.status", job_id_config.dump_json().encode("utf-8")) + self.assertEqual(result, job_id) + + def test_wait_for_job_completes_immediately(self) -> None: + job_id = "test-job-789" + + status = JobStatus( + jobId=job_id, + progress=1.0, + status="Done", + ) + + self.mock_client.do_action_with_retry.return_value = iter([ArrowTestResult(status.dump_camel())]) + + JobClient.wait_for_job(self.mock_client, job_id) + + self.mock_client.do_action_with_retry.assert_called_once_with( + "v2/jobs.status", JobIdConfig(jobId=job_id).dump_json().encode("utf-8") + ) + + def test_wait_for_job_waits_for_completion(self) -> None: + job_id = "test-job-waiting" + status_running = JobStatus( + jobId=job_id, + progress=0.5, + status="RUNNING", + ) + status_done = JobStatus( + jobId=job_id, + progress=1.0, + status="Done", + ) + + do_action_with_retry = MagicMock() + do_action_with_retry.side_effect = [ + iter([ArrowTestResult(status_running.dump_camel())]), + iter([ArrowTestResult(status_done.dump_camel())]), + ] + + self.mock_client.do_action_with_retry = do_action_with_retry + + JobClient.wait_for_job(self.mock_client, job_id) + + self.assertEqual(self.mock_client.do_action_with_retry.call_count, 2) + + def test_get_summary(self) -> None: + # Setup + job_id = "summary-job-123" + expected_summary = {"nodeCount": 100, "relationshipCount": 200, "requiredMemory": "1GB"} + + self.mock_client.do_action_with_retry.return_value = iter([ArrowTestResult(expected_summary)]) + + result = JobClient.get_summary(self.mock_client, job_id) + + self.mock_client.do_action_with_retry.assert_called_once_with( + "v2/results.summary", JobIdConfig(jobId=job_id).dump_json().encode("utf-8") + ) + self.assertEqual(result, expected_summary) diff --git a/graphdatascience/tests/unit/arrow_client/V2/test_mutation_client.py b/graphdatascience/tests/unit/arrow_client/V2/test_mutation_client.py new file mode 100644 index 000000000..3f86a9eda --- /dev/null +++ b/graphdatascience/tests/unit/arrow_client/V2/test_mutation_client.py @@ -0,0 +1,27 @@ +import unittest +from unittest.mock import MagicMock + +from graphdatascience.arrow_client.v2.api_types import MutateResult +from graphdatascience.arrow_client.v2.mutation_client import MutationClient +from graphdatascience.tests.unit.arrow_client.arrow_test_utils import ArrowTestResult + + +class TestMutationClient(unittest.TestCase): + def setUp(self) -> None: + self.mock_client = MagicMock() + + def test_mutate_node_property_success(self) -> None: + job_id = "test-job-123" + expected_mutation_result = MutateResult(nodePropertiesWritten=42, relationshipsWritten=1337) + + self.mock_client.do_action_with_retry.return_value = iter( + [ArrowTestResult(expected_mutation_result.dump_camel())] + ) + + result = MutationClient.mutate_node_property(self.mock_client, job_id, "propertyName") + + assert result == expected_mutation_result + + self.mock_client.do_action_with_retry.assert_called_once_with( + MutationClient.MUTATE_ENDPOINT, b'{"jobId": "test-job-123", "mutateProperty": "propertyName"}' + ) diff --git a/graphdatascience/tests/unit/arrow_client/V2/test_write_back_client.py b/graphdatascience/tests/unit/arrow_client/V2/test_write_back_client.py new file mode 100644 index 000000000..0b414e982 --- /dev/null +++ b/graphdatascience/tests/unit/arrow_client/V2/test_write_back_client.py @@ -0,0 +1,57 @@ +from typing import Optional +from unittest.mock import Mock + +import pytest +from pandas import DataFrame + +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.v2.write_back_client import WriteBackClient +from graphdatascience.tests.unit.conftest import DEFAULT_SERVER_VERSION, CollectingQueryRunner + + +@pytest.fixture +def mock_arrow_client() -> AuthenticatedArrowClient: + client = Mock(spec=AuthenticatedArrowClient) + client.connection_info.return_value = Mock(host="localhost", port=8080, encrypted=False) + client.request_token.return_value = "test_token" + return client + + +@pytest.fixture +def write_back_client(mock_arrow_client: AuthenticatedArrowClient) -> WriteBackClient: + query_runner = CollectingQueryRunner( + DEFAULT_SERVER_VERSION, + { + "protocol.version": DataFrame([{"version": "v3"}]), + }, + ) + return WriteBackClient(mock_arrow_client, query_runner) + + +def test_write_back_client_initialization(write_back_client: WriteBackClient) -> None: + assert isinstance(write_back_client, WriteBackClient) + + +def test_arrow_configuration(write_back_client: WriteBackClient, mock_arrow_client: AuthenticatedArrowClient) -> None: + expected_config = { + "host": "localhost", + "port": 8080, + "token": "test_token", + "encrypted": False, + } + + config = write_back_client._arrow_configuration() + assert config == expected_config + + +def test_write_calls_run_write_back(write_back_client: WriteBackClient) -> None: + graph_name = "test_graph" + job_id = "123" + concurrency: Optional[int] = 4 + + write_back_client._write_protocol.run_write_back = Mock() # type: ignore + + duration = write_back_client.write(graph_name, job_id, concurrency) + + write_back_client._write_protocol.run_write_back.assert_called_once() + assert duration >= 0 diff --git a/graphdatascience/tests/unit/arrow_client/__init__.py b/graphdatascience/tests/unit/arrow_client/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/tests/unit/arrow_client/arrow_test_utils.py b/graphdatascience/tests/unit/arrow_client/arrow_test_utils.py new file mode 100644 index 000000000..674da4c20 --- /dev/null +++ b/graphdatascience/tests/unit/arrow_client/arrow_test_utils.py @@ -0,0 +1,20 @@ +import json +from typing import Any + +from pyarrow._flight import Result + + +class ArrowTestResult(Result): # type:ignore + def __init__(self, body: dict[str, Any]): + self._body = json.dumps(body).encode() + + @property + def body(self) -> Any: + class MockBody: + def __init__(self, data: bytes): + self._data = data + + def to_pybytes(self) -> bytes: + return self._data + + return MockBody(self._body) diff --git a/graphdatascience/tests/unit/arrow_client/middleware/__init__.py b/graphdatascience/tests/unit/arrow_client/middleware/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/tests/unit/arrow_client/middleware/test_auth_middleware.py b/graphdatascience/tests/unit/arrow_client/middleware/test_auth_middleware.py new file mode 100644 index 000000000..c40e9e7eb --- /dev/null +++ b/graphdatascience/tests/unit/arrow_client/middleware/test_auth_middleware.py @@ -0,0 +1,61 @@ +import base64 +from unittest.mock import Mock, patch + +import pytest + +from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication +from graphdatascience.arrow_client.middleware.auth_middleware import AuthMiddleware + + +@pytest.fixture +def mock_auth() -> ArrowAuthentication: + class MockAuth(ArrowAuthentication): + def auth_pair(self) -> tuple[str, str]: + return "username", "password" + + return MockAuth() + + +@pytest.fixture +def auth_middleware(mock_auth: Mock) -> AuthMiddleware: + return AuthMiddleware(mock_auth) + + +def test_token_initially_none(auth_middleware: AuthMiddleware) -> None: + assert auth_middleware.token() is None + + +def test_set_token_updates_token_and_timestamp(auth_middleware: AuthMiddleware) -> None: + with patch("time.time", return_value=1000000): + auth_middleware._set_token("test_token") + assert auth_middleware.token() == "test_token" + + +def test_token_expires_after_10_minutes(auth_middleware: AuthMiddleware) -> None: + with patch("time.time", side_effect=[1000000, 1000601]): + auth_middleware._set_token("test_token") + assert auth_middleware.token() is None + + +def test_received_headers_sets_bearer_token(auth_middleware: AuthMiddleware) -> None: + headers = {"authorization": ["Bearer test_token"]} + auth_middleware.received_headers(headers) + assert auth_middleware.token() == "test_token" + + +def test_received_headers_raises_error_with_invalid_header(auth_middleware: AuthMiddleware) -> None: + headers = {"authorization": [12345]} # Invalid header value type + with pytest.raises(ValueError, match="Incompatible header value received from server: `12345`"): + auth_middleware.received_headers(headers) + + +def test_sending_headers_with_existing_token(auth_middleware: AuthMiddleware) -> None: + auth_middleware._set_token("test_token") + headers = auth_middleware.sending_headers() + assert headers == {"authorization": "Bearer test_token"} + + +def test_sending_headers_with_new_basic_auth(auth_middleware: AuthMiddleware, mock_auth: ArrowAuthentication) -> None: + headers = auth_middleware.sending_headers() + expected_auth_token = "Basic " + base64.b64encode(b"username:password").decode("ASCII") + assert headers == {"authorization": expected_auth_token} diff --git a/graphdatascience/tests/unit/arrow_client/test_authenticated_flight_client.py b/graphdatascience/tests/unit/arrow_client/test_authenticated_flight_client.py new file mode 100644 index 000000000..cb3ed216c --- /dev/null +++ b/graphdatascience/tests/unit/arrow_client/test_authenticated_flight_client.py @@ -0,0 +1,52 @@ +# graphdatascience/tests/test_authenticated_flight_client.py +import pytest +from pyarrow._flight import FlightInternalError, FlightTimedOutError, FlightUnavailableError +from tenacity import retry_any, retry_if_exception_type, stop_after_attempt, stop_after_delay, wait_exponential + +from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication +from graphdatascience.arrow_client.arrow_info import ArrowInfo +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient, ConnectionInfo +from graphdatascience.retry_utils.retry_config import RetryConfig + + +@pytest.fixture +def arrow_info() -> ArrowInfo: + return ArrowInfo(listenAddress="localhost:8491", enabled=True, running=True, versions=["1.0.0"]) + + +@pytest.fixture +def retry_config() -> RetryConfig: + return RetryConfig( + retry=retry_any( + retry_if_exception_type(FlightTimedOutError), + retry_if_exception_type(FlightUnavailableError), + retry_if_exception_type(FlightInternalError), + ), + stop=(stop_after_delay(10) | stop_after_attempt(5)), + wait=wait_exponential(multiplier=1, min=1, max=10), + ) + + +@pytest.fixture +def mock_auth() -> ArrowAuthentication: + class MockAuthentication(ArrowAuthentication): + def auth_pair(self) -> tuple[str, str]: + return ("mock_user", "mock_password") + + return MockAuthentication() + + +def test_create_authenticated_arrow_client( + arrow_info: ArrowInfo, retry_config: RetryConfig, mock_auth: ArrowAuthentication +) -> None: + client = AuthenticatedArrowClient.create( + arrow_info=arrow_info, auth=mock_auth, encrypted=True, retry_config=retry_config + ) + assert isinstance(client, AuthenticatedArrowClient) + assert client.connection_info() == ConnectionInfo("localhost", 8491, encrypted=True) + + +def test_connection_info(arrow_info: ArrowInfo, retry_config: RetryConfig) -> None: + client = AuthenticatedArrowClient(host="localhost", port=8491, retry_config=retry_config) + connection_info = client.connection_info() + assert connection_info == ConnectionInfo("localhost", 8491, encrypted=False) diff --git a/mypy.ini b/mypy.ini index 63c3644e7..fbaad268b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -4,6 +4,7 @@ exclude = (^build|^\.?venv) untyped_calls_exclude=nbconvert # numpy 2.x needs some type-ignore previous versions dont disable_error_code=unused-ignore +plugins = pydantic.mypy [mypy-pyarrow] ignore_missing_imports = True