diff --git a/graphdatascience/arrow_client/__init__.py b/graphdatascience/arrow_client/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/query_runner/arrow_authentication.py b/graphdatascience/arrow_client/arrow_authentication.py similarity index 100% rename from graphdatascience/query_runner/arrow_authentication.py rename to graphdatascience/arrow_client/arrow_authentication.py diff --git a/graphdatascience/query_runner/arrow_info.py b/graphdatascience/arrow_client/arrow_info.py similarity index 85% rename from graphdatascience/query_runner/arrow_info.py rename to graphdatascience/arrow_client/arrow_info.py index 8a2399182..a11b48c49 100644 --- a/graphdatascience/query_runner/arrow_info.py +++ b/graphdatascience/arrow_client/arrow_info.py @@ -2,8 +2,8 @@ from dataclasses import dataclass -from ..query_runner.query_runner import QueryRunner -from ..server_version.server_version import ServerVersion +from graphdatascience.query_runner.query_runner import QueryRunner +from graphdatascience.server_version.server_version import ServerVersion @dataclass(frozen=True) diff --git a/graphdatascience/arrow_client/authenticated_arrow_client.py b/graphdatascience/arrow_client/authenticated_arrow_client.py new file mode 100644 index 000000000..8fea1e3c9 --- /dev/null +++ b/graphdatascience/arrow_client/authenticated_arrow_client.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any, Iterator, Optional + +from pyarrow import __version__ as arrow_version +from pyarrow import flight +from pyarrow._flight import ( + Action, + FlightInternalError, + FlightStreamReader, + FlightTimedOutError, + FlightUnavailableError, + Result, + Ticket, +) +from tenacity import retry, retry_any, retry_if_exception_type, stop_after_attempt, stop_after_delay, wait_exponential + +from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication +from graphdatascience.arrow_client.arrow_info import ArrowInfo +from graphdatascience.retry_utils.retry_config import RetryConfig + +from ..retry_utils.retry_utils import before_log +from ..version import __version__ +from .middleware.AuthMiddleware import AuthFactory, AuthMiddleware +from .middleware.UserAgentMiddleware import UserAgentFactory + + +class AuthenticatedArrowClient: + @staticmethod + def create( + arrow_info: ArrowInfo, + auth: Optional[ArrowAuthentication] = None, + encrypted: bool = False, + disable_server_verification: bool = False, + tls_root_certs: Optional[bytes] = None, + connection_string_override: Optional[str] = None, + retry_config: Optional[RetryConfig] = None, + ) -> AuthenticatedArrowClient: + connection_string: str + if connection_string_override is not None: + connection_string = connection_string_override + else: + connection_string = arrow_info.listenAddress + + host, port = connection_string.split(":") + + if retry_config is None: + retry_config = RetryConfig( + retry=retry_any( + retry_if_exception_type(FlightTimedOutError), + retry_if_exception_type(FlightUnavailableError), + retry_if_exception_type(FlightInternalError), + ), + stop=(stop_after_delay(10) | stop_after_attempt(5)), + wait=wait_exponential(multiplier=1, min=1, max=10), + ) + + return AuthenticatedArrowClient( + host, + retry_config, + int(port), + auth, + encrypted, + disable_server_verification, + tls_root_certs, + ) + + def __init__( + self, + host: str, + retry_config: RetryConfig, + port: int = 8491, + auth: Optional[ArrowAuthentication] = None, + encrypted: bool = False, + disable_server_verification: bool = False, + tls_root_certs: Optional[bytes] = None, + user_agent: Optional[str] = None, + ): + """Creates a new GdsArrowClient instance. + + Parameters + ---------- + host: str + The host address of the GDS Arrow server + port: int + The host port of the GDS Arrow server (default is 8491) + auth: Optional[ArrowAuthentication] + Either an implementation of ArrowAuthentication providing a pair to be used for basic authentication, or a username, password tuple + encrypted: bool + A flag that indicates whether the connection should be encrypted (default is False) + disable_server_verification: bool + A flag that disables server verification for TLS connections (default is False) + tls_root_certs: Optional[bytes] + PEM-encoded certificates that are used for the connection to the GDS Arrow Flight server + user_agent: Optional[str] + The user agent string to use for the connection. (default is `neo4j-graphdatascience-v[VERSION] pyarrow-v[PYARROW_VERSION]) + retry_config: Optional[RetryConfig] + The retry configuration to use for the Arrow requests send by the client. + """ + self._host = host + self._port = port + self._auth = None + self._encrypted = encrypted + self._disable_server_verification = disable_server_verification + self._tls_root_certs = tls_root_certs + self._user_agent = user_agent + self._retry_config = retry_config + self._logger = logging.getLogger("gds_arrow_client") + self._retry_config = RetryConfig( + retry=retry_any( + retry_if_exception_type(FlightTimedOutError), + retry_if_exception_type(FlightUnavailableError), + retry_if_exception_type(FlightInternalError), + ), + stop=(stop_after_delay(10) | stop_after_attempt(5)), + wait=wait_exponential(multiplier=1, min=1, max=10), + ) + + if auth: + self._auth = auth + self._auth_middleware = AuthMiddleware(auth) + + self._flight_client = self._instantiate_flight_client() + + def connection_info(self) -> ConnectionInfo: + """ + Returns the host and port of the GDS Arrow server. + + Returns + ------- + tuple[str, int] + the host and port of the GDS Arrow server + """ + return ConnectionInfo(self._host, self._port, self._encrypted) + + def request_token(self) -> Optional[str]: + """ + Requests a token from the server and returns it. + + Returns + ------- + Optional[str] + a token from the server and returns it. + """ + + @retry( + reraise=True, + before=before_log("Request token", self._logger, logging.DEBUG), + retry=self._retry_config.retry, + stop=self._retry_config.stop, + wait=self._retry_config.wait, + ) + def auth_with_retry() -> None: + client = self._flight_client + if self._auth: + auth_pair = self._auth.auth_pair() + client.authenticate_basic_token(auth_pair[0], auth_pair[1]) + + if self._auth: + auth_with_retry() + return self._auth_middleware.token() + else: + return "IGNORED" + + def get_stream(self, ticket: Ticket) -> FlightStreamReader: + return self._flight_client.do_get(ticket) + + def do_action(self, endpoint: str, payload: bytes) -> Iterator[Result]: + return self._flight_client.do_action(Action(endpoint, payload)) # type: ignore + + def do_action_with_retry(self, endpoint: str, payload: bytes) -> Iterator[Result]: + @retry( + reraise=True, + before=before_log("Send action", self._logger, logging.DEBUG), + retry=self._retry_config.retry, + stop=self._retry_config.stop, + wait=self._retry_config.wait, + ) + def run_with_retry() -> Iterator[Result]: + return self.do_action(endpoint, payload) + + return run_with_retry() + + def _instantiate_flight_client(self) -> flight.FlightClient: + location = ( + flight.Location.for_grpc_tls(self._host, self._port) + if self._encrypted + else flight.Location.for_grpc_tcp(self._host, self._port) + ) + client_options: dict[str, Any] = {"disable_server_verification": self._disable_server_verification} + if self._auth: + user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}" + if self._user_agent: + user_agent = self._user_agent + + client_options["middleware"] = [ + AuthFactory(self._auth_middleware), + UserAgentFactory(useragent=user_agent), + ] + if self._tls_root_certs: + client_options["tls_root_certs"] = self._tls_root_certs + return flight.FlightClient(location, **client_options) + + +@dataclass +class ConnectionInfo: + host: str + port: int + encrypted: bool diff --git a/graphdatascience/arrow_client/data_mapper.py b/graphdatascience/arrow_client/data_mapper.py new file mode 100644 index 000000000..6e326af14 --- /dev/null +++ b/graphdatascience/arrow_client/data_mapper.py @@ -0,0 +1,55 @@ +import dataclasses +import json +from dataclasses import fields +from typing import Any, Dict, Iterator, Type, TypeVar + +from pyarrow._flight import Result + + +class DataMapper: + T = TypeVar("T") + + @staticmethod + def deserialize_single(input_stream: Iterator[Result], cls: Type[T]) -> T: + rows = DataMapper.deserialize(input_stream, cls) + + if len(rows) != 1: + raise ValueError(f"Expected exactly one row, got {len(rows)}") + + return rows[0] + + @staticmethod + def deserialize(input_stream: Iterator[Result], cls: Type[T]) -> list[T]: + def deserialize_row(row: Result): # type:ignore + result_dicts = json.loads(row.body.to_pybytes().decode()) + if cls == Dict: + return result_dicts + return DataMapper.dict_to_dataclass(result_dicts, cls) + + return [deserialize_row(row) for row in list(input_stream)] + + @staticmethod + def dict_to_dataclass(data: Dict[str, Any], cls: Type[T], strict: bool = False) -> T: + """ + Convert a dictionary to a dataclass instance with nested dataclass support. + """ + if not dataclasses.is_dataclass(cls): + raise ValueError(f"{cls} is not a dataclass") + + field_dict = {f.name: f for f in fields(cls)} + filtered_data = {} + + for key, value in data.items(): + if key in field_dict: + field = field_dict[key] + field_type = field.type + + # Handle nested dataclasses + if dataclasses.is_dataclass(field_type) and isinstance(value, dict): + filtered_data[key] = DataMapper.dict_to_dataclass(value, field_type, strict) # type:ignore + else: + filtered_data[key] = value + elif strict: + raise ValueError(f"Extra field '{key}' not allowed in {cls.__name__}") + + return cls(**filtered_data) # type: ignore diff --git a/graphdatascience/arrow_client/middleware/AuthMiddleware.py b/graphdatascience/arrow_client/middleware/AuthMiddleware.py new file mode 100644 index 000000000..0acb6e269 --- /dev/null +++ b/graphdatascience/arrow_client/middleware/AuthMiddleware.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import base64 +import time +from typing import Any, Optional + +from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory + +from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication + + +class AuthFactory(ClientMiddlewareFactory): # type: ignore + def __init__(self, middleware: AuthMiddleware, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._middleware = middleware + + def start_call(self, info: Any) -> AuthMiddleware: + return self._middleware + + +class AuthMiddleware(ClientMiddleware): # type: ignore + def __init__(self, auth: ArrowAuthentication, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._auth = auth + self._token: Optional[str] = None + self._token_timestamp = 0 + + def token(self) -> Optional[str]: + # check whether the token is older than 10 minutes. If so, reset it. + if self._token and int(time.time()) - self._token_timestamp > 600: + self._token = None + + return self._token + + def _set_token(self, token: str) -> None: + self._token = token + self._token_timestamp = int(time.time()) + + def received_headers(self, headers: dict[str, Any]) -> None: + auth_header = headers.get("authorization", None) + if not auth_header: + return + + # the result is always a list + header_value = auth_header[0] + + if not isinstance(header_value, str): + raise ValueError(f"Incompatible header value received from server: `{header_value}`") + + auth_type, token = header_value.split(" ", 1) + if auth_type == "Bearer": + self._set_token(token) + + def sending_headers(self) -> dict[str, str]: + token = self.token() + if token is not None: + return {"authorization": "Bearer " + token} + + auth_pair = self._auth.auth_pair() + auth_token = f"{auth_pair[0]}:{auth_pair[1]}" + auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII") + # There seems to be a bug, `authorization` must be lower key + return {"authorization": auth_token} diff --git a/graphdatascience/arrow_client/middleware/UserAgentMiddleware.py b/graphdatascience/arrow_client/middleware/UserAgentMiddleware.py new file mode 100644 index 000000000..704713bb4 --- /dev/null +++ b/graphdatascience/arrow_client/middleware/UserAgentMiddleware.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import Any + +from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory + + +class UserAgentFactory(ClientMiddlewareFactory): # type: ignore + def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._middleware = UserAgentMiddleware(useragent) + + def start_call(self, info: Any) -> ClientMiddleware: + return self._middleware + + +class UserAgentMiddleware(ClientMiddleware): # type: ignore + def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._useragent = useragent + + def sending_headers(self) -> dict[str, str]: + return {"x-gds-user-agent": self._useragent} + + def received_headers(self, headers: dict[str, Any]) -> None: + pass diff --git a/graphdatascience/arrow_client/v2/api_types.py b/graphdatascience/arrow_client/v2/api_types.py new file mode 100644 index 000000000..4390ec57d --- /dev/null +++ b/graphdatascience/arrow_client/v2/api_types.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass + + +@dataclass(frozen=True, repr=True) +class JobIdConfig: + jobId: str + + +@dataclass(frozen=True, repr=True) +class JobStatus: + jobId: str + status: str + progress: float + + +@dataclass(frozen=True, repr=True) +class MutateResult: + nodePropertiesWritten: int + relationshipsWritten: int diff --git a/graphdatascience/arrow_client/v2/job_client.py b/graphdatascience/arrow_client/v2/job_client.py new file mode 100644 index 000000000..afa66a090 --- /dev/null +++ b/graphdatascience/arrow_client/v2/job_client.py @@ -0,0 +1,64 @@ +import json +from typing import Any, Dict + +from pandas import ArrowDtype, DataFrame +from pyarrow._flight import Ticket + +from graphdatascience.arrow_client.authenticated_arrow_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.data_mapper import DataMapper +from graphdatascience.arrow_client.v2.api_types import JobIdConfig, JobStatus + +JOB_STATUS_ENDPOINT = "v2/jobs.status" +RESULTS_SUMMARY_ENDPOINT = "v2/results.summary" + + +class JobClient: + @staticmethod + def run_job_and_wait(client: AuthenticatedArrowClient, endpoint: str, config: dict[str, Any]) -> str: + job_id = JobClient.run_job(client, endpoint, config) + JobClient.wait_for_job(client, job_id) + return job_id + + @staticmethod + def run_job(client: AuthenticatedArrowClient, endpoint: str, config: dict[str, Any]) -> str: + encoded_config = json.dumps(config).encode("utf-8") + res = client.do_action_with_retry(endpoint, encoded_config) + return DataMapper.deserialize_single(res, JobIdConfig).jobId + + @staticmethod + def wait_for_job(client: AuthenticatedArrowClient, job_id: str) -> None: + while True: + job_id_config = {"jobId": job_id} + encoded_config = json.dumps(job_id_config).encode("utf-8") + + arrow_res = client.do_action_with_retry(JOB_STATUS_ENDPOINT, encoded_config) + job_status = DataMapper.deserialize_single(arrow_res, JobStatus) + if job_status.status == "Done": + break + + @staticmethod + def get_summary(client: AuthenticatedArrowClient, job_id: str) -> dict[str, Any]: + job_id_config = {"jobId": job_id} + encoded_config = json.dumps(job_id_config).encode("utf-8") + + res = client.do_action_with_retry(RESULTS_SUMMARY_ENDPOINT, encoded_config) + return DataMapper.deserialize_single(res, Dict[str, Any]) + + @staticmethod + def stream_results(client: AuthenticatedArrowClient, job_id: str) -> DataFrame: + job_id_config = {"jobId": job_id} + encoded_config = json.dumps(job_id_config).encode("utf-8") + + res = client.do_action_with_retry("v2/results.stream", encoded_config) + export_job_id = DataMapper.deserialize_single(res, JobIdConfig).jobId + + payload = { + "name": export_job_id, + "version": 1, + } + + ticket = Ticket(json.dumps(payload).encode("utf-8")) + with client.get_stream(ticket) as get: + arrow_table = get.read_all() + + return arrow_table.to_pandas(types_mapper=ArrowDtype) # type: ignore diff --git a/graphdatascience/arrow_client/v2/mutation_client.py b/graphdatascience/arrow_client/v2/mutation_client.py new file mode 100644 index 000000000..92cdf2f54 --- /dev/null +++ b/graphdatascience/arrow_client/v2/mutation_client.py @@ -0,0 +1,16 @@ +import json + +from graphdatascience.arrow_client.authenticated_arrow_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.data_mapper import DataMapper +from graphdatascience.arrow_client.v2.api_types import MutateResult + + +class MutationClient: + MUTATE_ENDPOINT = "v2/results.mutate" + + @staticmethod + def mutate_node_property(client: AuthenticatedArrowClient, job_id: str, mutate_property: str) -> MutateResult: + mutate_config = {"jobId": job_id, "mutateProperty": mutate_property} + encoded_config = json.dumps(mutate_config).encode("utf-8") + mutate_arrow_res = client.do_action_with_retry(MutationClient.MUTATE_ENDPOINT, encoded_config) + return DataMapper.deserialize_single(mutate_arrow_res, MutateResult) diff --git a/graphdatascience/arrow_client/v2/write_back_client.py b/graphdatascience/arrow_client/v2/write_back_client.py new file mode 100644 index 000000000..fb5441db4 --- /dev/null +++ b/graphdatascience/arrow_client/v2/write_back_client.py @@ -0,0 +1,54 @@ +import time +from typing import Any, Optional + +from graphdatascience import QueryRunner +from graphdatascience.arrow_client.authenticated_arrow_client import AuthenticatedArrowClient +from graphdatascience.call_parameters import CallParameters +from graphdatascience.query_runner.protocol.write_protocols import WriteProtocol +from graphdatascience.query_runner.termination_flag import TerminationFlagNoop +from graphdatascience.session.dbms.protocol_resolver import ProtocolVersionResolver + + +class WriteBackClient: + def __init__(self, arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner): + self._arrow_client = arrow_client + self._query_runner = query_runner + + protocol_version = ProtocolVersionResolver(query_runner).resolve() + self._write_protocol = WriteProtocol.select(protocol_version) + + # TODO: Add progress logging + # TODO: Support setting custom writeProperties and relationshipTypes + def write(self, graph_name: str, job_id: str, concurrency: Optional[int]) -> int: + arrow_config = self._arrow_configuration() + + configuration = {} + if concurrency is not None: + configuration["concurrency"] = concurrency + + write_back_params = CallParameters( + graphName=graph_name, + jobId=job_id, + arrowConfiguration=arrow_config, + configuration=configuration, + ) + + start_time = time.time() + + self._write_protocol.run_write_back(self._query_runner, write_back_params, None, TerminationFlagNoop()) + + return int((time.time() - start_time) * 1000) + + def _arrow_configuration(self) -> dict[str, Any]: + connection_info = self._arrow_client.connection_info() + token = self._arrow_client.request_token() + if token is None: + token = "IGNORED" + arrow_config = { + "host": connection_info.host, + "port": connection_info.port, + "token": token, + "encrypted": connection_info.encrypted, + } + + return arrow_config diff --git a/graphdatascience/graph_data_science.py b/graphdatascience/graph_data_science.py index 56c91b57a..523cf6def 100644 --- a/graphdatascience/graph_data_science.py +++ b/graphdatascience/graph_data_science.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os import warnings from types import TracebackType from typing import Any, Optional, Type, Union @@ -8,13 +9,15 @@ from neo4j import Driver from pandas import DataFrame -from graphdatascience.query_runner.arrow_authentication import UsernamePasswordAuthentication +from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication +from graphdatascience.arrow_client.arrow_info import ArrowInfo +from graphdatascience.procedure_surface.cypher.wcc_cypher_endpoints import WccCypherEndpoints from .call_builder import IndirectCallBuilder from .endpoints import AlphaEndpoints, BetaEndpoints, DirectEndpoints from .error.uncallable_namespace import UncallableNamespace from .graph.graph_proc_runner import GraphProcRunner -from .query_runner.arrow_info import ArrowInfo +from .procedure_surface.api.wcc_endpoints import WccEndpoints from .query_runner.arrow_query_runner import ArrowQueryRunner from .query_runner.neo4j_query_runner import Neo4jQueryRunner from .query_runner.query_runner import QueryRunner @@ -117,10 +120,21 @@ def __init__( self._query_runner.set_show_progress(show_progress) super().__init__(self._query_runner, namespace="gds", server_version=self._server_version) + self._wcc_endpoints: Optional[WccEndpoints] = None + if os.environ.get("ENABLE_EXPLICIT_ENDPOINTS") is not None: + self._wcc_endpoints = WccCypherEndpoints(self._query_runner) + @property def graph(self) -> GraphProcRunner: return GraphProcRunner(self._query_runner, f"{self._namespace}.graph", self._server_version) + @property + def wcc(self) -> Union[WccEndpoints, IndirectCallBuilder]: + if self._wcc_endpoints is None: + return IndirectCallBuilder(self._query_runner, f"gds.{self._namespace}.wcc", self._server_version) + + return self._wcc_endpoints + @property def util(self) -> UtilProcRunner: return UtilProcRunner(self._query_runner, f"{self._namespace}.util", self._server_version) diff --git a/graphdatascience/procedure_surface/__init__.py b/graphdatascience/procedure_surface/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/procedure_surface/api/__init__.py b/graphdatascience/procedure_surface/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/procedure_surface/api/wcc_endpoints.py b/graphdatascience/procedure_surface/api/wcc_endpoints.py new file mode 100644 index 000000000..2096e653f --- /dev/null +++ b/graphdatascience/procedure_surface/api/wcc_endpoints.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, List, Optional + +from pandas import DataFrame + +from ...graph.graph_object import Graph + + +class WccEndpoints(ABC): + """ + Abstract base class defining the API for the Weakly Connected Components (WCC) algorithm. + """ + + @abstractmethod + def mutate( + self, + G: Graph, + mutate_property: str, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + ) -> WccMutateResult: + """ + Executes the WCC algorithm and writes the results to the in-memory graph as node properties. + + Parameters + ---------- + G : Graph + The graph to run the algorithm on + mutate_property : str + The property name to store the component ID for each node + threshold : Optional[float], default=None + The minimum required weight to consider a relationship during traversal + relationship_types : Optional[List[str]], default=None + The relationship types to project + node_labels : Optional[List[str]], default=None + The node labels to project + sudo : Optional[bool], default=None + Run analysis with admin permission + log_progress : Optional[bool], default=None + Whether to log progress + username : Optional[str], default=None + The username to attribute the procedure run to + concurrency : Optional[Any], default=None + The number of concurrent threads + job_id : Optional[Any], default=None + An identifier for the job + seed_property : Optional[str], default=None + Defines node properties that are used as initial component identifiers + consecutive_ids : Optional[bool], default=None + Flag to decide whether component identifiers are mapped into a consecutive id space + relationship_weight_property : Optional[str], default=None + The property name that contains weight + + Returns + ------- + WccMutateResult + Algorithm metrics and statistics + """ + pass + + @abstractmethod + def stats( + self, + G: Graph, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + ) -> WccStatsResult: + """ + Executes the WCC algorithm and returns statistics. + + Parameters + ---------- + G : Graph + The graph to run the algorithm on + threshold : Optional[float], default=None + The minimum required weight to consider a relationship during traversal + relationship_types : Optional[List[str]], default=None + The relationship types to project + node_labels : Optional[List[str]], default=None + The node labels to project + sudo : Optional[bool], default=None + Run analysis with admin permission + log_progress : Optional[bool], default=None + Whether to log progress + username : Optional[str], default=None + The username to attribute the procedure run to + concurrency : Optional[Any], default=None + The number of concurrent threads + job_id : Optional[Any], default=None + An identifier for the job + seed_property : Optional[str], default=None + Defines node properties that are used as initial component identifiers + consecutive_ids : Optional[bool], default=None + Flag to decide whether component identifiers are mapped into a consecutive id space + relationship_weight_property : Optional[str], default=None + The property name that contains weight + + Returns + ------- + WccStatsResult + Algorithm metrics and statistics + """ + pass + + @abstractmethod + def stream( + self, + G: Graph, + min_component_size: Optional[int] = None, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + ) -> DataFrame: + """ + Executes the WCC algorithm and returns a stream of results. + + Parameters + ---------- + G : Graph + The graph to run the algorithm on + min_component_size : Optional[int], default=None + Don't stream components with fewer nodes than this + threshold : Optional[float], default=None + The minimum required weight to consider a relationship during traversal + relationship_types : Optional[List[str]], default=None + The relationship types to project + node_labels : Optional[List[str]], default=None + The node labels to project + sudo : Optional[bool], default=None + Run analysis with admin permission + log_progress : Optional[bool], default=None + Whether to log progress + username : Optional[str], default=None + The username to attribute the procedure run to + concurrency : Optional[Any], default=None + The number of concurrent threads + job_id : Optional[Any], default=None + An identifier for the job + seed_property : Optional[str], default=None + Defines node properties that are used as initial component identifiers + consecutive_ids : Optional[bool], default=None + Flag to decide whether component identifiers are mapped into a consecutive id space + relationship_weight_property : Optional[str], default=None + The property name that contains weight + + Returns + ------- + DataFrame + DataFrame with the algorithm results + """ + pass + + @abstractmethod + def write( + self, + G: Graph, + write_property: str, + min_component_size: Optional[int] = None, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + write_concurrency: Optional[Any] = None, + ) -> WccWriteResult: + """ + Executes the WCC algorithm and writes the results to the Neo4j database. + + Parameters + ---------- + G : Graph + The graph to run the algorithm on + write_property : str + The property name to write component IDs to + min_component_size : Optional[int], default=None + Don't write components with fewer nodes than this + threshold : Optional[float], default=None + The minimum required weight to consider a relationship during traversal + relationship_types : Optional[List[str]], default=None + The relationship types to project + node_labels : Optional[List[str]], default=None + The node labels to project + sudo : Optional[bool], default=None + Run analysis with admin permission + log_progress : Optional[bool], default=None + Whether to log progress + username : Optional[str], default=None + The username to attribute the procedure run to + concurrency : Optional[Any], default=None + The number of concurrent threads + job_id : Optional[Any], default=None + An identifier for the job + seed_property : Optional[str], default=None + Defines node properties that are used as initial component identifiers + consecutive_ids : Optional[bool], default=None + Flag to decide whether component identifiers are mapped into a consecutive id space + relationship_weight_property : Optional[str], default=None + The property name that contains weight + write_concurrency : Optional[Any], default=None + The number of concurrent threads during the write phase + + Returns + ------- + WccWriteResult + Algorithm metrics and statistics + """ + pass + + +@dataclass(frozen=True, repr=True) +class WccMutateResult: + component_count: int + component_distribution: dict[str, Any] + pre_processing_millis: int + compute_millis: int + post_processing_millis: int + mutate_millis: int + node_properties_written: int + configuration: dict[str, Any] + + def __getitem__(self, item: str) -> Any: + return getattr(self, item) + + +@dataclass(frozen=True, repr=True) +class WccStatsResult: + component_count: int + component_distribution: dict[str, Any] + pre_processing_millis: int + compute_millis: int + post_processing_millis: int + configuration: dict[str, Any] + + def __getitem__(self, item: str) -> Any: + return getattr(self, item) + + +@dataclass(frozen=True, repr=True) +class WccWriteResult: + component_count: int + component_distribution: dict[str, Any] + pre_processing_millis: int + compute_millis: int + write_millis: int + post_processing_millis: int + node_properties_written: int + configuration: dict[str, Any] + + def __getitem__(self, item: str) -> Any: + return getattr(self, item) diff --git a/graphdatascience/procedure_surface/arrow/__init__.py b/graphdatascience/procedure_surface/arrow/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py new file mode 100644 index 000000000..8261f56ba --- /dev/null +++ b/graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py @@ -0,0 +1,195 @@ +from typing import List, Optional + +from pandas import DataFrame + +from graphdatascience.procedure_surface.config_converter import ConfigConverter + +from ...arrow_client.authenticated_arrow_client import AuthenticatedArrowClient +from ...arrow_client.v2.job_client import JobClient +from ...arrow_client.v2.mutation_client import MutationClient +from ...arrow_client.v2.write_back_client import WriteBackClient +from ...graph.graph_object import Graph +from ..api.wcc_endpoints import WccEndpoints, WccMutateResult, WccStatsResult, WccWriteResult + +WCC_ENDPOINT = "v2/community.wcc" + + +class WccArrowEndpoints(WccEndpoints): + def __init__(self, arrow_client: AuthenticatedArrowClient, write_back_client: Optional[WriteBackClient]): + self._arrow_client = arrow_client + self._write_back_client = write_back_client + + def mutate( + self, + G: Graph, + mutate_property: str, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[int] = None, + job_id: Optional[str] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + ) -> WccMutateResult: + config = ConfigConverter.convert_to_gds_config( + graph_name=G.name(), + concurrency=concurrency, + consecutive_ids=consecutive_ids, + job_id=job_id, + log_progress=log_progress, + node_labels=node_labels, + relationship_types=relationship_types, + relationship_weight_property=relationship_weight_property, + seed_property=seed_property, + sudo=sudo, + threshold=threshold, + ) + + job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config) + + mutate_result = MutationClient.mutate_node_property(self._arrow_client, job_id, mutate_property) + computation_result = JobClient.get_summary(self._arrow_client, job_id) + + return WccMutateResult( + computation_result["componentCount"], + computation_result["componentDistribution"], + computation_result["preProcessingMillis"], + computation_result["computeMillis"], + computation_result["postProcessingMillis"], + 0, + mutate_result.nodePropertiesWritten, + computation_result["configuration"], + ) + + def stats( + self, + G: Graph, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[int] = None, + job_id: Optional[str] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + ) -> WccStatsResult: + config = ConfigConverter.convert_to_gds_config( + graph_name=G.name(), + concurrency=concurrency, + consecutive_ids=consecutive_ids, + job_id=job_id, + log_progress=log_progress, + node_labels=node_labels, + relationship_types=relationship_types, + relationship_weight_property=relationship_weight_property, + seed_property=seed_property, + sudo=sudo, + threshold=threshold, + ) + + job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config) + computation_result = JobClient.get_summary(self._arrow_client, job_id) + + return WccStatsResult( + computation_result["componentCount"], + computation_result["componentDistribution"], + computation_result["preProcessingMillis"], + computation_result["computeMillis"], + computation_result["postProcessingMillis"], + computation_result["configuration"], + ) + + def stream( + self, + G: Graph, + min_component_size: Optional[int] = None, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[int] = None, + job_id: Optional[str] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + ) -> DataFrame: + config = ConfigConverter.convert_to_gds_config( + graph_name=G.name(), + concurrency=concurrency, + consecutive_ids=consecutive_ids, + job_id=job_id, + log_progress=log_progress, + min_component_size=min_component_size, + node_labels=node_labels, + relationship_types=relationship_types, + relationship_weight_property=relationship_weight_property, + seed_property=seed_property, + sudo=sudo, + threshold=threshold, + ) + + job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config) + return JobClient.stream_results(self._arrow_client, job_id) + + def write( + self, + G: Graph, + write_property: str, + min_component_size: Optional[int] = None, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[int] = None, + job_id: Optional[str] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + write_concurrency: Optional[int] = None, + ) -> WccWriteResult: + config = ConfigConverter.convert_to_gds_config( + graph_name=G.name(), + concurrency=concurrency, + consecutive_ids=consecutive_ids, + job_id=job_id, + log_progress=log_progress, + min_component_size=min_component_size, + node_labels=node_labels, + relationship_types=relationship_types, + relationship_weight_property=relationship_weight_property, + seed_property=seed_property, + sudo=sudo, + threshold=threshold, + ) + + job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config) + computation_result = JobClient.get_summary(self._arrow_client, job_id) + + if self._write_back_client is None: + raise Exception("Write back client is not initialized") + + write_millis = self._write_back_client.write( + G.name(), job_id, write_concurrency if write_concurrency is not None else concurrency + ) + + return WccWriteResult( + computation_result["componentCount"], + computation_result["componentDistribution"], + computation_result["preProcessingMillis"], + computation_result["computeMillis"], + write_millis, + computation_result["postProcessingMillis"], + computation_result["nodePropertiesWritten"], + computation_result["configuration"], + ) diff --git a/graphdatascience/procedure_surface/config_converter.py b/graphdatascience/procedure_surface/config_converter.py new file mode 100644 index 000000000..c6da60355 --- /dev/null +++ b/graphdatascience/procedure_surface/config_converter.py @@ -0,0 +1,33 @@ +from typing import Any, Dict, Optional + + +class ConfigConverter: + @staticmethod + def convert_to_gds_config(**kwargs: Optional[Any]) -> dict[str, Any]: + config: dict[str, Any] = {} + + # Process kwargs + processed_kwargs = ConfigConverter._process_dict_values(kwargs) + config.update(processed_kwargs) + + return config + + @staticmethod + def _convert_to_camel_case(name: str) -> str: + """Convert a snake_case string to camelCase.""" + parts = name.split("_") + return "".join([word.capitalize() if i > 0 else word.lower() for i, word in enumerate(parts)]) + + @staticmethod + def _process_dict_values(input_dict: Dict[str, Any]) -> Dict[str, Any]: + """Process dictionary values, converting keys to camelCase and handling nested dictionaries.""" + result = {} + for key, value in input_dict.items(): + if value is not None: + camel_key = ConfigConverter._convert_to_camel_case(key) + # Recursively process nested dictionaries + if isinstance(value, dict): + result[camel_key] = ConfigConverter._process_dict_values(value) + else: + result[camel_key] = value + return result diff --git a/graphdatascience/procedure_surface/cypher/__init__.py b/graphdatascience/procedure_surface/cypher/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/procedure_surface/cypher/wcc_cypher_endpoints.py b/graphdatascience/procedure_surface/cypher/wcc_cypher_endpoints.py new file mode 100644 index 000000000..7bebd3788 --- /dev/null +++ b/graphdatascience/procedure_surface/cypher/wcc_cypher_endpoints.py @@ -0,0 +1,201 @@ +from typing import List, Optional + +from pandas import DataFrame + +from ...call_parameters import CallParameters +from ...graph.graph_object import Graph +from ...query_runner.query_runner import QueryRunner +from ..api.wcc_endpoints import WccEndpoints, WccMutateResult, WccStatsResult, WccWriteResult +from ..config_converter import ConfigConverter + + +class WccCypherEndpoints(WccEndpoints): + """ + Implementation of the WCC algorithm endpoints. + This class handles the actual execution by forwarding calls to the query runner. + """ + + def __init__(self, query_runner: QueryRunner): + self._query_runner = query_runner + + def mutate( + self, + G: Graph, + mutate_property: str, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[int] = None, + job_id: Optional[str] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + ) -> WccMutateResult: + config = ConfigConverter.convert_to_gds_config( + mutate_property=mutate_property, + concurrency=concurrency, + consecutive_ids=consecutive_ids, + job_id=job_id, + log_progress=log_progress, + node_labels=node_labels, + relationship_types=relationship_types, + relationship_weight_property=relationship_weight_property, + seed_property=seed_property, + sudo=sudo, + threshold=threshold, + username=username, + ) + + # Run procedure and return results + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + cypher_result = self._query_runner.call_procedure(endpoint="gds.wcc.mutate", params=params).squeeze() + + return WccMutateResult( + cypher_result["componentCount"], + cypher_result["componentDistribution"], + cypher_result["preProcessingMillis"], + cypher_result["computeMillis"], + cypher_result["postProcessingMillis"], + cypher_result["mutateMillis"], + cypher_result["nodePropertiesWritten"], + cypher_result["configuration"], + ) + + def stats( + self, + G: Graph, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[int] = None, + job_id: Optional[str] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + ) -> WccStatsResult: + config = ConfigConverter.convert_to_gds_config( + concurrency=concurrency, + consecutive_ids=consecutive_ids, + job_id=job_id, + log_progress=log_progress, + node_labels=node_labels, + relationship_types=relationship_types, + relationship_weight_property=relationship_weight_property, + seed_property=seed_property, + sudo=sudo, + threshold=threshold, + username=username, + ) + + # Run procedure and return results + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + cypher_result = self._query_runner.call_procedure(endpoint="gds.wcc.stats", params=params).squeeze() # type: ignore + + return WccStatsResult( + cypher_result["componentCount"], + cypher_result["componentDistribution"], + cypher_result["preProcessingMillis"], + cypher_result["computeMillis"], + cypher_result["postProcessingMillis"], + cypher_result["configuration"], + ) + + def stream( + self, + G: Graph, + min_component_size: Optional[int] = None, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[int] = None, + job_id: Optional[str] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + ) -> DataFrame: + config = ConfigConverter.convert_to_gds_config( + concurrency=concurrency, + consecutive_ids=consecutive_ids, + job_id=job_id, + log_progress=log_progress, + min_component_size=min_component_size, + node_labels=node_labels, + relationship_types=relationship_types, + relationship_weight_property=relationship_weight_property, + seed_property=seed_property, + sudo=sudo, + threshold=threshold, + username=username, + ) + + # Run procedure and return results + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + return self._query_runner.call_procedure(endpoint="gds.wcc.stream", params=params) + + def write( + self, + G: Graph, + write_property: str, + min_component_size: Optional[int] = None, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[int] = None, + job_id: Optional[str] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + write_concurrency: Optional[int] = None, + ) -> WccWriteResult: + config = ConfigConverter.convert_to_gds_config( + write_property=write_property, + concurrency=concurrency, + consecutive_ids=consecutive_ids, + job_id=job_id, + log_progress=log_progress, + min_component_size=min_component_size, + node_labels=node_labels, + relationship_types=relationship_types, + relationship_weight_property=relationship_weight_property, + seed_property=seed_property, + sudo=sudo, + threshold=threshold, + username=username, + ) + + if write_concurrency is not None: + config["writeConcurrency"] = write_concurrency + + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + result = self._query_runner.call_procedure(endpoint="gds.wcc.write", params=params).squeeze() # type: ignore + + return WccWriteResult( + result["componentCount"], + result["componentDistribution"], + result["preProcessingMillis"], + result["computeMillis"], + result["writeMillis"], + result["postProcessingMillis"], + result["nodePropertiesWritten"], + result["configuration"], + ) diff --git a/graphdatascience/query_runner/arrow_query_runner.py b/graphdatascience/query_runner/arrow_query_runner.py index 36214fb97..c3ed819a2 100644 --- a/graphdatascience/query_runner/arrow_query_runner.py +++ b/graphdatascience/query_runner/arrow_query_runner.py @@ -5,12 +5,12 @@ from pandas import DataFrame -from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication +from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication +from graphdatascience.arrow_client.arrow_info import ArrowInfo from graphdatascience.query_runner.query_mode import QueryMode from graphdatascience.retry_utils.retry_config import RetryConfig from ..call_parameters import CallParameters -from ..query_runner.arrow_info import ArrowInfo from ..server_version.server_version import ServerVersion from .arrow_graph_constructor import ArrowGraphConstructor from .gds_arrow_client import GdsArrowClient diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index 64f1988ad..cd60e2a20 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -35,14 +35,14 @@ wait_exponential, ) -from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication, UsernamePasswordAuthentication +from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication, UsernamePasswordAuthentication +from graphdatascience.arrow_client.arrow_info import ArrowInfo from graphdatascience.retry_utils.retry_config import RetryConfig from graphdatascience.retry_utils.retry_utils import before_log from ..semantic_version.semantic_version import SemanticVersion from ..version import __version__ from .arrow_endpoint_version import ArrowEndpointVersion -from .arrow_info import ArrowInfo class GdsArrowClient: diff --git a/graphdatascience/query_runner/protocol/project_protocols.py b/graphdatascience/query_runner/protocol/project_protocols.py index 46eb802b8..6f84f1ea3 100644 --- a/graphdatascience/query_runner/protocol/project_protocols.py +++ b/graphdatascience/query_runner/protocol/project_protocols.py @@ -25,7 +25,7 @@ def project_params( def run_projection( self, query_runner: QueryRunner, - endpoint: str, + query: str, params: CallParameters, terminationFlag: TerminationFlag, yields: Optional[list[str]] = None, diff --git a/graphdatascience/session/aura_api_token_authentication.py b/graphdatascience/session/aura_api_token_authentication.py index 37e51bfed..44e5780b7 100644 --- a/graphdatascience/session/aura_api_token_authentication.py +++ b/graphdatascience/session/aura_api_token_authentication.py @@ -1,4 +1,4 @@ -from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication +from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication from graphdatascience.session.aura_api import AuraApi diff --git a/graphdatascience/session/aura_graph_data_science.py b/graphdatascience/session/aura_graph_data_science.py index 565df1266..c30b7ffdd 100644 --- a/graphdatascience/session/aura_graph_data_science.py +++ b/graphdatascience/session/aura_graph_data_science.py @@ -1,10 +1,15 @@ from __future__ import annotations +import os from typing import Any, Callable, Optional, Union from pandas import DataFrame from graphdatascience import QueryRunner, ServerVersion +from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication +from graphdatascience.arrow_client.arrow_info import ArrowInfo +from graphdatascience.arrow_client.authenticated_arrow_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.v2.write_back_client import WriteBackClient from graphdatascience.call_builder import IndirectCallBuilder from graphdatascience.endpoints import ( AlphaRemoteEndpoints, @@ -13,8 +18,8 @@ ) from graphdatascience.error.uncallable_namespace import UncallableNamespace from graphdatascience.graph.graph_remote_proc_runner import GraphRemoteProcRunner -from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication -from graphdatascience.query_runner.arrow_info import ArrowInfo +from graphdatascience.procedure_surface.api.wcc_endpoints import WccEndpoints +from graphdatascience.procedure_surface.arrow.wcc_arrow_endpoints import WccArrowEndpoints from graphdatascience.query_runner.arrow_query_runner import ArrowQueryRunner from graphdatascience.query_runner.gds_arrow_client import GdsArrowClient from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner @@ -24,15 +29,11 @@ from graphdatascience.utils.util_remote_proc_runner import UtilRemoteProcRunner -class AuraGraphDataScience(DirectEndpoints, UncallableNamespace): - """ - Primary API class for interacting with Neo4j database + Graph Data Science Session. - Always bind this object to a variable called `gds`. - """ +class AuraGraphDataScienceFactory: + """Factory class for creating AuraGraphDataScience instances with all required components.""" - @classmethod - def create( - cls, + def __init__( + self, session_bolt_connection_info: DbmsConnectionInfo, arrow_authentication: Optional[ArrowAuthentication], db_endpoint: Optional[Union[Neo4jQueryRunner, DbmsConnectionInfo]], @@ -41,74 +42,114 @@ def create( arrow_tls_root_certs: Optional[bytes] = None, bookmarks: Optional[Any] = None, show_progress: bool = True, - ) -> AuraGraphDataScience: - session_bolt_query_runner = Neo4jQueryRunner.create_for_session( - endpoint=session_bolt_connection_info.uri, - auth=session_bolt_connection_info.get_auth(), - show_progress=show_progress, + ): + self.session_bolt_connection_info = session_bolt_connection_info + self.arrow_authentication = arrow_authentication + self.db_endpoint = db_endpoint + self.delete_fn = delete_fn + self.arrow_disable_server_verification = arrow_disable_server_verification + self.arrow_tls_root_certs = arrow_tls_root_certs + self.bookmarks = bookmarks + self.show_progress = show_progress + + def create(self) -> AuraGraphDataScience: + """Create and configure an AuraGraphDataScience instance.""" + session_bolt_query_runner = self._create_session_bolt_query_runner() + arrow_info = ArrowInfo.create(session_bolt_query_runner) + session_arrow_query_runner = self._create_session_arrow_query_runner(session_bolt_query_runner, arrow_info) + session_arrow_client = self._create_session_arrow_client(arrow_info, session_bolt_query_runner) + gds_version = session_bolt_query_runner.server_version() + + session_query_runner: QueryRunner + + if self.db_endpoint is not None: + db_bolt_query_runner = self._create_db_bolt_query_runner() + session_query_runner = SessionQueryRunner.create( + session_arrow_query_runner, db_bolt_query_runner, session_arrow_client, self.show_progress + ) + wcc_endpoints = self._create_wcc_endpoints(arrow_info, session_bolt_query_runner, db_bolt_query_runner) + else: + session_query_runner = StandaloneSessionQueryRunner(session_arrow_query_runner) + wcc_endpoints = self._create_wcc_endpoints(arrow_info, session_bolt_query_runner, None) + + return AuraGraphDataScience( + query_runner=session_query_runner, + wcc_endpoints=wcc_endpoints, + delete_fn=self.delete_fn, + gds_version=gds_version, ) - arrow_info = ArrowInfo.create(session_bolt_query_runner) - session_arrow_query_runner = ArrowQueryRunner.create( + def _create_session_bolt_query_runner(self) -> Neo4jQueryRunner: + return Neo4jQueryRunner.create_for_session( + endpoint=self.session_bolt_connection_info.uri, + auth=self.session_bolt_connection_info.get_auth(), + show_progress=self.show_progress, + ) + + def _create_session_arrow_query_runner( + self, session_bolt_query_runner: Neo4jQueryRunner, arrow_info: ArrowInfo + ) -> ArrowQueryRunner: + return ArrowQueryRunner.create( fallback_query_runner=session_bolt_query_runner, arrow_info=arrow_info, - arrow_authentication=arrow_authentication, + arrow_authentication=self.arrow_authentication, encrypted=session_bolt_query_runner.encrypted(), - disable_server_verification=arrow_disable_server_verification, - tls_root_certs=arrow_tls_root_certs, + disable_server_verification=self.arrow_disable_server_verification, + tls_root_certs=self.arrow_tls_root_certs, ) - # TODO: merge with the gds_arrow_client created inside ArrowQueryRunner - session_arrow_client = GdsArrowClient.create( + def _create_session_arrow_client( + self, arrow_info: ArrowInfo, session_bolt_query_runner: Neo4jQueryRunner + ) -> GdsArrowClient: + return GdsArrowClient.create( arrow_info, - arrow_authentication, + self.arrow_authentication, session_bolt_query_runner.encrypted(), - arrow_disable_server_verification, - arrow_tls_root_certs, + self.arrow_disable_server_verification, + self.arrow_tls_root_certs, ) - gds_version = session_bolt_query_runner.server_version() - - if db_endpoint is not None: - if isinstance(db_endpoint, Neo4jQueryRunner): - db_bolt_query_runner = db_endpoint - else: - db_bolt_query_runner = Neo4jQueryRunner.create_for_db( - db_endpoint.uri, - db_endpoint.get_auth(), - aura_ds=True, - show_progress=False, - database=db_endpoint.database, - ) - db_bolt_query_runner.set_bookmarks(bookmarks) - - session_query_runner = SessionQueryRunner.create( - session_arrow_query_runner, db_bolt_query_runner, session_arrow_client, show_progress - ) - return cls( - query_runner=session_query_runner, - delete_fn=delete_fn, - gds_version=gds_version, + def _create_db_bolt_query_runner(self) -> Neo4jQueryRunner: + if isinstance(self.db_endpoint, Neo4jQueryRunner): + db_bolt_query_runner = self.db_endpoint + elif isinstance(self.db_endpoint, DbmsConnectionInfo): + db_bolt_query_runner = Neo4jQueryRunner.create_for_db( + self.db_endpoint.uri, + self.db_endpoint.get_auth(), + aura_ds=True, + show_progress=False, + database=self.db_endpoint.database, ) else: - standalone_query_runner = StandaloneSessionQueryRunner(session_arrow_query_runner) - return cls( - query_runner=standalone_query_runner, - delete_fn=delete_fn, - gds_version=gds_version, + raise ValueError("db_endpoint must be a Neo4jQueryRunner or a DbmsConnectionInfo") + + db_bolt_query_runner.set_bookmarks(self.bookmarks) + return db_bolt_query_runner + + def _create_wcc_endpoints( + self, arrow_info: ArrowInfo, session_bolt_query_runner: Neo4jQueryRunner, db_query_runner: Optional[QueryRunner] + ) -> Optional[WccEndpoints]: + wcc_endpoints: Optional[WccEndpoints] = None + if os.environ.get("ENABLE_EXPLICIT_ENDPOINTS") is not None: + arrow_client = AuthenticatedArrowClient.create( + arrow_info, + self.arrow_authentication, + session_bolt_query_runner.encrypted(), + self.arrow_disable_server_verification, + self.arrow_tls_root_certs, ) - def __init__( - self, - query_runner: QueryRunner, - delete_fn: Callable[[], bool], - gds_version: ServerVersion, - ): - self._query_runner = query_runner - self._delete_fn = delete_fn - self._server_version = gds_version + write_back_client = WriteBackClient(arrow_client, db_query_runner) if db_query_runner is not None else None - super().__init__(self._query_runner, namespace="gds", server_version=self._server_version) + wcc_endpoints = WccArrowEndpoints(arrow_client, write_back_client) + return wcc_endpoints + + +class AuraGraphDataScience(DirectEndpoints, UncallableNamespace): + """ + Primary API class for interacting with Neo4j database + Graph Data Science Session. + Always bind this object to a variable called `gds`. + """ def run_cypher( self, @@ -133,6 +174,20 @@ def run_cypher( """ return self._query_runner.run_cypher(query, params, database, False) + def __init__( + self, + query_runner: QueryRunner, + delete_fn: Callable[[], bool], + gds_version: ServerVersion, + wcc_endpoints: Optional[WccEndpoints] = None, + ): + self._query_runner = query_runner + self._delete_fn = delete_fn + self._server_version = gds_version + self._wcc_endpoints = wcc_endpoints + + super().__init__(self._query_runner, namespace="gds", server_version=self._server_version) + @property def graph(self) -> GraphRemoteProcRunner: return GraphRemoteProcRunner(self._query_runner, f"{self._namespace}.graph", self._server_version) @@ -149,6 +204,13 @@ def alpha(self) -> AlphaRemoteEndpoints: def beta(self) -> BetaEndpoints: return BetaEndpoints(self._query_runner, "gds.beta", self._server_version) + @property + def wcc(self) -> Union[WccEndpoints, IndirectCallBuilder]: + if self._wcc_endpoints is None: + return IndirectCallBuilder(self._query_runner, f"gds.{self._namespace}.wcc", self._server_version) + + return self._wcc_endpoints + def __getattr__(self, attr: str) -> IndirectCallBuilder: return IndirectCallBuilder(self._query_runner, f"gds.{attr}", self._server_version) diff --git a/graphdatascience/session/dedicated_sessions.py b/graphdatascience/session/dedicated_sessions.py index 8babac7a0..b95399c2c 100644 --- a/graphdatascience/session/dedicated_sessions.py +++ b/graphdatascience/session/dedicated_sessions.py @@ -5,13 +5,13 @@ from datetime import datetime, timedelta, timezone from typing import Any, Optional -from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication +from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner from graphdatascience.session.algorithm_category import AlgorithmCategory from graphdatascience.session.aura_api import AuraApi from graphdatascience.session.aura_api_responses import SessionDetails from graphdatascience.session.aura_api_token_authentication import AuraApiTokenAuthentication -from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience +from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience, AuraGraphDataScienceFactory from graphdatascience.session.cloud_location import CloudLocation from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo from graphdatascience.session.session_info import SessionInfo @@ -210,9 +210,9 @@ def _construct_client( arrow_authentication: ArrowAuthentication, db_runner: Optional[Neo4jQueryRunner], ) -> AuraGraphDataScience: - return AuraGraphDataScience.create( + return AuraGraphDataScienceFactory( session_bolt_connection_info=session_bolt_connection_info, arrow_authentication=arrow_authentication, db_endpoint=db_runner, delete_fn=lambda: self._aura_api.delete_session(session_id=session_id), - ) + ).create() diff --git a/graphdatascience/tests/integration/conftest.py b/graphdatascience/tests/integration/conftest.py index c520992f7..ab4ad6822 100644 --- a/graphdatascience/tests/integration/conftest.py +++ b/graphdatascience/tests/integration/conftest.py @@ -6,11 +6,11 @@ import pytest from neo4j import Driver, GraphDatabase +from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication from graphdatascience.graph_data_science import GraphDataScience -from graphdatascience.query_runner.arrow_authentication import UsernamePasswordAuthentication from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner from graphdatascience.server_version.server_version import ServerVersion -from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience +from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience, AuraGraphDataScienceFactory from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687") @@ -92,12 +92,13 @@ def gds_without_arrow() -> Generator[GraphDataScience, None, None]: @pytest.fixture(scope="package", autouse=False) def gds_with_cloud_setup(request: pytest.FixtureRequest) -> Generator[AuraGraphDataScience, None, None]: - _gds = AuraGraphDataScience.create( + _gds = AuraGraphDataScienceFactory( session_bolt_connection_info=DbmsConnectionInfo(URI, AUTH[0], AUTH[1]), arrow_authentication=UsernamePasswordAuthentication(AUTH[0], AUTH[1]), db_endpoint=DbmsConnectionInfo(AURA_DB_URI, AURA_DB_AUTH[0], AURA_DB_AUTH[1]), delete_fn=lambda: True, - ) + ).create() + _gds.set_database(DB) yield _gds @@ -107,12 +108,12 @@ def gds_with_cloud_setup(request: pytest.FixtureRequest) -> Generator[AuraGraphD @pytest.fixture(scope="package", autouse=False) def standalone_aura_gds() -> Generator[AuraGraphDataScience, None, None]: - _gds = AuraGraphDataScience.create( + _gds = AuraGraphDataScienceFactory( session_bolt_connection_info=DbmsConnectionInfo(URI, AUTH[0], AUTH[1]), arrow_authentication=UsernamePasswordAuthentication(AUTH[0], AUTH[1]), db_endpoint=None, delete_fn=lambda: True, - ) + ).create() yield _gds diff --git a/graphdatascience/tests/integration/test_remote_graph_ops.py b/graphdatascience/tests/integration/test_remote_graph_ops.py index 7eba37c64..ad5ea078f 100644 --- a/graphdatascience/tests/integration/test_remote_graph_ops.py +++ b/graphdatascience/tests/integration/test_remote_graph_ops.py @@ -62,7 +62,7 @@ def test_remote_projection_and_writeback_custom_database_name(gds_with_cloud_set assert projection_result["nodeCount"] == 2 assert projection_result["relationshipCount"] == 1 - write_result = gds_with_cloud_setup.wcc.write(G, writeProperty="wcc") + write_result = gds_with_cloud_setup.wcc.write(G, writeProperty="wcc") # type: ignore assert write_result["nodePropertiesWritten"] == 2 count_wcc_nodes_query = "MATCH (n WHERE n.wcc IS NOT NULL) RETURN count(*) AS c" @@ -234,6 +234,6 @@ def test_empty_graph_write_back( assert G.node_count() == 0 - result = gds_with_cloud_setup.wcc.write(G, writeProperty="wcc") + result = gds_with_cloud_setup.wcc.write(G, writeProperty="wcc") # type: ignore assert result["nodePropertiesWritten"] == 0 diff --git a/graphdatascience/tests/integration/test_simple_algo.py b/graphdatascience/tests/integration/test_simple_algo.py index c7fbbba1d..b9fd0297d 100644 --- a/graphdatascience/tests/integration/test_simple_algo.py +++ b/graphdatascience/tests/integration/test_simple_algo.py @@ -57,7 +57,7 @@ def test_wcc_stats(gds: GraphDataScience) -> None: def test_wcc_stats_estimate(gds: GraphDataScience) -> None: G, _ = gds.graph.project(GRAPH_NAME, "*", "*") - result = gds.wcc.stats.estimate(G) + result = gds.wcc.stats.estimate(G) # type: ignore assert result["requiredMemory"] diff --git a/graphdatascience/tests/unit/arrow_client/__init__.py b/graphdatascience/tests/unit/arrow_client/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/tests/unit/arrow_client/test_data_mapper.py b/graphdatascience/tests/unit/arrow_client/test_data_mapper.py new file mode 100644 index 000000000..e4e172597 --- /dev/null +++ b/graphdatascience/tests/unit/arrow_client/test_data_mapper.py @@ -0,0 +1,44 @@ +from dataclasses import dataclass +from typing import Any, Dict + +import pytest + +from graphdatascience.arrow_client.data_mapper import DataMapper + + +@dataclass +class NestedDataclass: + nested_field: int + + +@dataclass +class ExampleDataclass: + field_one: str + field_two: int + nested: NestedDataclass + + +@pytest.mark.parametrize( + "input_data, expected_output", + [ + ( + {"field_one": "test", "field_two": 123, "nested": {"nested_field": 456}}, + ExampleDataclass("test", 123, NestedDataclass(456)), + ), + ], +) +def test_dict_to_dataclass(input_data: Dict[str, Any], expected_output: ExampleDataclass) -> None: + result = DataMapper.dict_to_dataclass(input_data, ExampleDataclass) + assert result == expected_output + + +def test_dict_to_dataclass_strict_mode_rejects_extra_fields() -> None: + input_data = {"field_one": "test", "field_two": 123, "nested": {"nested_field": 456}, "extra_field": "not_allowed"} + + with pytest.raises(ValueError, match="Extra field 'extra_field' not allowed in ExampleDataclass"): + DataMapper.dict_to_dataclass(input_data, ExampleDataclass, strict=True) + + +def test_dict_to_dataclass_non_dataclass_error() -> None: + with pytest.raises(ValueError, match="is not a dataclass"): + DataMapper.dict_to_dataclass({"key": "value"}, int) diff --git a/graphdatascience/tests/unit/conftest.py b/graphdatascience/tests/unit/conftest.py index 7a8964c5c..01bd31301 100644 --- a/graphdatascience/tests/unit/conftest.py +++ b/graphdatascience/tests/unit/conftest.py @@ -8,17 +8,17 @@ from pytest_mock import MockerFixture from graphdatascience import QueryRunner +from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication +from graphdatascience.arrow_client.arrow_info import ArrowInfo from graphdatascience.call_parameters import CallParameters from graphdatascience.graph_data_science import GraphDataScience -from graphdatascience.query_runner.arrow_authentication import UsernamePasswordAuthentication -from graphdatascience.query_runner.arrow_info import ArrowInfo from graphdatascience.query_runner.cypher_graph_constructor import ( CypherGraphConstructor, ) from graphdatascience.query_runner.graph_constructor import GraphConstructor from graphdatascience.query_runner.query_mode import QueryMode from graphdatascience.server_version.server_version import ServerVersion -from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience +from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience, AuraGraphDataScienceFactory from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo # Should mirror the latest GDS server version under development. @@ -194,12 +194,12 @@ def aura_gds(runner: CollectingQueryRunner, mocker: MockerFixture) -> Generator[ mocker.patch("graphdatascience.query_runner.arrow_query_runner.ArrowQueryRunner.create", return_value=runner) mocker.patch("graphdatascience.query_runner.gds_arrow_client.GdsArrowClient.create", return_value=None) - aura_gds = AuraGraphDataScience.create( + aura_gds = AuraGraphDataScienceFactory( session_bolt_connection_info=DbmsConnectionInfo("address", "some", "auth"), arrow_authentication=UsernamePasswordAuthentication("some", "auth"), db_endpoint=DbmsConnectionInfo("address", "some", "auth"), delete_fn=lambda: True, - ) + ).create() yield aura_gds aura_gds.close() diff --git a/graphdatascience/tests/unit/procedure_surface/__init__.py b/graphdatascience/tests/unit/procedure_surface/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/tests/unit/procedure_surface/arrow/__init__.py b/graphdatascience/tests/unit/procedure_surface/arrow/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/tests/unit/procedure_surface/cypher/test_wcc_cypher_endpoints.py b/graphdatascience/tests/unit/procedure_surface/cypher/test_wcc_cypher_endpoints.py new file mode 100644 index 000000000..519825390 --- /dev/null +++ b/graphdatascience/tests/unit/procedure_surface/cypher/test_wcc_cypher_endpoints.py @@ -0,0 +1,322 @@ +import pandas as pd +import pytest + +from graphdatascience.graph.graph_object import Graph +from graphdatascience.procedure_surface.api.wcc_endpoints import WccMutateResult, WccStatsResult, WccWriteResult +from graphdatascience.procedure_surface.cypher.wcc_cypher_endpoints import WccCypherEndpoints +from graphdatascience.tests.unit.conftest import DEFAULT_SERVER_VERSION, CollectingQueryRunner + + +@pytest.fixture +def query_runner() -> CollectingQueryRunner: + return CollectingQueryRunner(DEFAULT_SERVER_VERSION) + + +@pytest.fixture +def wcc_endpoints(query_runner: CollectingQueryRunner) -> WccCypherEndpoints: + return WccCypherEndpoints(query_runner) + + +@pytest.fixture +def graph(query_runner: CollectingQueryRunner) -> Graph: + return Graph("test_graph", query_runner) + + +def test_mutate_basic(graph: Graph) -> None: + result = { + "nodePropertiesWritten": 5, + "mutateMillis": 42, + "componentCount": 3, + "preProcessingMillis": 10, + "computeMillis": 20, + "postProcessingMillis": 12, + "componentDistribution": {"foo": 42}, + "configuration": {"bar": 1337}, + } + + query_runner = CollectingQueryRunner(DEFAULT_SERVER_VERSION, {"wcc.mutate": pd.DataFrame([result])}) + + result_obj = WccCypherEndpoints(query_runner).mutate(graph, "componentId") + + assert len(query_runner.queries) == 1 + assert "gds.wcc.mutate" in query_runner.queries[0] + params = query_runner.params[0] + assert params["graph_name"] == "test_graph" + config = params["config"] + assert config["mutateProperty"] == "componentId" + assert "jobId" in config + + assert isinstance(result_obj, WccMutateResult) + assert result_obj.node_properties_written == 5 + assert result_obj.mutate_millis == 42 + assert result_obj.component_count == 3 + assert result_obj.pre_processing_millis == 10 + assert result_obj.compute_millis == 20 + assert result_obj.post_processing_millis == 12 + assert result_obj.component_distribution == {"foo": 42} + assert result_obj.configuration == {"bar": 1337} + + +def test_mutate_with_optional_params(graph: Graph) -> None: + result = { + "nodePropertiesWritten": 5, + "mutateMillis": 42, + "componentCount": 3, + "preProcessingMillis": 10, + "computeMillis": 20, + "postProcessingMillis": 12, + "componentDistribution": {"foo": 42}, + "configuration": {"bar": 1337}, + } + + query_runner = CollectingQueryRunner(DEFAULT_SERVER_VERSION, {"wcc.mutate": pd.DataFrame([result])}) + + WccCypherEndpoints(query_runner).mutate( + graph, + "componentId", + threshold=0.5, + relationship_types=["REL"], + node_labels=["Person"], + sudo=True, + log_progress=True, + username="neo4j", + concurrency=4, + job_id="test-job", + seed_property="seed", + consecutive_ids=True, + relationship_weight_property="weight", + ) + + assert len(query_runner.queries) == 1 + assert "gds.wcc.mutate" in query_runner.queries[0] + params = query_runner.params[0] + assert params["graph_name"] == "test_graph" + assert params["config"] == { + "mutateProperty": "componentId", + "threshold": 0.5, + "relationshipTypes": ["REL"], + "nodeLabels": ["Person"], + "sudo": True, + "logProgress": True, + "username": "neo4j", + "concurrency": 4, + "jobId": "test-job", + "seedProperty": "seed", + "consecutiveIds": True, + "relationshipWeightProperty": "weight", + } + + +def test_stats_basic(graph: Graph) -> None: + result = { + "componentCount": 3, + "preProcessingMillis": 10, + "computeMillis": 20, + "postProcessingMillis": 12, + "componentDistribution": {"foo": 42}, + "configuration": {"bar": 1337}, + } + + query_runner = CollectingQueryRunner(DEFAULT_SERVER_VERSION, {"wcc.stats": pd.DataFrame([result])}) + + result_obj = WccCypherEndpoints(query_runner).stats(graph) + + assert len(query_runner.queries) == 1 + assert "gds.wcc.stats" in query_runner.queries[0] + params = query_runner.params[0] + assert params["graph_name"] == "test_graph" + config = params["config"] + assert "jobId" in config + + assert isinstance(result_obj, WccStatsResult) + assert result_obj.component_count == 3 + assert result_obj.pre_processing_millis == 10 + assert result_obj.compute_millis == 20 + assert result_obj.post_processing_millis == 12 + assert result_obj.component_distribution == {"foo": 42} + assert result_obj.configuration == {"bar": 1337} + + +def test_stats_with_optional_params(graph: Graph) -> None: + result = { + "componentCount": 3, + "preProcessingMillis": 10, + "computeMillis": 20, + "postProcessingMillis": 12, + "componentDistribution": {"foo": 42}, + "configuration": {"bar": 1337}, + } + + query_runner = CollectingQueryRunner(DEFAULT_SERVER_VERSION, {"wcc.stats": pd.DataFrame([result])}) + + WccCypherEndpoints(query_runner).stats( + graph, + threshold=0.5, + relationship_types=["REL"], + node_labels=["Person"], + sudo=True, + log_progress=True, + username="neo4j", + concurrency=4, + job_id="test-job", + seed_property="seed", + consecutive_ids=True, + relationship_weight_property="weight", + ) + + assert len(query_runner.queries) == 1 + assert "gds.wcc.stats" in query_runner.queries[0] + params = query_runner.params[0] + assert params["graph_name"] == "test_graph" + assert params["config"] == { + "threshold": 0.5, + "relationshipTypes": ["REL"], + "nodeLabels": ["Person"], + "sudo": True, + "logProgress": True, + "username": "neo4j", + "concurrency": 4, + "jobId": "test-job", + "seedProperty": "seed", + "consecutiveIds": True, + "relationshipWeightProperty": "weight", + } + + +def test_stream_basic(wcc_endpoints: WccCypherEndpoints, graph: Graph, query_runner: CollectingQueryRunner) -> None: + wcc_endpoints.stream(graph) + + assert len(query_runner.queries) == 1 + assert "gds.wcc.stream" in query_runner.queries[0] + params = query_runner.params[0] + assert params["graph_name"] == "test_graph" + config = params["config"] + assert "jobId" in config + + +def test_stream_with_optional_params( + wcc_endpoints: WccCypherEndpoints, graph: Graph, query_runner: CollectingQueryRunner +) -> None: + wcc_endpoints.stream( + graph, + min_component_size=2, + threshold=0.5, + relationship_types=["REL"], + node_labels=["Person"], + sudo=True, + log_progress=True, + username="neo4j", + concurrency=4, + job_id="test-job", + seed_property="seed", + consecutive_ids=True, + relationship_weight_property="weight", + ) + + assert len(query_runner.queries) == 1 + assert "gds.wcc.stream" in query_runner.queries[0] + params = query_runner.params[0] + assert params["graph_name"] == "test_graph" + assert params["config"] == { + "minComponentSize": 2, + "threshold": 0.5, + "relationshipTypes": ["REL"], + "nodeLabels": ["Person"], + "sudo": True, + "logProgress": True, + "username": "neo4j", + "concurrency": 4, + "jobId": "test-job", + "seedProperty": "seed", + "consecutiveIds": True, + "relationshipWeightProperty": "weight", + } + + +def test_write_basic(graph: Graph) -> None: + result = { + "componentCount": 3, + "preProcessingMillis": 10, + "computeMillis": 20, + "writeMillis": 15, + "postProcessingMillis": 12, + "nodePropertiesWritten": 5, + "componentDistribution": {"foo": 42}, + "configuration": {"bar": 1337}, + } + + query_runner = CollectingQueryRunner(DEFAULT_SERVER_VERSION, {"wcc.write": pd.DataFrame([result])}) + + result_obj = WccCypherEndpoints(query_runner).write(graph, "componentId") + + assert len(query_runner.queries) == 1 + assert "gds.wcc.write" in query_runner.queries[0] + params = query_runner.params[0] + assert params["graph_name"] == "test_graph" + config = params["config"] + assert config["writeProperty"] == "componentId" + assert "jobId" in config + + assert isinstance(result_obj, WccWriteResult) + assert result_obj.component_count == 3 + assert result_obj.pre_processing_millis == 10 + assert result_obj.compute_millis == 20 + assert result_obj.write_millis == 15 + assert result_obj.post_processing_millis == 12 + assert result_obj.node_properties_written == 5 + assert result_obj.component_distribution == {"foo": 42} + assert result_obj.configuration == {"bar": 1337} + + +def test_write_with_optional_params(graph: Graph) -> None: + result = { + "componentCount": 3, + "preProcessingMillis": 10, + "computeMillis": 20, + "writeMillis": 15, + "postProcessingMillis": 12, + "nodePropertiesWritten": 5, + "componentDistribution": {"foo": 42}, + "configuration": {"bar": 1337}, + } + + query_runner = CollectingQueryRunner(DEFAULT_SERVER_VERSION, {"wcc.write": pd.DataFrame([result])}) + + WccCypherEndpoints(query_runner).write( + graph, + "componentId", + min_component_size=2, + threshold=0.5, + relationship_types=["REL"], + node_labels=["Person"], + sudo=True, + log_progress=True, + username="neo4j", + concurrency=4, + job_id="test-job", + seed_property="seed", + consecutive_ids=True, + relationship_weight_property="weight", + write_concurrency=4, + ) + + assert len(query_runner.queries) == 1 + assert "gds.wcc.write" in query_runner.queries[0] + params = query_runner.params[0] + assert params["graph_name"] == "test_graph" + assert params["config"] == { + "writeProperty": "componentId", + "minComponentSize": 2, + "threshold": 0.5, + "relationshipTypes": ["REL"], + "nodeLabels": ["Person"], + "sudo": True, + "logProgress": True, + "username": "neo4j", + "concurrency": 4, + "jobId": "test-job", + "seedProperty": "seed", + "consecutiveIds": True, + "relationshipWeightProperty": "weight", + "writeConcurrency": 4, + } diff --git a/graphdatascience/tests/unit/procedure_surface/test_config_converter.py b/graphdatascience/tests/unit/procedure_surface/test_config_converter.py new file mode 100644 index 000000000..a01dd86e6 --- /dev/null +++ b/graphdatascience/tests/unit/procedure_surface/test_config_converter.py @@ -0,0 +1,39 @@ +from graphdatascience.procedure_surface.config_converter import ConfigConverter + + +def test_build_configuration_with_no_additional_args() -> None: + config = ConfigConverter.convert_to_gds_config() + assert config == {} + + +def test_build_configuration_with_additional_args() -> None: + config = ConfigConverter.convert_to_gds_config(some_property="value", another_property=42) + assert config["someProperty"] == "value" + assert config["anotherProperty"] == 42 + + +def test_build_configuration_ignores_none_values() -> None: + config = ConfigConverter.convert_to_gds_config(included_property="present", excluded_property=None) + assert "includedProperty" in config + assert "excludedProperty" not in config + assert config["includedProperty"] == "present" + + +def test_build_configuration_with_nested_dict() -> None: + config = ConfigConverter.convert_to_gds_config(foo_bar={"bar_baz": 42, "another_key": "value"}) + assert "fooBar" in config + assert isinstance(config["fooBar"], dict) + assert "barBaz" in config["fooBar"] + assert config["fooBar"]["barBaz"] == 42 + assert "anotherKey" in config["fooBar"] + assert config["fooBar"]["anotherKey"] == "value" + + +def test_build_configuration_with_deeply_nested_dict() -> None: + config = ConfigConverter.convert_to_gds_config(level_one={"level_two": {"level_three": 42}}) + assert "levelOne" in config + assert isinstance(config["levelOne"], dict) + assert "levelTwo" in config["levelOne"] + assert isinstance(config["levelOne"]["levelTwo"], dict) + assert "levelThree" in config["levelOne"]["levelTwo"] + assert config["levelOne"]["levelTwo"]["levelThree"] == 42 diff --git a/graphdatascience/tests/unit/test_arrow_runner.py b/graphdatascience/tests/unit/test_arrow_runner.py index dafa4f3a6..8cda64e9b 100644 --- a/graphdatascience/tests/unit/test_arrow_runner.py +++ b/graphdatascience/tests/unit/test_arrow_runner.py @@ -2,7 +2,7 @@ from pyarrow.flight import FlightUnavailableError from tenacity import retry_any, stop_after_attempt, wait_fixed -from graphdatascience.query_runner.arrow_info import ArrowInfo +from graphdatascience.arrow_client.arrow_info import ArrowInfo from graphdatascience.query_runner.arrow_query_runner import ArrowQueryRunner from graphdatascience.retry_utils.retry_config import RetryConfig from graphdatascience.server_version.server_version import ServerVersion diff --git a/graphdatascience/tests/unit/test_gds_arrow_client.py b/graphdatascience/tests/unit/test_gds_arrow_client.py index 80be7b80a..1145a5a6a 100644 --- a/graphdatascience/tests/unit/test_gds_arrow_client.py +++ b/graphdatascience/tests/unit/test_gds_arrow_client.py @@ -14,8 +14,8 @@ Ticket, ) -from graphdatascience.query_runner.arrow_authentication import UsernamePasswordAuthentication -from graphdatascience.query_runner.arrow_info import ArrowInfo +from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication +from graphdatascience.arrow_client.arrow_info import ArrowInfo from graphdatascience.query_runner.gds_arrow_client import AuthMiddleware, GdsArrowClient ActionParam = Union[str, tuple[str, Any], Action] diff --git a/graphdatascience/tests/unit/test_init.py b/graphdatascience/tests/unit/test_init.py index b431bb41c..e95ac67fa 100644 --- a/graphdatascience/tests/unit/test_init.py +++ b/graphdatascience/tests/unit/test_init.py @@ -3,8 +3,8 @@ import pytest from pandas import DataFrame +from graphdatascience.arrow_client.arrow_info import ArrowInfo from graphdatascience.graph_data_science import GraphDataScience -from graphdatascience.query_runner.arrow_info import ArrowInfo from graphdatascience.server_version.server_version import ServerVersion from graphdatascience.tests.unit.conftest import CollectingQueryRunner