From 4246b304fecc9b1a427e3a98720cfa83270660c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Fri, 28 Mar 2025 17:18:09 +0100 Subject: [PATCH 01/12] Introduce explict endpoints for Wcc --- .../procedure_surface/api/__init__.py | 0 .../procedure_surface/api/wcc_endpoints.py | 242 ++++++++++++++++++ 2 files changed, 242 insertions(+) create mode 100644 graphdatascience/procedure_surface/api/__init__.py create mode 100644 graphdatascience/procedure_surface/api/wcc_endpoints.py diff --git a/graphdatascience/procedure_surface/api/__init__.py b/graphdatascience/procedure_surface/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/procedure_surface/api/wcc_endpoints.py b/graphdatascience/procedure_surface/api/wcc_endpoints.py new file mode 100644 index 000000000..f6bec4679 --- /dev/null +++ b/graphdatascience/procedure_surface/api/wcc_endpoints.py @@ -0,0 +1,242 @@ +from abc import ABC, abstractmethod +from typing import Any, List, Optional + +from pandas import DataFrame, Series + +from ...graph.graph_object import Graph + + +class WccEndpoints(ABC): + """ + Abstract base class defining the API for the Weakly Connected Components (WCC) algorithm. + """ + + @abstractmethod + def mutate( + self, + G: Graph, + mutate_property: str, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + ) -> Series[Any]: + """ + Executes the WCC algorithm and writes the results to the in-memory graph as node properties. + + Parameters + ---------- + G : Graph + The graph to run the algorithm on + mutate_property : str + The property name to store the component ID for each node + threshold : Optional[float], default=None + The minimum required weight to consider a relationship during traversal + relationship_types : Optional[List[str]], default=None + The relationship types to project + node_labels : Optional[List[str]], default=None + The node labels to project + sudo : Optional[bool], default=None + Run analysis with admin permission + log_progress : Optional[bool], default=None + Whether to log progress + username : Optional[str], default=None + The username to attribute the procedure run to + concurrency : Optional[Any], default=None + The number of concurrent threads + job_id : Optional[Any], default=None + An identifier for the job + seed_property : Optional[str], default=None + Defines node properties that are used as initial component identifiers + consecutive_ids : Optional[bool], default=None + Flag to decide whether component identifiers are mapped into a consecutive id space + relationship_weight_property : Optional[str], default=None + The property name that contains weight + + Returns + ------- + Series + Algorithm metrics and statistics + """ + pass + + @abstractmethod + def stats( + self, + G: Graph, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + ) -> Series[Any]: + """ + Executes the WCC algorithm and returns statistics. + + Parameters + ---------- + G : Graph + The graph to run the algorithm on + threshold : Optional[float], default=None + The minimum required weight to consider a relationship during traversal + relationship_types : Optional[List[str]], default=None + The relationship types to project + node_labels : Optional[List[str]], default=None + The node labels to project + sudo : Optional[bool], default=None + Run analysis with admin permission + log_progress : Optional[bool], default=None + Whether to log progress + username : Optional[str], default=None + The username to attribute the procedure run to + concurrency : Optional[Any], default=None + The number of concurrent threads + job_id : Optional[Any], default=None + An identifier for the job + seed_property : Optional[str], default=None + Defines node properties that are used as initial component identifiers + consecutive_ids : Optional[bool], default=None + Flag to decide whether component identifiers are mapped into a consecutive id space + relationship_weight_property : Optional[str], default=None + The property name that contains weight + + Returns + ------- + Series + Algorithm metrics and statistics + """ + pass + + @abstractmethod + def stream( + self, + G: Graph, + min_component_size: Optional[int] = None, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + ) -> DataFrame: + """ + Executes the WCC algorithm and returns a stream of results. + + Parameters + ---------- + G : Graph + The graph to run the algorithm on + min_component_size : Optional[int], default=None + Don't stream components with fewer nodes than this + threshold : Optional[float], default=None + The minimum required weight to consider a relationship during traversal + relationship_types : Optional[List[str]], default=None + The relationship types to project + node_labels : Optional[List[str]], default=None + The node labels to project + sudo : Optional[bool], default=None + Run analysis with admin permission + log_progress : Optional[bool], default=None + Whether to log progress + username : Optional[str], default=None + The username to attribute the procedure run to + concurrency : Optional[Any], default=None + The number of concurrent threads + job_id : Optional[Any], default=None + An identifier for the job + seed_property : Optional[str], default=None + Defines node properties that are used as initial component identifiers + consecutive_ids : Optional[bool], default=None + Flag to decide whether component identifiers are mapped into a consecutive id space + relationship_weight_property : Optional[str], default=None + The property name that contains weight + + Returns + ------- + DataFrame + DataFrame with the algorithm results + """ + pass + + @abstractmethod + def write( + self, + G: Graph, + write_property: str, + min_component_size: Optional[int] = None, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + write_concurrency: Optional[Any] = None, + write_to_result_store: Optional[bool] = None, + ) -> Series[Any]: + """ + Executes the WCC algorithm and writes the results to the Neo4j database. + + Parameters + ---------- + G : Graph + The graph to run the algorithm on + write_property : str + The property name to write component IDs to + min_component_size : Optional[int], default=None + Don't write components with fewer nodes than this + threshold : Optional[float], default=None + The minimum required weight to consider a relationship during traversal + relationship_types : Optional[List[str]], default=None + The relationship types to project + node_labels : Optional[List[str]], default=None + The node labels to project + sudo : Optional[bool], default=None + Run analysis with admin permission + log_progress : Optional[bool], default=None + Whether to log progress + username : Optional[str], default=None + The username to attribute the procedure run to + concurrency : Optional[Any], default=None + The number of concurrent threads + job_id : Optional[Any], default=None + An identifier for the job + seed_property : Optional[str], default=None + Defines node properties that are used as initial component identifiers + consecutive_ids : Optional[bool], default=None + Flag to decide whether component identifiers are mapped into a consecutive id space + relationship_weight_property : Optional[str], default=None + The property name that contains weight + write_concurrency : Optional[Any], default=None + The number of concurrent threads during the write phase + write_to_result_store : Optional[bool], default=None + Whether to write the results to the result store + + Returns + ------- + Series + Algorithm metrics and statistics + """ + pass From c0edacceb3324ede559310af4dc735524bbc5a2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Fri, 28 Mar 2025 17:18:38 +0100 Subject: [PATCH 02/12] Implement Wcc endpoints using Cypher procedures --- graphdatascience/graph_data_science.py | 8 + .../procedure_surface/__init__.py | 0 .../procedure_surface/cypher/__init__.py | 0 .../cypher/wcc_proc_runner.py | 228 ++++++++++++++++++ .../cypher/test_wcc_cypher_endpoints.py | 227 +++++++++++++++++ 5 files changed, 463 insertions(+) create mode 100644 graphdatascience/procedure_surface/__init__.py create mode 100644 graphdatascience/procedure_surface/cypher/__init__.py create mode 100644 graphdatascience/procedure_surface/cypher/wcc_proc_runner.py create mode 100644 graphdatascience/tests/unit/procedure_surface/cypher/test_wcc_cypher_endpoints.py diff --git a/graphdatascience/graph_data_science.py b/graphdatascience/graph_data_science.py index 56c91b57a..c94eea85d 100644 --- a/graphdatascience/graph_data_science.py +++ b/graphdatascience/graph_data_science.py @@ -9,6 +9,8 @@ from pandas import DataFrame from graphdatascience.query_runner.arrow_authentication import UsernamePasswordAuthentication +from graphdatascience.procedure_surface.api.wcc_endpoints import WccEndpoints +from graphdatascience.procedure_surface.cypher.wcc_proc_runner import WccCypherEndpoints from .call_builder import IndirectCallBuilder from .endpoints import AlphaEndpoints, BetaEndpoints, DirectEndpoints @@ -117,10 +119,16 @@ def __init__( self._query_runner.set_show_progress(show_progress) super().__init__(self._query_runner, namespace="gds", server_version=self._server_version) + self._wcc_endpoints = WccCypherEndpoints(self._query_runner) + @property def graph(self) -> GraphProcRunner: return GraphProcRunner(self._query_runner, f"{self._namespace}.graph", self._server_version) + @property + def wcc(self) -> WccEndpoints: + return self._wcc_endpoints + @property def util(self) -> UtilProcRunner: return UtilProcRunner(self._query_runner, f"{self._namespace}.util", self._server_version) diff --git a/graphdatascience/procedure_surface/__init__.py b/graphdatascience/procedure_surface/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/procedure_surface/cypher/__init__.py b/graphdatascience/procedure_surface/cypher/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/procedure_surface/cypher/wcc_proc_runner.py b/graphdatascience/procedure_surface/cypher/wcc_proc_runner.py new file mode 100644 index 000000000..eb02ff741 --- /dev/null +++ b/graphdatascience/procedure_surface/cypher/wcc_proc_runner.py @@ -0,0 +1,228 @@ +from typing import Any, List, Optional + +from pandas import DataFrame, Series + +from ...call_parameters import CallParameters +from ...graph.graph_object import Graph +from ...query_runner.query_runner import QueryRunner +from ..api.wcc_endpoints import WccEndpoints + + +class WccCypherEndpoints(WccEndpoints): + """ + Implementation of the WCC algorithm endpoints. + This class handles the actual execution by forwarding calls to the query runner. + """ + + def __init__(self, query_runner: QueryRunner): + self._query_runner = query_runner + + def mutate( + self, + G: Graph, + mutate_property: str, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[int] = None, + job_id: Optional[str] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + ) -> Series[Any]: + # Build configuration dictionary from parameters + config: dict[str, Any] = { + "mutateProperty": mutate_property, + } + + # Add optional parameters + if threshold is not None: + config["threshold"] = threshold + if relationship_types is not None: + config["relationshipTypes"] = relationship_types + if node_labels is not None: + config["nodeLabels"] = node_labels + if sudo is not None: + config["sudo"] = sudo + if log_progress is not None: + config["logProgress"] = log_progress + if username is not None: + config["username"] = username + if concurrency is not None: + config["concurrency"] = concurrency + if job_id is not None: + config["jobId"] = job_id + if seed_property is not None: + config["seedProperty"] = seed_property + if consecutive_ids is not None: + config["consecutiveIds"] = consecutive_ids + if relationship_weight_property is not None: + config["relationshipWeightProperty"] = relationship_weight_property + + # Run procedure and return results + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + return self._query_runner.call_procedure(endpoint="gds.wcc.mutate", params=params).squeeze() # type: ignore + + def stats( + self, + G: Graph, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[int] = None, + job_id: Optional[str] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + ) -> Series[Any]: + # Build configuration dictionary from parameters + config: dict[str, Any] = {} + + # Add optional parameters + if threshold is not None: + config["threshold"] = threshold + if relationship_types is not None: + config["relationshipTypes"] = relationship_types + if node_labels is not None: + config["nodeLabels"] = node_labels + if sudo is not None: + config["sudo"] = sudo + if log_progress is not None: + config["logProgress"] = log_progress + if username is not None: + config["username"] = username + if concurrency is not None: + config["concurrency"] = concurrency + if job_id is not None: + config["jobId"] = job_id + if seed_property is not None: + config["seedProperty"] = seed_property + if consecutive_ids is not None: + config["consecutiveIds"] = consecutive_ids + if relationship_weight_property is not None: + config["relationshipWeightProperty"] = relationship_weight_property + + # Run procedure and return results + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + return self._query_runner.call_procedure(endpoint="gds.wcc.stats", params=params).squeeze() # type: ignore + + def stream( + self, + G: Graph, + min_component_size: Optional[int] = None, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[int] = None, + job_id: Optional[str] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + ) -> DataFrame: + # Build configuration dictionary from parameters + config: dict[str, Any] = {} + + # Add optional parameters + if min_component_size is not None: + config["minComponentSize"] = min_component_size + if threshold is not None: + config["threshold"] = threshold + if relationship_types is not None: + config["relationshipTypes"] = relationship_types + if node_labels is not None: + config["nodeLabels"] = node_labels + if sudo is not None: + config["sudo"] = sudo + if log_progress is not None: + config["logProgress"] = log_progress + if username is not None: + config["username"] = username + if concurrency is not None: + config["concurrency"] = concurrency + if job_id is not None: + config["jobId"] = job_id + if seed_property is not None: + config["seedProperty"] = seed_property + if consecutive_ids is not None: + config["consecutiveIds"] = consecutive_ids + if relationship_weight_property is not None: + config["relationshipWeightProperty"] = relationship_weight_property + + # Run procedure and return results + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + return self._query_runner.call_procedure(endpoint="gds.wcc.stream", params=params) + + def write( + self, + G: Graph, + write_property: str, + min_component_size: Optional[int] = None, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[int] = None, + job_id: Optional[str] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + write_concurrency: Optional[int] = None, + write_to_result_store: Optional[bool] = None, + ) -> Series[Any]: + # Build configuration dictionary from parameters + config: dict[str, Any] = { + "writeProperty": write_property, + } + + # Add optional parameters + if min_component_size is not None: + config["minComponentSize"] = min_component_size + if threshold is not None: + config["threshold"] = threshold + if relationship_types is not None: + config["relationshipTypes"] = relationship_types + if node_labels is not None: + config["nodeLabels"] = node_labels + if sudo is not None: + config["sudo"] = sudo + if log_progress is not None: + config["logProgress"] = log_progress + if username is not None: + config["username"] = username + if concurrency is not None: + config["concurrency"] = concurrency + if job_id is not None: + config["jobId"] = job_id + if seed_property is not None: + config["seedProperty"] = seed_property + if consecutive_ids is not None: + config["consecutiveIds"] = consecutive_ids + if relationship_weight_property is not None: + config["relationshipWeightProperty"] = relationship_weight_property + if write_concurrency is not None: + config["writeConcurrency"] = write_concurrency + if write_to_result_store is not None: + config["writeToResultStore"] = write_to_result_store + + # Run procedure and return results + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + return self._query_runner.call_procedure(endpoint="gds.wcc.write", params=params).squeeze() # type: ignore diff --git a/graphdatascience/tests/unit/procedure_surface/cypher/test_wcc_cypher_endpoints.py b/graphdatascience/tests/unit/procedure_surface/cypher/test_wcc_cypher_endpoints.py new file mode 100644 index 000000000..a15204629 --- /dev/null +++ b/graphdatascience/tests/unit/procedure_surface/cypher/test_wcc_cypher_endpoints.py @@ -0,0 +1,227 @@ +import pytest + +from graphdatascience.graph.graph_object import Graph +from graphdatascience.procedure_surface.cypher.wcc_proc_runner import WccCypherEndpoints +from graphdatascience.server_version.server_version import ServerVersion +from graphdatascience.tests.unit.conftest import CollectingQueryRunner + + +@pytest.fixture +def query_runner() -> CollectingQueryRunner: + return CollectingQueryRunner(ServerVersion(2, 16, 0)) + + +@pytest.fixture +def wcc_endpoints(query_runner: CollectingQueryRunner) -> WccCypherEndpoints: + return WccCypherEndpoints(query_runner) + + +@pytest.fixture +def graph(query_runner: CollectingQueryRunner) -> Graph: + return Graph("test_graph", query_runner) + + +def test_mutate_basic(wcc_endpoints: WccCypherEndpoints, graph: Graph, query_runner: CollectingQueryRunner) -> None: + wcc_endpoints.mutate(graph, "componentId") + + assert len(query_runner.queries) == 1 + assert "gds.wcc.mutate" in query_runner.queries[0] + params = query_runner.params[0] + assert params["graph_name"] == "test_graph" + config = params["config"] + assert config["mutateProperty"] == "componentId" + assert "jobId" in config + + +def test_mutate_with_optional_params( + wcc_endpoints: WccCypherEndpoints, graph: Graph, query_runner: CollectingQueryRunner +) -> None: + wcc_endpoints.mutate( + graph, + "componentId", + threshold=0.5, + relationship_types=["REL"], + node_labels=["Person"], + sudo=True, + log_progress=True, + username="neo4j", + concurrency=4, + job_id="test-job", + seed_property="seed", + consecutive_ids=True, + relationship_weight_property="weight", + ) + + assert len(query_runner.queries) == 1 + assert "gds.wcc.mutate" in query_runner.queries[0] + params = query_runner.params[0] + assert params["graph_name"] == "test_graph" + assert params["config"] == { + "mutateProperty": "componentId", + "threshold": 0.5, + "relationshipTypes": ["REL"], + "nodeLabels": ["Person"], + "sudo": True, + "logProgress": True, + "username": "neo4j", + "concurrency": 4, + "jobId": "test-job", + "seedProperty": "seed", + "consecutiveIds": True, + "relationshipWeightProperty": "weight", + } + + +def test_stats_basic(wcc_endpoints: WccCypherEndpoints, graph: Graph, query_runner: CollectingQueryRunner) -> None: + wcc_endpoints.stats(graph) + + assert len(query_runner.queries) == 1 + assert "gds.wcc.stats" in query_runner.queries[0] + params = query_runner.params[0] + assert params["graph_name"] == "test_graph" + config = params["config"] + assert "jobId" in config + + +def test_stats_with_optional_params( + wcc_endpoints: WccCypherEndpoints, graph: Graph, query_runner: CollectingQueryRunner +) -> None: + wcc_endpoints.stats( + graph, + threshold=0.5, + relationship_types=["REL"], + node_labels=["Person"], + sudo=True, + log_progress=True, + username="neo4j", + concurrency=4, + job_id="test-job", + seed_property="seed", + consecutive_ids=True, + relationship_weight_property="weight", + ) + + assert len(query_runner.queries) == 1 + assert "gds.wcc.stats" in query_runner.queries[0] + params = query_runner.params[0] + assert params["graph_name"] == "test_graph" + assert params["config"] == { + "threshold": 0.5, + "relationshipTypes": ["REL"], + "nodeLabels": ["Person"], + "sudo": True, + "logProgress": True, + "username": "neo4j", + "concurrency": 4, + "jobId": "test-job", + "seedProperty": "seed", + "consecutiveIds": True, + "relationshipWeightProperty": "weight", + } + + +def test_stream_basic(wcc_endpoints: WccCypherEndpoints, graph: Graph, query_runner: CollectingQueryRunner) -> None: + wcc_endpoints.stream(graph) + + assert len(query_runner.queries) == 1 + assert "gds.wcc.stream" in query_runner.queries[0] + params = query_runner.params[0] + assert params["graph_name"] == "test_graph" + config = params["config"] + assert "jobId" in config + + +def test_stream_with_optional_params( + wcc_endpoints: WccCypherEndpoints, graph: Graph, query_runner: CollectingQueryRunner +) -> None: + wcc_endpoints.stream( + graph, + min_component_size=2, + threshold=0.5, + relationship_types=["REL"], + node_labels=["Person"], + sudo=True, + log_progress=True, + username="neo4j", + concurrency=4, + job_id="test-job", + seed_property="seed", + consecutive_ids=True, + relationship_weight_property="weight", + ) + + assert len(query_runner.queries) == 1 + assert "gds.wcc.stream" in query_runner.queries[0] + params = query_runner.params[0] + assert params["graph_name"] == "test_graph" + assert params["config"] == { + "minComponentSize": 2, + "threshold": 0.5, + "relationshipTypes": ["REL"], + "nodeLabels": ["Person"], + "sudo": True, + "logProgress": True, + "username": "neo4j", + "concurrency": 4, + "jobId": "test-job", + "seedProperty": "seed", + "consecutiveIds": True, + "relationshipWeightProperty": "weight", + } + + +def test_write_basic(wcc_endpoints: WccCypherEndpoints, graph: Graph, query_runner: CollectingQueryRunner) -> None: + wcc_endpoints.write(graph, "componentId") + + assert len(query_runner.queries) == 1 + assert "gds.wcc.write" in query_runner.queries[0] + params = query_runner.params[0] + assert params["graph_name"] == "test_graph" + config = params["config"] + assert config["writeProperty"] == "componentId" + assert "jobId" in config + + +def test_write_with_optional_params( + wcc_endpoints: WccCypherEndpoints, graph: Graph, query_runner: CollectingQueryRunner +) -> None: + wcc_endpoints.write( + graph, + "componentId", + min_component_size=2, + threshold=0.5, + relationship_types=["REL"], + node_labels=["Person"], + sudo=True, + log_progress=True, + username="neo4j", + concurrency=4, + job_id="test-job", + seed_property="seed", + consecutive_ids=True, + relationship_weight_property="weight", + write_concurrency=4, + write_to_result_store=True, + ) + + assert len(query_runner.queries) == 1 + assert "gds.wcc.write" in query_runner.queries[0] + params = query_runner.params[0] + assert params["graph_name"] == "test_graph" + assert params["config"] == { + "writeProperty": "componentId", + "minComponentSize": 2, + "threshold": 0.5, + "relationshipTypes": ["REL"], + "nodeLabels": ["Person"], + "sudo": True, + "logProgress": True, + "username": "neo4j", + "concurrency": 4, + "jobId": "test-job", + "seedProperty": "seed", + "consecutiveIds": True, + "relationshipWeightProperty": "weight", + "writeConcurrency": 4, + "writeToResultStore": True, + } From 50d88ab0881a20cd669771362781e80b35393a17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Mon, 23 Jun 2025 17:11:03 +0200 Subject: [PATCH 03/12] Move arrow client related code into its own package --- graphdatascience/arrow_client/__init__.py | 0 .../arrow_authentication.py | 0 .../arrow_info.py | 4 +- .../arrow_client/middleware/AuthMiddleware.py | 63 +++++++++++++++++++ .../middleware/UserAgentMiddleware.py | 24 +++++++ graphdatascience/graph_data_science.py | 4 +- .../query_runner/arrow_query_runner.py | 4 +- .../query_runner/gds_arrow_client.py | 4 +- .../session/aura_api_token_authentication.py | 2 +- .../session/aura_graph_data_science.py | 4 +- .../session/dedicated_sessions.py | 2 +- .../tests/integration/conftest.py | 2 +- graphdatascience/tests/unit/conftest.py | 4 +- .../tests/unit/test_arrow_runner.py | 2 +- .../tests/unit/test_gds_arrow_client.py | 4 +- graphdatascience/tests/unit/test_init.py | 2 +- 16 files changed, 106 insertions(+), 19 deletions(-) create mode 100644 graphdatascience/arrow_client/__init__.py rename graphdatascience/{query_runner => arrow_client}/arrow_authentication.py (100%) rename graphdatascience/{query_runner => arrow_client}/arrow_info.py (85%) create mode 100644 graphdatascience/arrow_client/middleware/AuthMiddleware.py create mode 100644 graphdatascience/arrow_client/middleware/UserAgentMiddleware.py diff --git a/graphdatascience/arrow_client/__init__.py b/graphdatascience/arrow_client/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/query_runner/arrow_authentication.py b/graphdatascience/arrow_client/arrow_authentication.py similarity index 100% rename from graphdatascience/query_runner/arrow_authentication.py rename to graphdatascience/arrow_client/arrow_authentication.py diff --git a/graphdatascience/query_runner/arrow_info.py b/graphdatascience/arrow_client/arrow_info.py similarity index 85% rename from graphdatascience/query_runner/arrow_info.py rename to graphdatascience/arrow_client/arrow_info.py index 8a2399182..a11b48c49 100644 --- a/graphdatascience/query_runner/arrow_info.py +++ b/graphdatascience/arrow_client/arrow_info.py @@ -2,8 +2,8 @@ from dataclasses import dataclass -from ..query_runner.query_runner import QueryRunner -from ..server_version.server_version import ServerVersion +from graphdatascience.query_runner.query_runner import QueryRunner +from graphdatascience.server_version.server_version import ServerVersion @dataclass(frozen=True) diff --git a/graphdatascience/arrow_client/middleware/AuthMiddleware.py b/graphdatascience/arrow_client/middleware/AuthMiddleware.py new file mode 100644 index 000000000..e350f7d38 --- /dev/null +++ b/graphdatascience/arrow_client/middleware/AuthMiddleware.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import base64 +import time +from typing import Optional, Any + +from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory + +from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication + + +class AuthFactory(ClientMiddlewareFactory): # type: ignore + def __init__(self, middleware: AuthMiddleware, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._middleware = middleware + + def start_call(self, info: Any) -> AuthMiddleware: + return self._middleware + + +class AuthMiddleware(ClientMiddleware): # type: ignore + def __init__(self, auth: ArrowAuthentication, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._auth = auth + self._token: Optional[str] = None + self._token_timestamp = 0 + + def token(self) -> Optional[str]: + # check whether the token is older than 10 minutes. If so, reset it. + if self._token and int(time.time()) - self._token_timestamp > 600: + self._token = None + + return self._token + + def _set_token(self, token: str) -> None: + self._token = token + self._token_timestamp = int(time.time()) + + def received_headers(self, headers: dict[str, Any]) -> None: + auth_header = headers.get("authorization", None) + if not auth_header: + return + + # the result is always a list + header_value = auth_header[0] + + if not isinstance(header_value, str): + raise ValueError(f"Incompatible header value received from server: `{header_value}`") + + auth_type, token = header_value.split(" ", 1) + if auth_type == "Bearer": + self._set_token(token) + + def sending_headers(self) -> dict[str, str]: + token = self.token() + if token is not None: + return {"authorization": "Bearer " + token} + + auth_pair = self._auth.auth_pair() + auth_token = f"{auth_pair[0]}:{auth_pair[1]}" + auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII") + # There seems to be a bug, `authorization` must be lower key + return {"authorization": auth_token} \ No newline at end of file diff --git a/graphdatascience/arrow_client/middleware/UserAgentMiddleware.py b/graphdatascience/arrow_client/middleware/UserAgentMiddleware.py new file mode 100644 index 000000000..b6313502f --- /dev/null +++ b/graphdatascience/arrow_client/middleware/UserAgentMiddleware.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing import Any + +from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory + +class UserAgentFactory(ClientMiddlewareFactory): # type: ignore + def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._middleware = UserAgentMiddleware(useragent) + + def start_call(self, info: Any) -> ClientMiddleware: + return self._middleware + +class UserAgentMiddleware(ClientMiddleware): # type: ignore + def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._useragent = useragent + + def sending_headers(self) -> dict[str, str]: + return {"x-gds-user-agent": self._useragent} + + def received_headers(self, headers: dict[str, Any]) -> None: + pass \ No newline at end of file diff --git a/graphdatascience/graph_data_science.py b/graphdatascience/graph_data_science.py index c94eea85d..f8e2e0eec 100644 --- a/graphdatascience/graph_data_science.py +++ b/graphdatascience/graph_data_science.py @@ -8,7 +8,7 @@ from neo4j import Driver from pandas import DataFrame -from graphdatascience.query_runner.arrow_authentication import UsernamePasswordAuthentication +from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication from graphdatascience.procedure_surface.api.wcc_endpoints import WccEndpoints from graphdatascience.procedure_surface.cypher.wcc_proc_runner import WccCypherEndpoints @@ -16,7 +16,7 @@ from .endpoints import AlphaEndpoints, BetaEndpoints, DirectEndpoints from .error.uncallable_namespace import UncallableNamespace from .graph.graph_proc_runner import GraphProcRunner -from .query_runner.arrow_info import ArrowInfo +from graphdatascience.arrow_client.arrow_info import ArrowInfo from .query_runner.arrow_query_runner import ArrowQueryRunner from .query_runner.neo4j_query_runner import Neo4jQueryRunner from .query_runner.query_runner import QueryRunner diff --git a/graphdatascience/query_runner/arrow_query_runner.py b/graphdatascience/query_runner/arrow_query_runner.py index 36214fb97..975eaed25 100644 --- a/graphdatascience/query_runner/arrow_query_runner.py +++ b/graphdatascience/query_runner/arrow_query_runner.py @@ -5,12 +5,12 @@ from pandas import DataFrame -from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication +from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication from graphdatascience.query_runner.query_mode import QueryMode from graphdatascience.retry_utils.retry_config import RetryConfig from ..call_parameters import CallParameters -from ..query_runner.arrow_info import ArrowInfo +from graphdatascience.arrow_client.arrow_info import ArrowInfo from ..server_version.server_version import ServerVersion from .arrow_graph_constructor import ArrowGraphConstructor from .gds_arrow_client import GdsArrowClient diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index 64f1988ad..c25efdf2e 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -35,14 +35,14 @@ wait_exponential, ) -from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication, UsernamePasswordAuthentication +from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication, UsernamePasswordAuthentication from graphdatascience.retry_utils.retry_config import RetryConfig from graphdatascience.retry_utils.retry_utils import before_log from ..semantic_version.semantic_version import SemanticVersion from ..version import __version__ from .arrow_endpoint_version import ArrowEndpointVersion -from .arrow_info import ArrowInfo +from graphdatascience.arrow_client.arrow_info import ArrowInfo class GdsArrowClient: diff --git a/graphdatascience/session/aura_api_token_authentication.py b/graphdatascience/session/aura_api_token_authentication.py index 37e51bfed..44e5780b7 100644 --- a/graphdatascience/session/aura_api_token_authentication.py +++ b/graphdatascience/session/aura_api_token_authentication.py @@ -1,4 +1,4 @@ -from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication +from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication from graphdatascience.session.aura_api import AuraApi diff --git a/graphdatascience/session/aura_graph_data_science.py b/graphdatascience/session/aura_graph_data_science.py index 565df1266..51a4ffde8 100644 --- a/graphdatascience/session/aura_graph_data_science.py +++ b/graphdatascience/session/aura_graph_data_science.py @@ -13,8 +13,8 @@ ) from graphdatascience.error.uncallable_namespace import UncallableNamespace from graphdatascience.graph.graph_remote_proc_runner import GraphRemoteProcRunner -from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication -from graphdatascience.query_runner.arrow_info import ArrowInfo +from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication +from graphdatascience.arrow_client.arrow_info import ArrowInfo from graphdatascience.query_runner.arrow_query_runner import ArrowQueryRunner from graphdatascience.query_runner.gds_arrow_client import GdsArrowClient from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner diff --git a/graphdatascience/session/dedicated_sessions.py b/graphdatascience/session/dedicated_sessions.py index 8babac7a0..f39a133c8 100644 --- a/graphdatascience/session/dedicated_sessions.py +++ b/graphdatascience/session/dedicated_sessions.py @@ -5,7 +5,7 @@ from datetime import datetime, timedelta, timezone from typing import Any, Optional -from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication +from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner from graphdatascience.session.algorithm_category import AlgorithmCategory from graphdatascience.session.aura_api import AuraApi diff --git a/graphdatascience/tests/integration/conftest.py b/graphdatascience/tests/integration/conftest.py index c520992f7..e48d5450b 100644 --- a/graphdatascience/tests/integration/conftest.py +++ b/graphdatascience/tests/integration/conftest.py @@ -7,7 +7,7 @@ from neo4j import Driver, GraphDatabase from graphdatascience.graph_data_science import GraphDataScience -from graphdatascience.query_runner.arrow_authentication import UsernamePasswordAuthentication +from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner from graphdatascience.server_version.server_version import ServerVersion from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience diff --git a/graphdatascience/tests/unit/conftest.py b/graphdatascience/tests/unit/conftest.py index 7a8964c5c..9ce82b4e2 100644 --- a/graphdatascience/tests/unit/conftest.py +++ b/graphdatascience/tests/unit/conftest.py @@ -10,8 +10,8 @@ from graphdatascience import QueryRunner from graphdatascience.call_parameters import CallParameters from graphdatascience.graph_data_science import GraphDataScience -from graphdatascience.query_runner.arrow_authentication import UsernamePasswordAuthentication -from graphdatascience.query_runner.arrow_info import ArrowInfo +from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication +from graphdatascience.arrow_client.arrow_info import ArrowInfo from graphdatascience.query_runner.cypher_graph_constructor import ( CypherGraphConstructor, ) diff --git a/graphdatascience/tests/unit/test_arrow_runner.py b/graphdatascience/tests/unit/test_arrow_runner.py index dafa4f3a6..8cda64e9b 100644 --- a/graphdatascience/tests/unit/test_arrow_runner.py +++ b/graphdatascience/tests/unit/test_arrow_runner.py @@ -2,7 +2,7 @@ from pyarrow.flight import FlightUnavailableError from tenacity import retry_any, stop_after_attempt, wait_fixed -from graphdatascience.query_runner.arrow_info import ArrowInfo +from graphdatascience.arrow_client.arrow_info import ArrowInfo from graphdatascience.query_runner.arrow_query_runner import ArrowQueryRunner from graphdatascience.retry_utils.retry_config import RetryConfig from graphdatascience.server_version.server_version import ServerVersion diff --git a/graphdatascience/tests/unit/test_gds_arrow_client.py b/graphdatascience/tests/unit/test_gds_arrow_client.py index 80be7b80a..1145a5a6a 100644 --- a/graphdatascience/tests/unit/test_gds_arrow_client.py +++ b/graphdatascience/tests/unit/test_gds_arrow_client.py @@ -14,8 +14,8 @@ Ticket, ) -from graphdatascience.query_runner.arrow_authentication import UsernamePasswordAuthentication -from graphdatascience.query_runner.arrow_info import ArrowInfo +from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication +from graphdatascience.arrow_client.arrow_info import ArrowInfo from graphdatascience.query_runner.gds_arrow_client import AuthMiddleware, GdsArrowClient ActionParam = Union[str, tuple[str, Any], Action] diff --git a/graphdatascience/tests/unit/test_init.py b/graphdatascience/tests/unit/test_init.py index b431bb41c..a51beb03a 100644 --- a/graphdatascience/tests/unit/test_init.py +++ b/graphdatascience/tests/unit/test_init.py @@ -4,7 +4,7 @@ from pandas import DataFrame from graphdatascience.graph_data_science import GraphDataScience -from graphdatascience.query_runner.arrow_info import ArrowInfo +from graphdatascience.arrow_client.arrow_info import ArrowInfo from graphdatascience.server_version.server_version import ServerVersion from graphdatascience.tests.unit.conftest import CollectingQueryRunner From 6c39bc0b58504f3a10b7af27f1d45ad5b86554ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Mon, 23 Jun 2025 17:14:05 +0200 Subject: [PATCH 04/12] Return custom type from wcc.mutate --- .../procedure_surface/api/wcc_endpoints.py | 21 ++++++++++++++++--- .../cypher/wcc_proc_runner.py | 20 +++++++++++++----- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/graphdatascience/procedure_surface/api/wcc_endpoints.py b/graphdatascience/procedure_surface/api/wcc_endpoints.py index f6bec4679..a1cb7a51d 100644 --- a/graphdatascience/procedure_surface/api/wcc_endpoints.py +++ b/graphdatascience/procedure_surface/api/wcc_endpoints.py @@ -1,4 +1,7 @@ +from __future__ import annotations + from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import Any, List, Optional from pandas import DataFrame, Series @@ -27,7 +30,7 @@ def mutate( seed_property: Optional[str] = None, consecutive_ids: Optional[bool] = None, relationship_weight_property: Optional[str] = None, - ) -> Series[Any]: + ) -> WccMutateResult: """ Executes the WCC algorithm and writes the results to the in-memory graph as node properties. @@ -82,7 +85,7 @@ def stats( seed_property: Optional[str] = None, consecutive_ids: Optional[bool] = None, relationship_weight_property: Optional[str] = None, - ) -> Series[Any]: + ) -> Series: """ Executes the WCC algorithm and returns statistics. @@ -195,7 +198,7 @@ def write( relationship_weight_property: Optional[str] = None, write_concurrency: Optional[Any] = None, write_to_result_store: Optional[bool] = None, - ) -> Series[Any]: + ) -> Series: """ Executes the WCC algorithm and writes the results to the Neo4j database. @@ -240,3 +243,15 @@ def write( Algorithm metrics and statistics """ pass + + +@dataclass(frozen=True, repr=True) +class WccMutateResult: + component_count: int + component_distribution: dict[str, Any] + pre_processing_millis: int + compute_millis: int + post_processing_millis: int + mutate_millis: int + node_properties_written: int + configuration: dict[str, Any] diff --git a/graphdatascience/procedure_surface/cypher/wcc_proc_runner.py b/graphdatascience/procedure_surface/cypher/wcc_proc_runner.py index eb02ff741..767d8f436 100644 --- a/graphdatascience/procedure_surface/cypher/wcc_proc_runner.py +++ b/graphdatascience/procedure_surface/cypher/wcc_proc_runner.py @@ -5,7 +5,7 @@ from ...call_parameters import CallParameters from ...graph.graph_object import Graph from ...query_runner.query_runner import QueryRunner -from ..api.wcc_endpoints import WccEndpoints +from ..api.wcc_endpoints import WccEndpoints, WccMutateResult class WccCypherEndpoints(WccEndpoints): @@ -32,7 +32,7 @@ def mutate( seed_property: Optional[str] = None, consecutive_ids: Optional[bool] = None, relationship_weight_property: Optional[str] = None, - ) -> Series[Any]: + ) -> WccMutateResult: # Build configuration dictionary from parameters config: dict[str, Any] = { "mutateProperty": mutate_property, @@ -66,7 +66,17 @@ def mutate( params = CallParameters(graph_name=G.name(), config=config) params.ensure_job_id_in_config() - return self._query_runner.call_procedure(endpoint="gds.wcc.mutate", params=params).squeeze() # type: ignore + cypher_result = self._query_runner.call_procedure(endpoint="gds.wcc.mutate", params=params).squeeze() + + return WccMutateResult( + cypher_result["componentCount"], + cypher_result["componentDistribution"], + cypher_result["preProcessingMillis"], + cypher_result["computeMillis"], + cypher_result["postProcessingMillis"], + cypher_result["mutateMillis"], + cypher_result["nodePropertiesWritten"], + ) def stats( self, @@ -82,7 +92,7 @@ def stats( seed_property: Optional[str] = None, consecutive_ids: Optional[bool] = None, relationship_weight_property: Optional[str] = None, - ) -> Series[Any]: + ) -> Series: # Build configuration dictionary from parameters config: dict[str, Any] = {} @@ -185,7 +195,7 @@ def write( relationship_weight_property: Optional[str] = None, write_concurrency: Optional[int] = None, write_to_result_store: Optional[bool] = None, - ) -> Series[Any]: + ) -> Series: # Build configuration dictionary from parameters config: dict[str, Any] = { "writeProperty": write_property, From be2985e8e1f8c3954120e3349c153e565d8f8438 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Mon, 23 Jun 2025 17:17:54 +0200 Subject: [PATCH 05/12] Introduce AuthenticatedArrowClient --- .../authenticated_arrow_client.py | 157 ++++++++++++++++++ .../tests/unit/arrow_client/__init__.py | 0 2 files changed, 157 insertions(+) create mode 100644 graphdatascience/arrow_client/authenticated_arrow_client.py create mode 100644 graphdatascience/tests/unit/arrow_client/__init__.py diff --git a/graphdatascience/arrow_client/authenticated_arrow_client.py b/graphdatascience/arrow_client/authenticated_arrow_client.py new file mode 100644 index 000000000..f7792f14e --- /dev/null +++ b/graphdatascience/arrow_client/authenticated_arrow_client.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import logging +from typing import Optional, Union, Any + +from pyarrow import __version__ as arrow_version +from pyarrow import flight +from pyarrow._flight import FlightTimedOutError, FlightUnavailableError, FlightInternalError, Action +from tenacity import retry_any, retry_if_exception_type, stop_after_delay, stop_after_attempt, wait_exponential, retry + +from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication +from graphdatascience.arrow_client.arrow_info import ArrowInfo +from graphdatascience.retry_utils.retry_config import RetryConfig +from .middleware.AuthMiddleware import AuthMiddleware, AuthFactory +from .middleware.UserAgentMiddleware import UserAgentFactory +from ..retry_utils.retry_utils import before_log +from ..version import __version__ + + +class AuthenticatedArrowClient: + + @staticmethod + def create( + arrow_info: ArrowInfo, + auth: Optional[ArrowAuthentication] = None, + encrypted: bool = False, + disable_server_verification: bool = False, + tls_root_certs: Optional[bytes] = None, + connection_string_override: Optional[str] = None, + retry_config: Optional[RetryConfig] = None, + ) -> AuthenticatedArrowClient: + connection_string: str + if connection_string_override is not None: + connection_string = connection_string_override + else: + connection_string = arrow_info.listenAddress + + host, port = connection_string.split(":") + + if retry_config is None: + retry_config = RetryConfig( + retry=retry_any( + retry_if_exception_type(FlightTimedOutError), + retry_if_exception_type(FlightUnavailableError), + retry_if_exception_type(FlightInternalError), + ), + stop=(stop_after_delay(10) | stop_after_attempt(5)), + wait=wait_exponential(multiplier=1, min=1, max=10), + ) + + return AuthenticatedArrowClient( + host, + retry_config, + int(port), + auth, + encrypted, + disable_server_verification, + tls_root_certs, + ) + + def __init__( + self, + host: str, + retry_config: RetryConfig, + port: int = 8491, + auth: Optional[Union[ArrowAuthentication, tuple[str, str]]] = None, + encrypted: bool = False, + disable_server_verification: bool = False, + tls_root_certs: Optional[bytes] = None, + user_agent: Optional[str] = None, + ): + """Creates a new GdsArrowClient instance. + + Parameters + ---------- + host: str + The host address of the GDS Arrow server + port: int + The host port of the GDS Arrow server (default is 8491) + auth: Optional[Union[ArrowAuthentication, tuple[str, str]]] + Either an implementation of ArrowAuthentication providing a pair to be used for basic authentication, or a username, password tuple + encrypted: bool + A flag that indicates whether the connection should be encrypted (default is False) + disable_server_verification: bool + A flag that disables server verification for TLS connections (default is False) + tls_root_certs: Optional[bytes] + PEM-encoded certificates that are used for the connection to the GDS Arrow Flight server + arrow_endpoint_version: + The version of the Arrow endpoint to use (default is ArrowEndpointVersion.V1) + user_agent: Optional[str] + The user agent string to use for the connection. (default is `neo4j-graphdatascience-v[VERSION] pyarrow-v[PYARROW_VERSION]) + retry_config: Optional[RetryConfig] + The retry configuration to use for the Arrow requests send by the client. + """ + self._host = host + self._port = port + self._auth = None + self._encrypted = encrypted + self._disable_server_verification = disable_server_verification + self._tls_root_certs = tls_root_certs + self._user_agent = user_agent + self._retry_config = retry_config + self._logger = logging.getLogger("gds_arrow_client") + self._retry_config = RetryConfig( + retry=retry_any( + retry_if_exception_type(FlightTimedOutError), + retry_if_exception_type(FlightUnavailableError), + retry_if_exception_type(FlightInternalError), + ), + stop=(stop_after_delay(10) | stop_after_attempt(5)), + wait=wait_exponential(multiplier=1, min=1, max=10), + ) + + if auth: + self._auth = auth + self._auth_middleware = AuthMiddleware(auth) + + self._flight_client = self._instantiate_flight_client() + + + def do_action(self, endpoint: str, payload: bytes): + return self._flight_client.do_action(Action(endpoint, payload)) + + def do_action_with_retry(self, endpoint: str, payload: bytes): + @retry( + reraise=True, + before=before_log("Send action", self._logger, logging.DEBUG), + retry=self._retry_config.retry, + stop=self._retry_config.stop, + wait=self._retry_config.wait, + ) + def run_with_retry(): + return self.do_action(endpoint, payload) + + return run_with_retry() + + def _instantiate_flight_client(self) -> flight.FlightClient: + location = ( + flight.Location.for_grpc_tls(self._host, self._port) + if self._encrypted + else flight.Location.for_grpc_tcp(self._host, self._port) + ) + client_options: dict[str, Any] = {"disable_server_verification": self._disable_server_verification} + if self._auth: + user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}" + if self._user_agent: + user_agent = self._user_agent + + client_options["middleware"] = [ + AuthFactory(self._auth_middleware), + UserAgentFactory(useragent=user_agent), + ] + if self._tls_root_certs: + client_options["tls_root_certs"] = self._tls_root_certs + return flight.FlightClient(location, **client_options) + + diff --git a/graphdatascience/tests/unit/arrow_client/__init__.py b/graphdatascience/tests/unit/arrow_client/__init__.py new file mode 100644 index 000000000..e69de29bb From 877427a2eb439ab7776963babc4f80d7204e560e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Mon, 23 Jun 2025 17:18:20 +0200 Subject: [PATCH 06/12] Implement Arrow based wcc endpoints --- .../authenticated_arrow_client.py | 101 ++++++-- graphdatascience/arrow_client/data_mapper.py | 55 ++++ graphdatascience/arrow_client/v2/api_types.py | 17 ++ .../arrow_client/v2/job_client.py | 64 +++++ .../arrow_client/v2/mutation_client.py | 16 ++ .../arrow_client/v2/write_back_client.py | 49 ++++ .../procedure_surface/api/wcc_endpoints.py | 37 ++- .../procedure_surface/arrow/__init__.py | 0 .../arrow/wcc_arrow_endpoints.py | 236 ++++++++++++++++++ .../cypher/wcc_proc_runner.py | 197 ++++++++------- .../unit/arrow_client/test_data_mapper.py | 53 ++++ 11 files changed, 705 insertions(+), 120 deletions(-) create mode 100644 graphdatascience/arrow_client/data_mapper.py create mode 100644 graphdatascience/arrow_client/v2/api_types.py create mode 100644 graphdatascience/arrow_client/v2/job_client.py create mode 100644 graphdatascience/arrow_client/v2/mutation_client.py create mode 100644 graphdatascience/arrow_client/v2/write_back_client.py create mode 100644 graphdatascience/procedure_surface/arrow/__init__.py create mode 100644 graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py create mode 100644 graphdatascience/tests/unit/arrow_client/test_data_mapper.py diff --git a/graphdatascience/arrow_client/authenticated_arrow_client.py b/graphdatascience/arrow_client/authenticated_arrow_client.py index f7792f14e..61216de5f 100644 --- a/graphdatascience/arrow_client/authenticated_arrow_client.py +++ b/graphdatascience/arrow_client/authenticated_arrow_client.py @@ -1,33 +1,41 @@ from __future__ import annotations import logging -from typing import Optional, Union, Any +from dataclasses import dataclass +from typing import Any, Optional, Union from pyarrow import __version__ as arrow_version from pyarrow import flight -from pyarrow._flight import FlightTimedOutError, FlightUnavailableError, FlightInternalError, Action -from tenacity import retry_any, retry_if_exception_type, stop_after_delay, stop_after_attempt, wait_exponential, retry +from pyarrow._flight import ( + Action, + FlightInternalError, + FlightStreamReader, + FlightTimedOutError, + FlightUnavailableError, + Ticket, +) +from tenacity import retry, retry_any, retry_if_exception_type, stop_after_attempt, stop_after_delay, wait_exponential from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication from graphdatascience.arrow_client.arrow_info import ArrowInfo from graphdatascience.retry_utils.retry_config import RetryConfig -from .middleware.AuthMiddleware import AuthMiddleware, AuthFactory -from .middleware.UserAgentMiddleware import UserAgentFactory + from ..retry_utils.retry_utils import before_log from ..version import __version__ +from .middleware.AuthMiddleware import AuthFactory, AuthMiddleware +from .middleware.UserAgentMiddleware import UserAgentFactory class AuthenticatedArrowClient: - @staticmethod def create( - arrow_info: ArrowInfo, - auth: Optional[ArrowAuthentication] = None, - encrypted: bool = False, - disable_server_verification: bool = False, - tls_root_certs: Optional[bytes] = None, - connection_string_override: Optional[str] = None, - retry_config: Optional[RetryConfig] = None, + arrow_info: ArrowInfo, + auth: Optional[ArrowAuthentication] = None, + encrypted: bool = False, + disable_server_verification: bool = False, + tls_root_certs: Optional[bytes] = None, + connection_string_override: Optional[str] = None, + retry_config: Optional[RetryConfig] = None, ) -> AuthenticatedArrowClient: connection_string: str if connection_string_override is not None: @@ -59,15 +67,15 @@ def create( ) def __init__( - self, - host: str, - retry_config: RetryConfig, - port: int = 8491, - auth: Optional[Union[ArrowAuthentication, tuple[str, str]]] = None, - encrypted: bool = False, - disable_server_verification: bool = False, - tls_root_certs: Optional[bytes] = None, - user_agent: Optional[str] = None, + self, + host: str, + retry_config: RetryConfig, + port: int = 8491, + auth: Optional[Union[ArrowAuthentication, tuple[str, str]]] = None, + encrypted: bool = False, + disable_server_verification: bool = False, + tls_root_certs: Optional[bytes] = None, + user_agent: Optional[str] = None, ): """Creates a new GdsArrowClient instance. @@ -85,8 +93,6 @@ def __init__( A flag that disables server verification for TLS connections (default is False) tls_root_certs: Optional[bytes] PEM-encoded certificates that are used for the connection to the GDS Arrow Flight server - arrow_endpoint_version: - The version of the Arrow endpoint to use (default is ArrowEndpointVersion.V1) user_agent: Optional[str] The user agent string to use for the connection. (default is `neo4j-graphdatascience-v[VERSION] pyarrow-v[PYARROW_VERSION]) retry_config: Optional[RetryConfig] @@ -117,6 +123,48 @@ def __init__( self._flight_client = self._instantiate_flight_client() + def connection_info(self) -> ConnectionInfo: + """ + Returns the host and port of the GDS Arrow server. + + Returns + ------- + tuple[str, int] + the host and port of the GDS Arrow server + """ + return ConnectionInfo(self._host, self._port, self._encrypted) + + def request_token(self) -> Optional[str]: + """ + Requests a token from the server and returns it. + + Returns + ------- + Optional[str] + a token from the server and returns it. + """ + + @retry( + reraise=True, + before=before_log("Request token", self._logger, logging.DEBUG), + retry=self._retry_config.retry, + stop=self._retry_config.stop, + wait=self._retry_config.wait, + ) + def auth_with_retry() -> None: + client = self._flight_client + if self._auth: + auth_pair = self._auth.auth_pair() + client.authenticate_basic_token(auth_pair[0], auth_pair[1]) + + if self._auth: + auth_with_retry() + return self._auth_middleware.token() + else: + return "IGNORED" + + def get_stream(self, ticket: Ticket) -> FlightStreamReader: + return self._flight_client.do_get(ticket) def do_action(self, endpoint: str, payload: bytes): return self._flight_client.do_action(Action(endpoint, payload)) @@ -155,3 +203,8 @@ def _instantiate_flight_client(self) -> flight.FlightClient: return flight.FlightClient(location, **client_options) +@dataclass +class ConnectionInfo: + host: str + port: int + encrypted: bool diff --git a/graphdatascience/arrow_client/data_mapper.py b/graphdatascience/arrow_client/data_mapper.py new file mode 100644 index 000000000..f09cf6ca6 --- /dev/null +++ b/graphdatascience/arrow_client/data_mapper.py @@ -0,0 +1,55 @@ +import dataclasses +import json +from dataclasses import fields +from typing import Any, Dict, Iterator, Type, TypeVar + +from pyarrow._flight import Result + + +class DataMapper: + T = TypeVar("T") + + @staticmethod + def deserialize_single(input_stream: Iterator[Result], cls: Type[T]) -> T: + rows = DataMapper.deserialize(input_stream, cls) + + if len(rows) != 1: + raise ValueError(f"Expected exactly one row, got {len(rows)}") + + return rows[0] + + @staticmethod + def deserialize(input_stream, cls: Type[T]) -> list[T]: + def deserialize_row(row: Any): + result_dicts = json.loads(row.body.to_pybytes().decode()) + if cls == Dict: + return result_dicts + return DataMapper.dict_to_dataclass(result_dicts, cls) + + return [deserialize_row(row) for row in list(input_stream)] + + @staticmethod + def dict_to_dataclass(data: Dict[str, Any], cls: Type[T], strict: bool = False) -> T: + """ + Convert a dictionary to a dataclass instance with nested dataclass support. + """ + if not dataclasses.is_dataclass(cls): + raise ValueError(f"{cls} is not a dataclass") + + field_dict = {f.name: f for f in fields(cls)} + filtered_data = {} + + for key, value in data.items(): + if key in field_dict: + field = field_dict[key] + field_type = field.type + + # Handle nested dataclasses + if dataclasses.is_dataclass(field_type) and isinstance(value, dict): + filtered_data[key] = DataMapper.dict_to_dataclass(value, field_type, strict) + else: + filtered_data[key] = value + elif strict: + raise ValueError(f"Extra field '{key}' not allowed in {cls.__name__}") + + return cls(**filtered_data) diff --git a/graphdatascience/arrow_client/v2/api_types.py b/graphdatascience/arrow_client/v2/api_types.py new file mode 100644 index 000000000..de021bd00 --- /dev/null +++ b/graphdatascience/arrow_client/v2/api_types.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass + +@dataclass(frozen=True, repr=True) +class JobIdConfig: + jobId: str + +@dataclass(frozen=True, repr=True) +class JobStatus: + jobId: str + status: str + progress: float + + +@dataclass(frozen=True, repr=True) +class MutateResult: + nodePropertiesWritten: int + relationshipsWritten: int diff --git a/graphdatascience/arrow_client/v2/job_client.py b/graphdatascience/arrow_client/v2/job_client.py new file mode 100644 index 000000000..e6131bca1 --- /dev/null +++ b/graphdatascience/arrow_client/v2/job_client.py @@ -0,0 +1,64 @@ +import json +from typing import Any, Dict + +from pandas import ArrowDtype, DataFrame +from pyarrow._flight import Ticket + +from graphdatascience.arrow_client.authenticated_arrow_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.data_mapper import DataMapper +from graphdatascience.arrow_client.v2.api_types import JobIdConfig, JobStatus + +JOB_STATUS_ENDPOINT = "v2/jobs.status" +RESULTS_SUMMARY_ENDPOINT = "v2/results.summary" + + +class JobClient: + @staticmethod + def run_job_and_wait(client: AuthenticatedArrowClient, endpoint: str, config: dict[str, Any]) -> str: + job_id = JobClient.run_job(client, endpoint, config) + JobClient.wait_for_job(client, job_id) + return job_id + + @staticmethod + def run_job(client: AuthenticatedArrowClient, endpoint: str, config: dict[str, Any]) -> str: + encoded_config = json.dumps(config).encode("utf-8") + res = client.do_action_with_retry(endpoint, encoded_config) + return DataMapper.deserialize_single(res, JobIdConfig).jobId + + @staticmethod + def wait_for_job(client: AuthenticatedArrowClient, job_id: str): + while True: + job_id_config = {"jobId": job_id} + encoded_config = json.dumps(job_id_config).encode("utf-8") + + arrow_res = client.do_action_with_retry(JOB_STATUS_ENDPOINT, encoded_config) + job_status = DataMapper.deserialize_single(arrow_res, JobStatus) + if job_status.status == "Done": + break + + @staticmethod + def get_summary(client: AuthenticatedArrowClient, job_id: str) -> dict[str, Any]: + job_id_config = {"jobId": job_id} + encoded_config = json.dumps(job_id_config).encode("utf-8") + + res = client.do_action_with_retry(RESULTS_SUMMARY_ENDPOINT, encoded_config) + return DataMapper.deserialize_single(res, Dict) + + @staticmethod + def stream_results(client: AuthenticatedArrowClient, job_id: str) -> DataFrame: + job_id_config = {"jobId": job_id} + encoded_config = json.dumps(job_id_config).encode("utf-8") + + res = client.do_action_with_retry("v2/results.stream", encoded_config) + export_job_id = DataMapper.deserialize_single(res, JobIdConfig).jobId + + payload = { + "name": export_job_id, + "version": 1, + } + + ticket = Ticket(json.dumps(payload).encode("utf-8")) + with client.get_stream(ticket) as get: + arrow_table = get.read_all() + + return arrow_table.to_pandas(types_mapper=ArrowDtype) # type: ignore diff --git a/graphdatascience/arrow_client/v2/mutation_client.py b/graphdatascience/arrow_client/v2/mutation_client.py new file mode 100644 index 000000000..92cdf2f54 --- /dev/null +++ b/graphdatascience/arrow_client/v2/mutation_client.py @@ -0,0 +1,16 @@ +import json + +from graphdatascience.arrow_client.authenticated_arrow_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.data_mapper import DataMapper +from graphdatascience.arrow_client.v2.api_types import MutateResult + + +class MutationClient: + MUTATE_ENDPOINT = "v2/results.mutate" + + @staticmethod + def mutate_node_property(client: AuthenticatedArrowClient, job_id: str, mutate_property: str) -> MutateResult: + mutate_config = {"jobId": job_id, "mutateProperty": mutate_property} + encoded_config = json.dumps(mutate_config).encode("utf-8") + mutate_arrow_res = client.do_action_with_retry(MutationClient.MUTATE_ENDPOINT, encoded_config) + return DataMapper.deserialize_single(mutate_arrow_res, MutateResult) diff --git a/graphdatascience/arrow_client/v2/write_back_client.py b/graphdatascience/arrow_client/v2/write_back_client.py new file mode 100644 index 000000000..bbfec3e27 --- /dev/null +++ b/graphdatascience/arrow_client/v2/write_back_client.py @@ -0,0 +1,49 @@ +import time +from typing import Any, Optional + +from graphdatascience import QueryRunner +from graphdatascience.arrow_client.authenticated_arrow_client import AuthenticatedArrowClient +from graphdatascience.call_parameters import CallParameters +from graphdatascience.query_runner.protocol.write_protocols import WriteProtocol +from graphdatascience.query_runner.termination_flag import TerminationFlagNoop +from graphdatascience.session.dbms.protocol_resolver import ProtocolVersionResolver + + +class WriteBackClient: + def __init__(self, arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner): + self._arrow_client = arrow_client + self._query_runner = query_runner + + protocol_version = ProtocolVersionResolver(query_runner).resolve() + self._write_protocol = WriteProtocol.select(protocol_version) + + # TODO: Add progress logging + # TODO: Support setting custom writeProperties and relationshipTypes + def write(self, graph_name: str, job_id: str, concurrency: Optional[int]) -> int: + arrow_config = self._arrow_configuration() + + configuration = {} + if concurrency is not None: + configuration["concurrency"] = concurrency + + write_back_params = CallParameters( + graphName=graph_name, + jobId=job_id, + arrowConfiguration=arrow_config, + configuration=configuration, + ) + + start_time = time.time() + + self._write_protocol.run_write_back(self._query_runner, write_back_params, None, TerminationFlagNoop()) + + return int((time.time() - start_time) * 1000) + + def _arrow_configuration(self) -> dict[str, Any]: + host, port, encrypted = self._arrow_client.connection_info() + token = self._arrow_client.request_token() + if token is None: + token = "IGNORED" + arrow_config = {"host": host, "port": port, "token": token, "encrypted": encrypted()} + + return arrow_config diff --git a/graphdatascience/procedure_surface/api/wcc_endpoints.py b/graphdatascience/procedure_surface/api/wcc_endpoints.py index a1cb7a51d..5c32b822b 100644 --- a/graphdatascience/procedure_surface/api/wcc_endpoints.py +++ b/graphdatascience/procedure_surface/api/wcc_endpoints.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Any, List, Optional -from pandas import DataFrame, Series +from pandas import DataFrame from ...graph.graph_object import Graph @@ -65,7 +65,7 @@ def mutate( Returns ------- - Series + WccMutateResult Algorithm metrics and statistics """ pass @@ -85,7 +85,7 @@ def stats( seed_property: Optional[str] = None, consecutive_ids: Optional[bool] = None, relationship_weight_property: Optional[str] = None, - ) -> Series: + ) -> WccStatsResult: """ Executes the WCC algorithm and returns statistics. @@ -118,7 +118,7 @@ def stats( Returns ------- - Series + WccStatsResult Algorithm metrics and statistics """ pass @@ -197,8 +197,7 @@ def write( consecutive_ids: Optional[bool] = None, relationship_weight_property: Optional[str] = None, write_concurrency: Optional[Any] = None, - write_to_result_store: Optional[bool] = None, - ) -> Series: + ) -> WccWriteResult: """ Executes the WCC algorithm and writes the results to the Neo4j database. @@ -234,12 +233,10 @@ def write( The property name that contains weight write_concurrency : Optional[Any], default=None The number of concurrent threads during the write phase - write_to_result_store : Optional[bool], default=None - Whether to write the results to the result store Returns ------- - Series + WccWriteResult Algorithm metrics and statistics """ pass @@ -255,3 +252,25 @@ class WccMutateResult: mutate_millis: int node_properties_written: int configuration: dict[str, Any] + + +@dataclass(frozen=True, repr=True) +class WccStatsResult: + component_count: int + component_distribution: dict[str, Any] + pre_processing_millis: int + compute_millis: int + post_processing_millis: int + configuration: dict[str, Any] + + +@dataclass(frozen=True, repr=True) +class WccWriteResult: + component_count: int + component_distribution: dict[str, Any] + pre_processing_millis: int + compute_millis: int + write_millis: int + post_processing_millis: int + node_properties_written: int + configuration: dict[str, Any] diff --git a/graphdatascience/procedure_surface/arrow/__init__.py b/graphdatascience/procedure_surface/arrow/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py new file mode 100644 index 000000000..63ac1df55 --- /dev/null +++ b/graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py @@ -0,0 +1,236 @@ +from typing import Any, List, Optional + +from pandas import DataFrame + +from ...arrow_client.authenticated_arrow_client import AuthenticatedArrowClient +from ...arrow_client.v2.job_client import JobClient +from ...arrow_client.v2.mutation_client import MutationClient +from ...arrow_client.v2.write_back_client import WriteBackClient +from ...graph.graph_object import Graph +from ..api.wcc_endpoints import WccEndpoints, WccMutateResult, WccStatsResult, WccWriteResult + +WCC_ENDPOINT = "v2/community.wcc" + + +class WccArrowEndpoints(WccEndpoints): + def __init__(self, arrow_client: AuthenticatedArrowClient, write_back_client: Optional[WriteBackClient]): + self._arrow_client = arrow_client + self._write_back_client = write_back_client + + def mutate( + self, + G: Graph, + mutate_property: str, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[int] = None, + job_id: Optional[str] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + ) -> WccMutateResult: + config = self._build_configuration( + G, + concurrency, + consecutive_ids, + job_id, + log_progress, + None, + node_labels, + relationship_types, + relationship_weight_property, + seed_property, + sudo, + threshold, + ) + + job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config) + + mutate_result = MutationClient.mutate_node_property(self._arrow_client, job_id, mutate_property) + computation_result = JobClient.get_summary(self._arrow_client, job_id) + + return WccMutateResult( + computation_result["componentCount"], + computation_result["componentDistribution"], + computation_result["preProcessingMillis"], + computation_result["computeMillis"], + computation_result["postProcessingMillis"], + 0, + mutate_result.nodePropertiesWritten, + computation_result["configuration"], + ) + + def stats( + self, + G: Graph, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[int] = None, + job_id: Optional[str] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + ) -> WccStatsResult: + config = self._build_configuration( + G, + concurrency, + consecutive_ids, + job_id, + log_progress, + None, + node_labels, + relationship_types, + relationship_weight_property, + seed_property, + sudo, + threshold, + ) + + job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config) + computation_result = JobClient.get_summary(self._arrow_client, job_id) + + return WccStatsResult( + computation_result["componentCount"], + computation_result["componentDistribution"], + computation_result["preProcessingMillis"], + computation_result["computeMillis"], + computation_result["postProcessingMillis"], + computation_result["configuration"], + ) + + def stream( + self, + G: Graph, + min_component_size: Optional[int] = None, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[int] = None, + job_id: Optional[str] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + ) -> DataFrame: + config = self._build_configuration( + G, + concurrency, + consecutive_ids, + job_id, + log_progress, + min_component_size, + node_labels, + relationship_types, + relationship_weight_property, + seed_property, + sudo, + threshold, + ) + + job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config) + return JobClient.stream_results(self._arrow_client, job_id) + + def write( + self, + G: Graph, + write_property: str, + min_component_size: Optional[int] = None, + threshold: Optional[float] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + sudo: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + concurrency: Optional[int] = None, + job_id: Optional[str] = None, + seed_property: Optional[str] = None, + consecutive_ids: Optional[bool] = None, + relationship_weight_property: Optional[str] = None, + write_concurrency: Optional[int] = None, + ) -> WccWriteResult: + config = self._build_configuration( + G, + concurrency, + consecutive_ids, + job_id, + log_progress, + min_component_size, + node_labels, + relationship_types, + relationship_weight_property, + seed_property, + sudo, + threshold, + ) + + job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config) + computation_result = JobClient.get_summary(self._arrow_client, job_id) + + write_millis = self._write_back_client.write( + G.name(), job_id, write_concurrency if write_concurrency is not None else concurrency + ) + + return WccWriteResult( + computation_result["componentCount"], + computation_result["componentDistribution"], + computation_result["preProcessingMillis"], + computation_result["computeMillis"], + write_millis, + computation_result["postProcessingMillis"], + computation_result["nodePropertiesWritten"], + computation_result["configuration"], + ) + + @staticmethod + def _build_configuration( + G: Graph, + concurrency: Optional[int], + consecutive_ids: Optional[bool], + job_id: Optional[str], + log_progress: Optional[bool], + min_component_size: Optional[int], + node_labels: Optional[List[str]], + relationship_types: Optional[List[str]], + relationship_weight_property: Optional[str], + seed_property: Optional[str], + sudo: Optional[bool], + threshold: Optional[float], + ): + config: dict[str, Any] = { + "graphName": G.name(), + } + + if min_component_size is not None: + config["minComponentSize"] = min_component_size + if threshold is not None: + config["threshold"] = threshold + if relationship_types is not None: + config["relationshipTypes"] = relationship_types + if node_labels is not None: + config["nodeLabels"] = node_labels + if sudo is not None: + config["sudo"] = sudo + if log_progress is not None: + config["logProgress"] = log_progress + if concurrency is not None: + config["concurrency"] = concurrency + if job_id is not None: + config["jobId"] = job_id + if seed_property is not None: + config["seedProperty"] = seed_property + if consecutive_ids is not None: + config["consecutiveIds"] = consecutive_ids + if relationship_weight_property is not None: + config["relationshipWeightProperty"] = relationship_weight_property + + return config diff --git a/graphdatascience/procedure_surface/cypher/wcc_proc_runner.py b/graphdatascience/procedure_surface/cypher/wcc_proc_runner.py index 767d8f436..7449b9936 100644 --- a/graphdatascience/procedure_surface/cypher/wcc_proc_runner.py +++ b/graphdatascience/procedure_surface/cypher/wcc_proc_runner.py @@ -1,11 +1,11 @@ from typing import Any, List, Optional -from pandas import DataFrame, Series +from pandas import DataFrame from ...call_parameters import CallParameters from ...graph.graph_object import Graph from ...query_runner.query_runner import QueryRunner -from ..api.wcc_endpoints import WccEndpoints, WccMutateResult +from ..api.wcc_endpoints import WccEndpoints, WccMutateResult, WccStatsResult, WccWriteResult class WccCypherEndpoints(WccEndpoints): @@ -38,29 +38,21 @@ def mutate( "mutateProperty": mutate_property, } - # Add optional parameters - if threshold is not None: - config["threshold"] = threshold - if relationship_types is not None: - config["relationshipTypes"] = relationship_types - if node_labels is not None: - config["nodeLabels"] = node_labels - if sudo is not None: - config["sudo"] = sudo - if log_progress is not None: - config["logProgress"] = log_progress - if username is not None: - config["username"] = username - if concurrency is not None: - config["concurrency"] = concurrency - if job_id is not None: - config["jobId"] = job_id - if seed_property is not None: - config["seedProperty"] = seed_property - if consecutive_ids is not None: - config["consecutiveIds"] = consecutive_ids - if relationship_weight_property is not None: - config["relationshipWeightProperty"] = relationship_weight_property + self._create_procedure_config( + config, + concurrency, + consecutive_ids, + job_id, + log_progress, + None, + node_labels, + relationship_types, + relationship_weight_property, + seed_property, + sudo, + threshold, + username, + ) # Run procedure and return results params = CallParameters(graph_name=G.name(), config=config) @@ -76,6 +68,7 @@ def mutate( cypher_result["postProcessingMillis"], cypher_result["mutateMillis"], cypher_result["nodePropertiesWritten"], + cypher_result["configuration"], ) def stats( @@ -92,39 +85,40 @@ def stats( seed_property: Optional[str] = None, consecutive_ids: Optional[bool] = None, relationship_weight_property: Optional[str] = None, - ) -> Series: + ) -> WccStatsResult: # Build configuration dictionary from parameters config: dict[str, Any] = {} - # Add optional parameters - if threshold is not None: - config["threshold"] = threshold - if relationship_types is not None: - config["relationshipTypes"] = relationship_types - if node_labels is not None: - config["nodeLabels"] = node_labels - if sudo is not None: - config["sudo"] = sudo - if log_progress is not None: - config["logProgress"] = log_progress - if username is not None: - config["username"] = username - if concurrency is not None: - config["concurrency"] = concurrency - if job_id is not None: - config["jobId"] = job_id - if seed_property is not None: - config["seedProperty"] = seed_property - if consecutive_ids is not None: - config["consecutiveIds"] = consecutive_ids - if relationship_weight_property is not None: - config["relationshipWeightProperty"] = relationship_weight_property + self._create_procedure_config( + config, + concurrency, + consecutive_ids, + job_id, + log_progress, + None, + node_labels, + relationship_types, + relationship_weight_property, + seed_property, + sudo, + threshold, + username, + ) # Run procedure and return results params = CallParameters(graph_name=G.name(), config=config) params.ensure_job_id_in_config() - return self._query_runner.call_procedure(endpoint="gds.wcc.stats", params=params).squeeze() # type: ignore + cypher_result = self._query_runner.call_procedure(endpoint="gds.wcc.stats", params=params).squeeze() # type: ignore + + return WccStatsResult( + cypher_result["componentCount"], + cypher_result["componentDistribution"], + cypher_result["preProcessingMillis"], + cypher_result["computeMillis"], + cypher_result["postProcessingMillis"], + cypher_result["configuration"], + ) def stream( self, @@ -145,31 +139,21 @@ def stream( # Build configuration dictionary from parameters config: dict[str, Any] = {} - # Add optional parameters - if min_component_size is not None: - config["minComponentSize"] = min_component_size - if threshold is not None: - config["threshold"] = threshold - if relationship_types is not None: - config["relationshipTypes"] = relationship_types - if node_labels is not None: - config["nodeLabels"] = node_labels - if sudo is not None: - config["sudo"] = sudo - if log_progress is not None: - config["logProgress"] = log_progress - if username is not None: - config["username"] = username - if concurrency is not None: - config["concurrency"] = concurrency - if job_id is not None: - config["jobId"] = job_id - if seed_property is not None: - config["seedProperty"] = seed_property - if consecutive_ids is not None: - config["consecutiveIds"] = consecutive_ids - if relationship_weight_property is not None: - config["relationshipWeightProperty"] = relationship_weight_property + self._create_procedure_config( + config, + concurrency, + consecutive_ids, + job_id, + log_progress, + min_component_size, + node_labels, + relationship_types, + relationship_weight_property, + seed_property, + sudo, + threshold, + username, + ) # Run procedure and return results params = CallParameters(graph_name=G.name(), config=config) @@ -194,13 +178,62 @@ def write( consecutive_ids: Optional[bool] = None, relationship_weight_property: Optional[str] = None, write_concurrency: Optional[int] = None, - write_to_result_store: Optional[bool] = None, - ) -> Series: + ) -> WccWriteResult: # Build configuration dictionary from parameters config: dict[str, Any] = { "writeProperty": write_property, } + self._create_procedure_config( + config, + concurrency, + consecutive_ids, + job_id, + log_progress, + min_component_size, + node_labels, + relationship_types, + relationship_weight_property, + seed_property, + sudo, + threshold, + username, + ) + + if write_concurrency is not None: + config["writeConcurrency"] = write_concurrency + + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + result = self._query_runner.call_procedure(endpoint="gds.wcc.write", params=params).squeeze() # type: ignore + return WccWriteResult( + result["componentCount"], + result["componentDistribution"], + result["preProcessingMillis"], + result["computeMillis"], + result["writeMillis"], + result["postProcessingMillis"], + result["nodePropertiesWritten"], + result["configuration"], + ) + + @staticmethod + def _create_procedure_config( + config: dict[str, Any], + concurrency: Optional[int], + consecutive_ids: Optional[bool], + job_id: Optional[str], + log_progress: Optional[bool], + min_component_size: Optional[int], + node_labels: Optional[List[str]], + relationship_types: Optional[List[str]], + relationship_weight_property: Optional[str], + seed_property: Optional[str], + sudo: Optional[bool], + threshold: Optional[float], + username: Optional[str], + ): # Add optional parameters if min_component_size is not None: config["minComponentSize"] = min_component_size @@ -226,13 +259,3 @@ def write( config["consecutiveIds"] = consecutive_ids if relationship_weight_property is not None: config["relationshipWeightProperty"] = relationship_weight_property - if write_concurrency is not None: - config["writeConcurrency"] = write_concurrency - if write_to_result_store is not None: - config["writeToResultStore"] = write_to_result_store - - # Run procedure and return results - params = CallParameters(graph_name=G.name(), config=config) - params.ensure_job_id_in_config() - - return self._query_runner.call_procedure(endpoint="gds.wcc.write", params=params).squeeze() # type: ignore diff --git a/graphdatascience/tests/unit/arrow_client/test_data_mapper.py b/graphdatascience/tests/unit/arrow_client/test_data_mapper.py new file mode 100644 index 000000000..cc502f516 --- /dev/null +++ b/graphdatascience/tests/unit/arrow_client/test_data_mapper.py @@ -0,0 +1,53 @@ +from dataclasses import dataclass +from typing import Dict, Any + +import pytest + +from graphdatascience.arrow_client.data_mapper import DataMapper + + +@dataclass +class NestedDataclass: + nested_field: int + + +@dataclass +class ExampleDataclass: + field_one: str + field_two: int + nested: NestedDataclass + + +@pytest.mark.parametrize( + "input_data, expected_output", + [ + ( + { + "field_one": "test", + "field_two": 123, + "nested": {"nested_field": 456} + }, + ExampleDataclass("test", 123, NestedDataclass(456)), + ), + ], +) +def test_dict_to_dataclass(input_data: Dict[str, Any], expected_output: ExampleDataclass): + result = DataMapper.dict_to_dataclass(input_data, ExampleDataclass) + assert result == expected_output + + +def test_dict_to_dataclass_strict_mode_rejects_extra_fields(): + input_data = { + "field_one": "test", + "field_two": 123, + "nested": {"nested_field": 456}, + "extra_field": "not_allowed" + } + + with pytest.raises(ValueError, match="Extra field 'extra_field' not allowed in ExampleDataclass"): + DataMapper.dict_to_dataclass(input_data, ExampleDataclass, strict=True) + + +def test_dict_to_dataclass_non_dataclass_error(): + with pytest.raises(ValueError, match="is not a dataclass"): + DataMapper.dict_to_dataclass({"key": "value"}, int) From 8554f63659698ce82dfd146f3bd0d1bd55a213f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Tue, 24 Jun 2025 22:14:09 +0200 Subject: [PATCH 07/12] Fix formatting --- .../arrow_client/middleware/AuthMiddleware.py | 4 ++-- .../middleware/UserAgentMiddleware.py | 4 +++- graphdatascience/arrow_client/v2/api_types.py | 2 ++ graphdatascience/graph_data_science.py | 2 +- .../query_runner/arrow_query_runner.py | 2 +- .../query_runner/gds_arrow_client.py | 2 +- .../query_runner/protocol/project_protocols.py | 2 +- .../session/aura_graph_data_science.py | 4 ++-- graphdatascience/tests/integration/conftest.py | 2 +- .../tests/unit/arrow_client/test_data_mapper.py | 17 ++++------------- graphdatascience/tests/unit/conftest.py | 4 ++-- graphdatascience/tests/unit/test_init.py | 2 +- 12 files changed, 21 insertions(+), 26 deletions(-) diff --git a/graphdatascience/arrow_client/middleware/AuthMiddleware.py b/graphdatascience/arrow_client/middleware/AuthMiddleware.py index e350f7d38..0acb6e269 100644 --- a/graphdatascience/arrow_client/middleware/AuthMiddleware.py +++ b/graphdatascience/arrow_client/middleware/AuthMiddleware.py @@ -2,7 +2,7 @@ import base64 import time -from typing import Optional, Any +from typing import Any, Optional from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory @@ -60,4 +60,4 @@ def sending_headers(self) -> dict[str, str]: auth_token = f"{auth_pair[0]}:{auth_pair[1]}" auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII") # There seems to be a bug, `authorization` must be lower key - return {"authorization": auth_token} \ No newline at end of file + return {"authorization": auth_token} diff --git a/graphdatascience/arrow_client/middleware/UserAgentMiddleware.py b/graphdatascience/arrow_client/middleware/UserAgentMiddleware.py index b6313502f..704713bb4 100644 --- a/graphdatascience/arrow_client/middleware/UserAgentMiddleware.py +++ b/graphdatascience/arrow_client/middleware/UserAgentMiddleware.py @@ -4,6 +4,7 @@ from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory + class UserAgentFactory(ClientMiddlewareFactory): # type: ignore def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -12,6 +13,7 @@ def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None: def start_call(self, info: Any) -> ClientMiddleware: return self._middleware + class UserAgentMiddleware(ClientMiddleware): # type: ignore def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -21,4 +23,4 @@ def sending_headers(self) -> dict[str, str]: return {"x-gds-user-agent": self._useragent} def received_headers(self, headers: dict[str, Any]) -> None: - pass \ No newline at end of file + pass diff --git a/graphdatascience/arrow_client/v2/api_types.py b/graphdatascience/arrow_client/v2/api_types.py index de021bd00..4390ec57d 100644 --- a/graphdatascience/arrow_client/v2/api_types.py +++ b/graphdatascience/arrow_client/v2/api_types.py @@ -1,9 +1,11 @@ from dataclasses import dataclass + @dataclass(frozen=True, repr=True) class JobIdConfig: jobId: str + @dataclass(frozen=True, repr=True) class JobStatus: jobId: str diff --git a/graphdatascience/graph_data_science.py b/graphdatascience/graph_data_science.py index f8e2e0eec..008008b1a 100644 --- a/graphdatascience/graph_data_science.py +++ b/graphdatascience/graph_data_science.py @@ -9,6 +9,7 @@ from pandas import DataFrame from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication +from graphdatascience.arrow_client.arrow_info import ArrowInfo from graphdatascience.procedure_surface.api.wcc_endpoints import WccEndpoints from graphdatascience.procedure_surface.cypher.wcc_proc_runner import WccCypherEndpoints @@ -16,7 +17,6 @@ from .endpoints import AlphaEndpoints, BetaEndpoints, DirectEndpoints from .error.uncallable_namespace import UncallableNamespace from .graph.graph_proc_runner import GraphProcRunner -from graphdatascience.arrow_client.arrow_info import ArrowInfo from .query_runner.arrow_query_runner import ArrowQueryRunner from .query_runner.neo4j_query_runner import Neo4jQueryRunner from .query_runner.query_runner import QueryRunner diff --git a/graphdatascience/query_runner/arrow_query_runner.py b/graphdatascience/query_runner/arrow_query_runner.py index 975eaed25..c3ed819a2 100644 --- a/graphdatascience/query_runner/arrow_query_runner.py +++ b/graphdatascience/query_runner/arrow_query_runner.py @@ -6,11 +6,11 @@ from pandas import DataFrame from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication +from graphdatascience.arrow_client.arrow_info import ArrowInfo from graphdatascience.query_runner.query_mode import QueryMode from graphdatascience.retry_utils.retry_config import RetryConfig from ..call_parameters import CallParameters -from graphdatascience.arrow_client.arrow_info import ArrowInfo from ..server_version.server_version import ServerVersion from .arrow_graph_constructor import ArrowGraphConstructor from .gds_arrow_client import GdsArrowClient diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index c25efdf2e..cd60e2a20 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -36,13 +36,13 @@ ) from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication, UsernamePasswordAuthentication +from graphdatascience.arrow_client.arrow_info import ArrowInfo from graphdatascience.retry_utils.retry_config import RetryConfig from graphdatascience.retry_utils.retry_utils import before_log from ..semantic_version.semantic_version import SemanticVersion from ..version import __version__ from .arrow_endpoint_version import ArrowEndpointVersion -from graphdatascience.arrow_client.arrow_info import ArrowInfo class GdsArrowClient: diff --git a/graphdatascience/query_runner/protocol/project_protocols.py b/graphdatascience/query_runner/protocol/project_protocols.py index 46eb802b8..6f84f1ea3 100644 --- a/graphdatascience/query_runner/protocol/project_protocols.py +++ b/graphdatascience/query_runner/protocol/project_protocols.py @@ -25,7 +25,7 @@ def project_params( def run_projection( self, query_runner: QueryRunner, - endpoint: str, + query: str, params: CallParameters, terminationFlag: TerminationFlag, yields: Optional[list[str]] = None, diff --git a/graphdatascience/session/aura_graph_data_science.py b/graphdatascience/session/aura_graph_data_science.py index 51a4ffde8..f35690d87 100644 --- a/graphdatascience/session/aura_graph_data_science.py +++ b/graphdatascience/session/aura_graph_data_science.py @@ -5,6 +5,8 @@ from pandas import DataFrame from graphdatascience import QueryRunner, ServerVersion +from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication +from graphdatascience.arrow_client.arrow_info import ArrowInfo from graphdatascience.call_builder import IndirectCallBuilder from graphdatascience.endpoints import ( AlphaRemoteEndpoints, @@ -13,8 +15,6 @@ ) from graphdatascience.error.uncallable_namespace import UncallableNamespace from graphdatascience.graph.graph_remote_proc_runner import GraphRemoteProcRunner -from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication -from graphdatascience.arrow_client.arrow_info import ArrowInfo from graphdatascience.query_runner.arrow_query_runner import ArrowQueryRunner from graphdatascience.query_runner.gds_arrow_client import GdsArrowClient from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner diff --git a/graphdatascience/tests/integration/conftest.py b/graphdatascience/tests/integration/conftest.py index e48d5450b..439e1eea0 100644 --- a/graphdatascience/tests/integration/conftest.py +++ b/graphdatascience/tests/integration/conftest.py @@ -6,8 +6,8 @@ import pytest from neo4j import Driver, GraphDatabase -from graphdatascience.graph_data_science import GraphDataScience from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication +from graphdatascience.graph_data_science import GraphDataScience from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner from graphdatascience.server_version.server_version import ServerVersion from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience diff --git a/graphdatascience/tests/unit/arrow_client/test_data_mapper.py b/graphdatascience/tests/unit/arrow_client/test_data_mapper.py index cc502f516..0cfb574d7 100644 --- a/graphdatascience/tests/unit/arrow_client/test_data_mapper.py +++ b/graphdatascience/tests/unit/arrow_client/test_data_mapper.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, Any +from typing import Any, Dict import pytest @@ -22,12 +22,8 @@ class ExampleDataclass: "input_data, expected_output", [ ( - { - "field_one": "test", - "field_two": 123, - "nested": {"nested_field": 456} - }, - ExampleDataclass("test", 123, NestedDataclass(456)), + {"field_one": "test", "field_two": 123, "nested": {"nested_field": 456}}, + ExampleDataclass("test", 123, NestedDataclass(456)), ), ], ) @@ -37,12 +33,7 @@ def test_dict_to_dataclass(input_data: Dict[str, Any], expected_output: ExampleD def test_dict_to_dataclass_strict_mode_rejects_extra_fields(): - input_data = { - "field_one": "test", - "field_two": 123, - "nested": {"nested_field": 456}, - "extra_field": "not_allowed" - } + input_data = {"field_one": "test", "field_two": 123, "nested": {"nested_field": 456}, "extra_field": "not_allowed"} with pytest.raises(ValueError, match="Extra field 'extra_field' not allowed in ExampleDataclass"): DataMapper.dict_to_dataclass(input_data, ExampleDataclass, strict=True) diff --git a/graphdatascience/tests/unit/conftest.py b/graphdatascience/tests/unit/conftest.py index 9ce82b4e2..ca88ff228 100644 --- a/graphdatascience/tests/unit/conftest.py +++ b/graphdatascience/tests/unit/conftest.py @@ -8,10 +8,10 @@ from pytest_mock import MockerFixture from graphdatascience import QueryRunner -from graphdatascience.call_parameters import CallParameters -from graphdatascience.graph_data_science import GraphDataScience from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication from graphdatascience.arrow_client.arrow_info import ArrowInfo +from graphdatascience.call_parameters import CallParameters +from graphdatascience.graph_data_science import GraphDataScience from graphdatascience.query_runner.cypher_graph_constructor import ( CypherGraphConstructor, ) diff --git a/graphdatascience/tests/unit/test_init.py b/graphdatascience/tests/unit/test_init.py index a51beb03a..e95ac67fa 100644 --- a/graphdatascience/tests/unit/test_init.py +++ b/graphdatascience/tests/unit/test_init.py @@ -3,8 +3,8 @@ import pytest from pandas import DataFrame -from graphdatascience.graph_data_science import GraphDataScience from graphdatascience.arrow_client.arrow_info import ArrowInfo +from graphdatascience.graph_data_science import GraphDataScience from graphdatascience.server_version.server_version import ServerVersion from graphdatascience.tests.unit.conftest import CollectingQueryRunner From 235fd9e560ad3d935f47c559408e919df075138c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Fri, 27 Jun 2025 11:13:51 +0200 Subject: [PATCH 08/12] Fix Cypher wcc endpoint tests --- .../cypher/test_wcc_cypher_endpoints.py | 153 ++++++++++++++++-- 1 file changed, 136 insertions(+), 17 deletions(-) diff --git a/graphdatascience/tests/unit/procedure_surface/cypher/test_wcc_cypher_endpoints.py b/graphdatascience/tests/unit/procedure_surface/cypher/test_wcc_cypher_endpoints.py index a15204629..3f93f3eab 100644 --- a/graphdatascience/tests/unit/procedure_surface/cypher/test_wcc_cypher_endpoints.py +++ b/graphdatascience/tests/unit/procedure_surface/cypher/test_wcc_cypher_endpoints.py @@ -1,14 +1,15 @@ +import pandas as pd import pytest from graphdatascience.graph.graph_object import Graph +from graphdatascience.procedure_surface.api.wcc_endpoints import WccMutateResult, WccStatsResult, WccWriteResult from graphdatascience.procedure_surface.cypher.wcc_proc_runner import WccCypherEndpoints -from graphdatascience.server_version.server_version import ServerVersion -from graphdatascience.tests.unit.conftest import CollectingQueryRunner +from graphdatascience.tests.unit.conftest import CollectingQueryRunner, DEFAULT_SERVER_VERSION @pytest.fixture def query_runner() -> CollectingQueryRunner: - return CollectingQueryRunner(ServerVersion(2, 16, 0)) + return CollectingQueryRunner(DEFAULT_SERVER_VERSION) @pytest.fixture @@ -21,8 +22,24 @@ def graph(query_runner: CollectingQueryRunner) -> Graph: return Graph("test_graph", query_runner) -def test_mutate_basic(wcc_endpoints: WccCypherEndpoints, graph: Graph, query_runner: CollectingQueryRunner) -> None: - wcc_endpoints.mutate(graph, "componentId") +def test_mutate_basic(graph: Graph) -> None: + result = { + "nodePropertiesWritten": 5, + "mutateMillis": 42, + "componentCount": 3, + "preProcessingMillis": 10, + "computeMillis": 20, + "postProcessingMillis": 12, + "componentDistribution": {"foo": 42}, + "configuration": {"bar": 1337}, + } + + query_runner = CollectingQueryRunner( + DEFAULT_SERVER_VERSION, + {"wcc.mutate" : pd.DataFrame([result])} + ) + + result_obj = WccCypherEndpoints(query_runner).mutate(graph, "componentId") assert len(query_runner.queries) == 1 assert "gds.wcc.mutate" in query_runner.queries[0] @@ -32,11 +49,37 @@ def test_mutate_basic(wcc_endpoints: WccCypherEndpoints, graph: Graph, query_run assert config["mutateProperty"] == "componentId" assert "jobId" in config + assert isinstance(result_obj, WccMutateResult) + assert result_obj.node_properties_written == 5 + assert result_obj.mutate_millis == 42 + assert result_obj.component_count == 3 + assert result_obj.pre_processing_millis == 10 + assert result_obj.compute_millis == 20 + assert result_obj.post_processing_millis == 12 + assert result_obj.component_distribution == {"foo": 42} + assert result_obj.configuration == {"bar": 1337} + def test_mutate_with_optional_params( - wcc_endpoints: WccCypherEndpoints, graph: Graph, query_runner: CollectingQueryRunner + graph: Graph ) -> None: - wcc_endpoints.mutate( + result = { + "nodePropertiesWritten": 5, + "mutateMillis": 42, + "componentCount": 3, + "preProcessingMillis": 10, + "computeMillis": 20, + "postProcessingMillis": 12, + "componentDistribution": {"foo": 42}, + "configuration": {"bar": 1337}, + } + + query_runner = CollectingQueryRunner( + DEFAULT_SERVER_VERSION, + {"wcc.mutate" : pd.DataFrame([result])} + ) + + WccCypherEndpoints(query_runner).mutate( graph, "componentId", threshold=0.5, @@ -72,8 +115,22 @@ def test_mutate_with_optional_params( } -def test_stats_basic(wcc_endpoints: WccCypherEndpoints, graph: Graph, query_runner: CollectingQueryRunner) -> None: - wcc_endpoints.stats(graph) +def test_stats_basic(graph: Graph) -> None: + result = { + "componentCount": 3, + "preProcessingMillis": 10, + "computeMillis": 20, + "postProcessingMillis": 12, + "componentDistribution": {"foo": 42}, + "configuration": {"bar": 1337} + } + + query_runner = CollectingQueryRunner( + DEFAULT_SERVER_VERSION, + {"wcc.stats": pd.DataFrame([result])} + ) + + result_obj = WccCypherEndpoints(query_runner).stats(graph) assert len(query_runner.queries) == 1 assert "gds.wcc.stats" in query_runner.queries[0] @@ -82,11 +139,33 @@ def test_stats_basic(wcc_endpoints: WccCypherEndpoints, graph: Graph, query_runn config = params["config"] assert "jobId" in config + assert isinstance(result_obj, WccStatsResult) + assert result_obj.component_count == 3 + assert result_obj.pre_processing_millis == 10 + assert result_obj.compute_millis == 20 + assert result_obj.post_processing_millis == 12 + assert result_obj.component_distribution == {"foo": 42} + assert result_obj.configuration == {"bar": 1337} + def test_stats_with_optional_params( - wcc_endpoints: WccCypherEndpoints, graph: Graph, query_runner: CollectingQueryRunner + graph: Graph ) -> None: - wcc_endpoints.stats( + result = { + "componentCount": 3, + "preProcessingMillis": 10, + "computeMillis": 20, + "postProcessingMillis": 12, + "componentDistribution": {"foo": 42}, + "configuration": {"bar": 1337} + } + + query_runner = CollectingQueryRunner( + DEFAULT_SERVER_VERSION, + {"wcc.stats": pd.DataFrame([result])} + ) + + WccCypherEndpoints(query_runner).stats( graph, threshold=0.5, relationship_types=["REL"], @@ -170,8 +249,24 @@ def test_stream_with_optional_params( } -def test_write_basic(wcc_endpoints: WccCypherEndpoints, graph: Graph, query_runner: CollectingQueryRunner) -> None: - wcc_endpoints.write(graph, "componentId") +def test_write_basic(graph: Graph) -> None: + result = { + "componentCount": 3, + "preProcessingMillis": 10, + "computeMillis": 20, + "writeMillis": 15, + "postProcessingMillis": 12, + "nodePropertiesWritten": 5, + "componentDistribution": {"foo": 42}, + "configuration": {"bar": 1337} + } + + query_runner = CollectingQueryRunner( + DEFAULT_SERVER_VERSION, + {"wcc.write": pd.DataFrame([result])} + ) + + result_obj = WccCypherEndpoints(query_runner).write(graph, "componentId") assert len(query_runner.queries) == 1 assert "gds.wcc.write" in query_runner.queries[0] @@ -181,11 +276,37 @@ def test_write_basic(wcc_endpoints: WccCypherEndpoints, graph: Graph, query_runn assert config["writeProperty"] == "componentId" assert "jobId" in config + assert isinstance(result_obj, WccWriteResult) + assert result_obj.component_count == 3 + assert result_obj.pre_processing_millis == 10 + assert result_obj.compute_millis == 20 + assert result_obj.write_millis == 15 + assert result_obj.post_processing_millis == 12 + assert result_obj.node_properties_written == 5 + assert result_obj.component_distribution == {"foo": 42} + assert result_obj.configuration == {"bar": 1337} + def test_write_with_optional_params( - wcc_endpoints: WccCypherEndpoints, graph: Graph, query_runner: CollectingQueryRunner + graph: Graph ) -> None: - wcc_endpoints.write( + result = { + "componentCount": 3, + "preProcessingMillis": 10, + "computeMillis": 20, + "writeMillis": 15, + "postProcessingMillis": 12, + "nodePropertiesWritten": 5, + "componentDistribution": {"foo": 42}, + "configuration": {"bar": 1337} + } + + query_runner = CollectingQueryRunner( + DEFAULT_SERVER_VERSION, + {"wcc.write": pd.DataFrame([result])} + ) + + WccCypherEndpoints(query_runner).write( graph, "componentId", min_component_size=2, @@ -201,7 +322,6 @@ def test_write_with_optional_params( consecutive_ids=True, relationship_weight_property="weight", write_concurrency=4, - write_to_result_store=True, ) assert len(query_runner.queries) == 1 @@ -223,5 +343,4 @@ def test_write_with_optional_params( "consecutiveIds": True, "relationshipWeightProperty": "weight", "writeConcurrency": 4, - "writeToResultStore": True, } From d8f221ca6fe8dddfd79b293bf38a5ac030495491 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Fri, 27 Jun 2025 12:29:15 +0200 Subject: [PATCH 09/12] Fix type issues --- .../authenticated_arrow_client.py | 15 +++--- graphdatascience/arrow_client/data_mapper.py | 8 +-- .../arrow_client/v2/job_client.py | 4 +- .../arrow_client/v2/write_back_client.py | 9 +++- graphdatascience/graph_data_science.py | 7 ++- .../arrow/wcc_arrow_endpoints.py | 5 +- .../cypher/wcc_proc_runner.py | 2 +- .../unit/arrow_client/test_data_mapper.py | 6 +-- .../cypher/test_wcc_cypher_endpoints.py | 52 +++++-------------- 9 files changed, 46 insertions(+), 62 deletions(-) diff --git a/graphdatascience/arrow_client/authenticated_arrow_client.py b/graphdatascience/arrow_client/authenticated_arrow_client.py index 61216de5f..8fea1e3c9 100644 --- a/graphdatascience/arrow_client/authenticated_arrow_client.py +++ b/graphdatascience/arrow_client/authenticated_arrow_client.py @@ -2,7 +2,7 @@ import logging from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, Iterator, Optional from pyarrow import __version__ as arrow_version from pyarrow import flight @@ -12,6 +12,7 @@ FlightStreamReader, FlightTimedOutError, FlightUnavailableError, + Result, Ticket, ) from tenacity import retry, retry_any, retry_if_exception_type, stop_after_attempt, stop_after_delay, wait_exponential @@ -71,7 +72,7 @@ def __init__( host: str, retry_config: RetryConfig, port: int = 8491, - auth: Optional[Union[ArrowAuthentication, tuple[str, str]]] = None, + auth: Optional[ArrowAuthentication] = None, encrypted: bool = False, disable_server_verification: bool = False, tls_root_certs: Optional[bytes] = None, @@ -85,7 +86,7 @@ def __init__( The host address of the GDS Arrow server port: int The host port of the GDS Arrow server (default is 8491) - auth: Optional[Union[ArrowAuthentication, tuple[str, str]]] + auth: Optional[ArrowAuthentication] Either an implementation of ArrowAuthentication providing a pair to be used for basic authentication, or a username, password tuple encrypted: bool A flag that indicates whether the connection should be encrypted (default is False) @@ -166,10 +167,10 @@ def auth_with_retry() -> None: def get_stream(self, ticket: Ticket) -> FlightStreamReader: return self._flight_client.do_get(ticket) - def do_action(self, endpoint: str, payload: bytes): - return self._flight_client.do_action(Action(endpoint, payload)) + def do_action(self, endpoint: str, payload: bytes) -> Iterator[Result]: + return self._flight_client.do_action(Action(endpoint, payload)) # type: ignore - def do_action_with_retry(self, endpoint: str, payload: bytes): + def do_action_with_retry(self, endpoint: str, payload: bytes) -> Iterator[Result]: @retry( reraise=True, before=before_log("Send action", self._logger, logging.DEBUG), @@ -177,7 +178,7 @@ def do_action_with_retry(self, endpoint: str, payload: bytes): stop=self._retry_config.stop, wait=self._retry_config.wait, ) - def run_with_retry(): + def run_with_retry() -> Iterator[Result]: return self.do_action(endpoint, payload) return run_with_retry() diff --git a/graphdatascience/arrow_client/data_mapper.py b/graphdatascience/arrow_client/data_mapper.py index f09cf6ca6..6e326af14 100644 --- a/graphdatascience/arrow_client/data_mapper.py +++ b/graphdatascience/arrow_client/data_mapper.py @@ -19,8 +19,8 @@ def deserialize_single(input_stream: Iterator[Result], cls: Type[T]) -> T: return rows[0] @staticmethod - def deserialize(input_stream, cls: Type[T]) -> list[T]: - def deserialize_row(row: Any): + def deserialize(input_stream: Iterator[Result], cls: Type[T]) -> list[T]: + def deserialize_row(row: Result): # type:ignore result_dicts = json.loads(row.body.to_pybytes().decode()) if cls == Dict: return result_dicts @@ -46,10 +46,10 @@ def dict_to_dataclass(data: Dict[str, Any], cls: Type[T], strict: bool = False) # Handle nested dataclasses if dataclasses.is_dataclass(field_type) and isinstance(value, dict): - filtered_data[key] = DataMapper.dict_to_dataclass(value, field_type, strict) + filtered_data[key] = DataMapper.dict_to_dataclass(value, field_type, strict) # type:ignore else: filtered_data[key] = value elif strict: raise ValueError(f"Extra field '{key}' not allowed in {cls.__name__}") - return cls(**filtered_data) + return cls(**filtered_data) # type: ignore diff --git a/graphdatascience/arrow_client/v2/job_client.py b/graphdatascience/arrow_client/v2/job_client.py index e6131bca1..afa66a090 100644 --- a/graphdatascience/arrow_client/v2/job_client.py +++ b/graphdatascience/arrow_client/v2/job_client.py @@ -26,7 +26,7 @@ def run_job(client: AuthenticatedArrowClient, endpoint: str, config: dict[str, A return DataMapper.deserialize_single(res, JobIdConfig).jobId @staticmethod - def wait_for_job(client: AuthenticatedArrowClient, job_id: str): + def wait_for_job(client: AuthenticatedArrowClient, job_id: str) -> None: while True: job_id_config = {"jobId": job_id} encoded_config = json.dumps(job_id_config).encode("utf-8") @@ -42,7 +42,7 @@ def get_summary(client: AuthenticatedArrowClient, job_id: str) -> dict[str, Any] encoded_config = json.dumps(job_id_config).encode("utf-8") res = client.do_action_with_retry(RESULTS_SUMMARY_ENDPOINT, encoded_config) - return DataMapper.deserialize_single(res, Dict) + return DataMapper.deserialize_single(res, Dict[str, Any]) @staticmethod def stream_results(client: AuthenticatedArrowClient, job_id: str) -> DataFrame: diff --git a/graphdatascience/arrow_client/v2/write_back_client.py b/graphdatascience/arrow_client/v2/write_back_client.py index bbfec3e27..fb5441db4 100644 --- a/graphdatascience/arrow_client/v2/write_back_client.py +++ b/graphdatascience/arrow_client/v2/write_back_client.py @@ -40,10 +40,15 @@ def write(self, graph_name: str, job_id: str, concurrency: Optional[int]) -> int return int((time.time() - start_time) * 1000) def _arrow_configuration(self) -> dict[str, Any]: - host, port, encrypted = self._arrow_client.connection_info() + connection_info = self._arrow_client.connection_info() token = self._arrow_client.request_token() if token is None: token = "IGNORED" - arrow_config = {"host": host, "port": port, "token": token, "encrypted": encrypted()} + arrow_config = { + "host": connection_info.host, + "port": connection_info.port, + "token": token, + "encrypted": connection_info.encrypted, + } return arrow_config diff --git a/graphdatascience/graph_data_science.py b/graphdatascience/graph_data_science.py index 008008b1a..6055f0d97 100644 --- a/graphdatascience/graph_data_science.py +++ b/graphdatascience/graph_data_science.py @@ -10,7 +10,6 @@ from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication from graphdatascience.arrow_client.arrow_info import ArrowInfo -from graphdatascience.procedure_surface.api.wcc_endpoints import WccEndpoints from graphdatascience.procedure_surface.cypher.wcc_proc_runner import WccCypherEndpoints from .call_builder import IndirectCallBuilder @@ -125,9 +124,9 @@ def __init__( def graph(self) -> GraphProcRunner: return GraphProcRunner(self._query_runner, f"{self._namespace}.graph", self._server_version) - @property - def wcc(self) -> WccEndpoints: - return self._wcc_endpoints + # @property + # def wcc(self) -> WccEndpoints: + # return self._wcc_endpoints @property def util(self) -> UtilProcRunner: diff --git a/graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py index 63ac1df55..36aeaf4a5 100644 --- a/graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py @@ -176,6 +176,9 @@ def write( job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config) computation_result = JobClient.get_summary(self._arrow_client, job_id) + if self._write_back_client is None: + raise Exception("Write back client is not initialized") + write_millis = self._write_back_client.write( G.name(), job_id, write_concurrency if write_concurrency is not None else concurrency ) @@ -205,7 +208,7 @@ def _build_configuration( seed_property: Optional[str], sudo: Optional[bool], threshold: Optional[float], - ): + ) -> dict[str, Any]: config: dict[str, Any] = { "graphName": G.name(), } diff --git a/graphdatascience/procedure_surface/cypher/wcc_proc_runner.py b/graphdatascience/procedure_surface/cypher/wcc_proc_runner.py index 7449b9936..79e41bcf5 100644 --- a/graphdatascience/procedure_surface/cypher/wcc_proc_runner.py +++ b/graphdatascience/procedure_surface/cypher/wcc_proc_runner.py @@ -233,7 +233,7 @@ def _create_procedure_config( sudo: Optional[bool], threshold: Optional[float], username: Optional[str], - ): + ) -> None: # Add optional parameters if min_component_size is not None: config["minComponentSize"] = min_component_size diff --git a/graphdatascience/tests/unit/arrow_client/test_data_mapper.py b/graphdatascience/tests/unit/arrow_client/test_data_mapper.py index 0cfb574d7..e4e172597 100644 --- a/graphdatascience/tests/unit/arrow_client/test_data_mapper.py +++ b/graphdatascience/tests/unit/arrow_client/test_data_mapper.py @@ -27,18 +27,18 @@ class ExampleDataclass: ), ], ) -def test_dict_to_dataclass(input_data: Dict[str, Any], expected_output: ExampleDataclass): +def test_dict_to_dataclass(input_data: Dict[str, Any], expected_output: ExampleDataclass) -> None: result = DataMapper.dict_to_dataclass(input_data, ExampleDataclass) assert result == expected_output -def test_dict_to_dataclass_strict_mode_rejects_extra_fields(): +def test_dict_to_dataclass_strict_mode_rejects_extra_fields() -> None: input_data = {"field_one": "test", "field_two": 123, "nested": {"nested_field": 456}, "extra_field": "not_allowed"} with pytest.raises(ValueError, match="Extra field 'extra_field' not allowed in ExampleDataclass"): DataMapper.dict_to_dataclass(input_data, ExampleDataclass, strict=True) -def test_dict_to_dataclass_non_dataclass_error(): +def test_dict_to_dataclass_non_dataclass_error() -> None: with pytest.raises(ValueError, match="is not a dataclass"): DataMapper.dict_to_dataclass({"key": "value"}, int) diff --git a/graphdatascience/tests/unit/procedure_surface/cypher/test_wcc_cypher_endpoints.py b/graphdatascience/tests/unit/procedure_surface/cypher/test_wcc_cypher_endpoints.py index 3f93f3eab..bb46e7be5 100644 --- a/graphdatascience/tests/unit/procedure_surface/cypher/test_wcc_cypher_endpoints.py +++ b/graphdatascience/tests/unit/procedure_surface/cypher/test_wcc_cypher_endpoints.py @@ -4,7 +4,7 @@ from graphdatascience.graph.graph_object import Graph from graphdatascience.procedure_surface.api.wcc_endpoints import WccMutateResult, WccStatsResult, WccWriteResult from graphdatascience.procedure_surface.cypher.wcc_proc_runner import WccCypherEndpoints -from graphdatascience.tests.unit.conftest import CollectingQueryRunner, DEFAULT_SERVER_VERSION +from graphdatascience.tests.unit.conftest import DEFAULT_SERVER_VERSION, CollectingQueryRunner @pytest.fixture @@ -34,10 +34,7 @@ def test_mutate_basic(graph: Graph) -> None: "configuration": {"bar": 1337}, } - query_runner = CollectingQueryRunner( - DEFAULT_SERVER_VERSION, - {"wcc.mutate" : pd.DataFrame([result])} - ) + query_runner = CollectingQueryRunner(DEFAULT_SERVER_VERSION, {"wcc.mutate": pd.DataFrame([result])}) result_obj = WccCypherEndpoints(query_runner).mutate(graph, "componentId") @@ -60,9 +57,7 @@ def test_mutate_basic(graph: Graph) -> None: assert result_obj.configuration == {"bar": 1337} -def test_mutate_with_optional_params( - graph: Graph -) -> None: +def test_mutate_with_optional_params(graph: Graph) -> None: result = { "nodePropertiesWritten": 5, "mutateMillis": 42, @@ -74,10 +69,7 @@ def test_mutate_with_optional_params( "configuration": {"bar": 1337}, } - query_runner = CollectingQueryRunner( - DEFAULT_SERVER_VERSION, - {"wcc.mutate" : pd.DataFrame([result])} - ) + query_runner = CollectingQueryRunner(DEFAULT_SERVER_VERSION, {"wcc.mutate": pd.DataFrame([result])}) WccCypherEndpoints(query_runner).mutate( graph, @@ -122,13 +114,10 @@ def test_stats_basic(graph: Graph) -> None: "computeMillis": 20, "postProcessingMillis": 12, "componentDistribution": {"foo": 42}, - "configuration": {"bar": 1337} + "configuration": {"bar": 1337}, } - query_runner = CollectingQueryRunner( - DEFAULT_SERVER_VERSION, - {"wcc.stats": pd.DataFrame([result])} - ) + query_runner = CollectingQueryRunner(DEFAULT_SERVER_VERSION, {"wcc.stats": pd.DataFrame([result])}) result_obj = WccCypherEndpoints(query_runner).stats(graph) @@ -148,22 +137,17 @@ def test_stats_basic(graph: Graph) -> None: assert result_obj.configuration == {"bar": 1337} -def test_stats_with_optional_params( - graph: Graph -) -> None: +def test_stats_with_optional_params(graph: Graph) -> None: result = { "componentCount": 3, "preProcessingMillis": 10, "computeMillis": 20, "postProcessingMillis": 12, "componentDistribution": {"foo": 42}, - "configuration": {"bar": 1337} + "configuration": {"bar": 1337}, } - query_runner = CollectingQueryRunner( - DEFAULT_SERVER_VERSION, - {"wcc.stats": pd.DataFrame([result])} - ) + query_runner = CollectingQueryRunner(DEFAULT_SERVER_VERSION, {"wcc.stats": pd.DataFrame([result])}) WccCypherEndpoints(query_runner).stats( graph, @@ -258,13 +242,10 @@ def test_write_basic(graph: Graph) -> None: "postProcessingMillis": 12, "nodePropertiesWritten": 5, "componentDistribution": {"foo": 42}, - "configuration": {"bar": 1337} + "configuration": {"bar": 1337}, } - query_runner = CollectingQueryRunner( - DEFAULT_SERVER_VERSION, - {"wcc.write": pd.DataFrame([result])} - ) + query_runner = CollectingQueryRunner(DEFAULT_SERVER_VERSION, {"wcc.write": pd.DataFrame([result])}) result_obj = WccCypherEndpoints(query_runner).write(graph, "componentId") @@ -287,9 +268,7 @@ def test_write_basic(graph: Graph) -> None: assert result_obj.configuration == {"bar": 1337} -def test_write_with_optional_params( - graph: Graph -) -> None: +def test_write_with_optional_params(graph: Graph) -> None: result = { "componentCount": 3, "preProcessingMillis": 10, @@ -298,13 +277,10 @@ def test_write_with_optional_params( "postProcessingMillis": 12, "nodePropertiesWritten": 5, "componentDistribution": {"foo": 42}, - "configuration": {"bar": 1337} + "configuration": {"bar": 1337}, } - query_runner = CollectingQueryRunner( - DEFAULT_SERVER_VERSION, - {"wcc.write": pd.DataFrame([result])} - ) + query_runner = CollectingQueryRunner(DEFAULT_SERVER_VERSION, {"wcc.write": pd.DataFrame([result])}) WccCypherEndpoints(query_runner).write( graph, From 03eb011eb8db710e32b6dc60205875bd14f5b243 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Tue, 1 Jul 2025 11:56:26 +0200 Subject: [PATCH 10/12] Generalize config extraction for arrow endpoints --- .../arrow/arrow_config_converter.py | 38 +++++ .../arrow/wcc_arrow_endpoints.py | 142 ++++++------------ .../tests/unit/procedure_surface/__init__.py | 0 .../unit/procedure_surface/arrow/__init__.py | 0 4 files changed, 87 insertions(+), 93 deletions(-) create mode 100644 graphdatascience/procedure_surface/arrow/arrow_config_converter.py create mode 100644 graphdatascience/tests/unit/procedure_surface/__init__.py create mode 100644 graphdatascience/tests/unit/procedure_surface/arrow/__init__.py diff --git a/graphdatascience/procedure_surface/arrow/arrow_config_converter.py b/graphdatascience/procedure_surface/arrow/arrow_config_converter.py new file mode 100644 index 000000000..51dd90f9b --- /dev/null +++ b/graphdatascience/procedure_surface/arrow/arrow_config_converter.py @@ -0,0 +1,38 @@ +from typing import Optional, Any, Dict + +from graphdatascience import Graph + + +class ArrowConfigConverter: + + @staticmethod + def build_configuration(G: Graph, **kwargs: Optional[Any]) -> dict[str, Any]: + config: dict[str, Any] = { + "graphName": G.name(), + } + + # Process kwargs + processed_kwargs = ArrowConfigConverter._process_dict_values(kwargs) + config.update(processed_kwargs) + + return config + + @staticmethod + def _convert_to_camel_case(name: str) -> str: + """Convert a snake_case string to camelCase.""" + parts = name.split('_') + return ''.join([word.capitalize() if i > 0 else word.lower() for i, word in enumerate(parts)]) + + @staticmethod + def _process_dict_values(input_dict: Dict[str, Any]) -> Dict[str, Any]: + """Process dictionary values, converting keys to camelCase and handling nested dictionaries.""" + result = {} + for key, value in input_dict.items(): + if value is not None: + camel_key = ArrowConfigConverter._convert_to_camel_case(key) + # Recursively process nested dictionaries + if isinstance(value, dict): + result[camel_key] = ArrowConfigConverter._process_dict_values(value) + else: + result[camel_key] = value + return result \ No newline at end of file diff --git a/graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py index 36aeaf4a5..99e7aba35 100644 --- a/graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py @@ -2,6 +2,7 @@ from pandas import DataFrame +from .arrow_config_converter import ArrowConfigConverter from ...arrow_client.authenticated_arrow_client import AuthenticatedArrowClient from ...arrow_client.v2.job_client import JobClient from ...arrow_client.v2.mutation_client import MutationClient @@ -33,19 +34,18 @@ def mutate( consecutive_ids: Optional[bool] = None, relationship_weight_property: Optional[str] = None, ) -> WccMutateResult: - config = self._build_configuration( + config = ArrowConfigConverter.build_configuration( G, - concurrency, - consecutive_ids, - job_id, - log_progress, - None, - node_labels, - relationship_types, - relationship_weight_property, - seed_property, - sudo, - threshold, + concurrency = concurrency, + consecutive_ids = consecutive_ids, + job_id = job_id, + log_progress = log_progress, + node_labels = node_labels, + relationship_types = relationship_types, + relationship_weight_property = relationship_weight_property, + seed_property = seed_property, + sudo = sudo, + threshold = threshold, ) job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config) @@ -79,19 +79,18 @@ def stats( consecutive_ids: Optional[bool] = None, relationship_weight_property: Optional[str] = None, ) -> WccStatsResult: - config = self._build_configuration( + config = ArrowConfigConverter.build_configuration( G, - concurrency, - consecutive_ids, - job_id, - log_progress, - None, - node_labels, - relationship_types, - relationship_weight_property, - seed_property, - sudo, - threshold, + concurrency = concurrency, + consecutive_ids = consecutive_ids, + job_id = job_id, + log_progress = log_progress, + node_labels = node_labels, + relationship_types = relationship_types, + relationship_weight_property = relationship_weight_property, + seed_property = seed_property, + sudo = sudo, + threshold = threshold, ) job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config) @@ -122,19 +121,19 @@ def stream( consecutive_ids: Optional[bool] = None, relationship_weight_property: Optional[str] = None, ) -> DataFrame: - config = self._build_configuration( + config = ArrowConfigConverter.build_configuration( G, - concurrency, - consecutive_ids, - job_id, - log_progress, - min_component_size, - node_labels, - relationship_types, - relationship_weight_property, - seed_property, - sudo, - threshold, + concurrency = concurrency, + consecutive_ids = consecutive_ids, + job_id = job_id, + log_progress = log_progress, + min_component_size = min_component_size, + node_labels = node_labels, + relationship_types = relationship_types, + relationship_weight_property = relationship_weight_property, + seed_property = seed_property, + sudo = sudo, + threshold = threshold, ) job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config) @@ -158,19 +157,20 @@ def write( relationship_weight_property: Optional[str] = None, write_concurrency: Optional[int] = None, ) -> WccWriteResult: - config = self._build_configuration( + + config = ArrowConfigConverter.build_configuration( G, - concurrency, - consecutive_ids, - job_id, - log_progress, - min_component_size, - node_labels, - relationship_types, - relationship_weight_property, - seed_property, - sudo, - threshold, + concurrency = concurrency, + consecutive_ids = consecutive_ids, + job_id = job_id, + log_progress = log_progress, + min_component_size = min_component_size, + node_labels = node_labels, + relationship_types = relationship_types, + relationship_weight_property = relationship_weight_property, + seed_property = seed_property, + sudo = sudo, + threshold = threshold, ) job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config) @@ -192,48 +192,4 @@ def write( computation_result["postProcessingMillis"], computation_result["nodePropertiesWritten"], computation_result["configuration"], - ) - - @staticmethod - def _build_configuration( - G: Graph, - concurrency: Optional[int], - consecutive_ids: Optional[bool], - job_id: Optional[str], - log_progress: Optional[bool], - min_component_size: Optional[int], - node_labels: Optional[List[str]], - relationship_types: Optional[List[str]], - relationship_weight_property: Optional[str], - seed_property: Optional[str], - sudo: Optional[bool], - threshold: Optional[float], - ) -> dict[str, Any]: - config: dict[str, Any] = { - "graphName": G.name(), - } - - if min_component_size is not None: - config["minComponentSize"] = min_component_size - if threshold is not None: - config["threshold"] = threshold - if relationship_types is not None: - config["relationshipTypes"] = relationship_types - if node_labels is not None: - config["nodeLabels"] = node_labels - if sudo is not None: - config["sudo"] = sudo - if log_progress is not None: - config["logProgress"] = log_progress - if concurrency is not None: - config["concurrency"] = concurrency - if job_id is not None: - config["jobId"] = job_id - if seed_property is not None: - config["seedProperty"] = seed_property - if consecutive_ids is not None: - config["consecutiveIds"] = consecutive_ids - if relationship_weight_property is not None: - config["relationshipWeightProperty"] = relationship_weight_property - - return config + ) \ No newline at end of file diff --git a/graphdatascience/tests/unit/procedure_surface/__init__.py b/graphdatascience/tests/unit/procedure_surface/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/tests/unit/procedure_surface/arrow/__init__.py b/graphdatascience/tests/unit/procedure_surface/arrow/__init__.py new file mode 100644 index 000000000..e69de29bb From ed37a4d1d6220c05acabe0daf9970f5a4df5b254 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Tue, 1 Jul 2025 12:07:18 +0200 Subject: [PATCH 11/12] Use config converter also for Cypher endpoints --- .../arrow/wcc_arrow_endpoints.py | 108 +++++------ ...onfig_converter.py => config_converter.py} | 25 ++- ...proc_runner.py => wcc_cypher_endpoints.py} | 168 ++++++------------ .../cypher/test_wcc_cypher_endpoints.py | 2 +- .../test_config_converter.py | 39 ++++ 5 files changed, 158 insertions(+), 184 deletions(-) rename graphdatascience/procedure_surface/{arrow/arrow_config_converter.py => config_converter.py} (54%) rename graphdatascience/procedure_surface/cypher/{wcc_proc_runner.py => wcc_cypher_endpoints.py} (59%) create mode 100644 graphdatascience/tests/unit/procedure_surface/test_config_converter.py diff --git a/graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py index 99e7aba35..8261f56ba 100644 --- a/graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py @@ -1,8 +1,9 @@ -from typing import Any, List, Optional +from typing import List, Optional from pandas import DataFrame -from .arrow_config_converter import ArrowConfigConverter +from graphdatascience.procedure_surface.config_converter import ConfigConverter + from ...arrow_client.authenticated_arrow_client import AuthenticatedArrowClient from ...arrow_client.v2.job_client import JobClient from ...arrow_client.v2.mutation_client import MutationClient @@ -34,18 +35,18 @@ def mutate( consecutive_ids: Optional[bool] = None, relationship_weight_property: Optional[str] = None, ) -> WccMutateResult: - config = ArrowConfigConverter.build_configuration( - G, - concurrency = concurrency, - consecutive_ids = consecutive_ids, - job_id = job_id, - log_progress = log_progress, - node_labels = node_labels, - relationship_types = relationship_types, - relationship_weight_property = relationship_weight_property, - seed_property = seed_property, - sudo = sudo, - threshold = threshold, + config = ConfigConverter.convert_to_gds_config( + graph_name=G.name(), + concurrency=concurrency, + consecutive_ids=consecutive_ids, + job_id=job_id, + log_progress=log_progress, + node_labels=node_labels, + relationship_types=relationship_types, + relationship_weight_property=relationship_weight_property, + seed_property=seed_property, + sudo=sudo, + threshold=threshold, ) job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config) @@ -79,18 +80,18 @@ def stats( consecutive_ids: Optional[bool] = None, relationship_weight_property: Optional[str] = None, ) -> WccStatsResult: - config = ArrowConfigConverter.build_configuration( - G, - concurrency = concurrency, - consecutive_ids = consecutive_ids, - job_id = job_id, - log_progress = log_progress, - node_labels = node_labels, - relationship_types = relationship_types, - relationship_weight_property = relationship_weight_property, - seed_property = seed_property, - sudo = sudo, - threshold = threshold, + config = ConfigConverter.convert_to_gds_config( + graph_name=G.name(), + concurrency=concurrency, + consecutive_ids=consecutive_ids, + job_id=job_id, + log_progress=log_progress, + node_labels=node_labels, + relationship_types=relationship_types, + relationship_weight_property=relationship_weight_property, + seed_property=seed_property, + sudo=sudo, + threshold=threshold, ) job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config) @@ -121,19 +122,19 @@ def stream( consecutive_ids: Optional[bool] = None, relationship_weight_property: Optional[str] = None, ) -> DataFrame: - config = ArrowConfigConverter.build_configuration( - G, - concurrency = concurrency, - consecutive_ids = consecutive_ids, - job_id = job_id, - log_progress = log_progress, - min_component_size = min_component_size, - node_labels = node_labels, - relationship_types = relationship_types, - relationship_weight_property = relationship_weight_property, - seed_property = seed_property, - sudo = sudo, - threshold = threshold, + config = ConfigConverter.convert_to_gds_config( + graph_name=G.name(), + concurrency=concurrency, + consecutive_ids=consecutive_ids, + job_id=job_id, + log_progress=log_progress, + min_component_size=min_component_size, + node_labels=node_labels, + relationship_types=relationship_types, + relationship_weight_property=relationship_weight_property, + seed_property=seed_property, + sudo=sudo, + threshold=threshold, ) job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config) @@ -157,20 +158,19 @@ def write( relationship_weight_property: Optional[str] = None, write_concurrency: Optional[int] = None, ) -> WccWriteResult: - - config = ArrowConfigConverter.build_configuration( - G, - concurrency = concurrency, - consecutive_ids = consecutive_ids, - job_id = job_id, - log_progress = log_progress, - min_component_size = min_component_size, - node_labels = node_labels, - relationship_types = relationship_types, - relationship_weight_property = relationship_weight_property, - seed_property = seed_property, - sudo = sudo, - threshold = threshold, + config = ConfigConverter.convert_to_gds_config( + graph_name=G.name(), + concurrency=concurrency, + consecutive_ids=consecutive_ids, + job_id=job_id, + log_progress=log_progress, + min_component_size=min_component_size, + node_labels=node_labels, + relationship_types=relationship_types, + relationship_weight_property=relationship_weight_property, + seed_property=seed_property, + sudo=sudo, + threshold=threshold, ) job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config) @@ -192,4 +192,4 @@ def write( computation_result["postProcessingMillis"], computation_result["nodePropertiesWritten"], computation_result["configuration"], - ) \ No newline at end of file + ) diff --git a/graphdatascience/procedure_surface/arrow/arrow_config_converter.py b/graphdatascience/procedure_surface/config_converter.py similarity index 54% rename from graphdatascience/procedure_surface/arrow/arrow_config_converter.py rename to graphdatascience/procedure_surface/config_converter.py index 51dd90f9b..c6da60355 100644 --- a/graphdatascience/procedure_surface/arrow/arrow_config_converter.py +++ b/graphdatascience/procedure_surface/config_converter.py @@ -1,18 +1,13 @@ -from typing import Optional, Any, Dict +from typing import Any, Dict, Optional -from graphdatascience import Graph - - -class ArrowConfigConverter: +class ConfigConverter: @staticmethod - def build_configuration(G: Graph, **kwargs: Optional[Any]) -> dict[str, Any]: - config: dict[str, Any] = { - "graphName": G.name(), - } + def convert_to_gds_config(**kwargs: Optional[Any]) -> dict[str, Any]: + config: dict[str, Any] = {} # Process kwargs - processed_kwargs = ArrowConfigConverter._process_dict_values(kwargs) + processed_kwargs = ConfigConverter._process_dict_values(kwargs) config.update(processed_kwargs) return config @@ -20,8 +15,8 @@ def build_configuration(G: Graph, **kwargs: Optional[Any]) -> dict[str, Any]: @staticmethod def _convert_to_camel_case(name: str) -> str: """Convert a snake_case string to camelCase.""" - parts = name.split('_') - return ''.join([word.capitalize() if i > 0 else word.lower() for i, word in enumerate(parts)]) + parts = name.split("_") + return "".join([word.capitalize() if i > 0 else word.lower() for i, word in enumerate(parts)]) @staticmethod def _process_dict_values(input_dict: Dict[str, Any]) -> Dict[str, Any]: @@ -29,10 +24,10 @@ def _process_dict_values(input_dict: Dict[str, Any]) -> Dict[str, Any]: result = {} for key, value in input_dict.items(): if value is not None: - camel_key = ArrowConfigConverter._convert_to_camel_case(key) + camel_key = ConfigConverter._convert_to_camel_case(key) # Recursively process nested dictionaries if isinstance(value, dict): - result[camel_key] = ArrowConfigConverter._process_dict_values(value) + result[camel_key] = ConfigConverter._process_dict_values(value) else: result[camel_key] = value - return result \ No newline at end of file + return result diff --git a/graphdatascience/procedure_surface/cypher/wcc_proc_runner.py b/graphdatascience/procedure_surface/cypher/wcc_cypher_endpoints.py similarity index 59% rename from graphdatascience/procedure_surface/cypher/wcc_proc_runner.py rename to graphdatascience/procedure_surface/cypher/wcc_cypher_endpoints.py index 79e41bcf5..7bebd3788 100644 --- a/graphdatascience/procedure_surface/cypher/wcc_proc_runner.py +++ b/graphdatascience/procedure_surface/cypher/wcc_cypher_endpoints.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional +from typing import List, Optional from pandas import DataFrame @@ -6,6 +6,7 @@ from ...graph.graph_object import Graph from ...query_runner.query_runner import QueryRunner from ..api.wcc_endpoints import WccEndpoints, WccMutateResult, WccStatsResult, WccWriteResult +from ..config_converter import ConfigConverter class WccCypherEndpoints(WccEndpoints): @@ -33,25 +34,19 @@ def mutate( consecutive_ids: Optional[bool] = None, relationship_weight_property: Optional[str] = None, ) -> WccMutateResult: - # Build configuration dictionary from parameters - config: dict[str, Any] = { - "mutateProperty": mutate_property, - } - - self._create_procedure_config( - config, - concurrency, - consecutive_ids, - job_id, - log_progress, - None, - node_labels, - relationship_types, - relationship_weight_property, - seed_property, - sudo, - threshold, - username, + config = ConfigConverter.convert_to_gds_config( + mutate_property=mutate_property, + concurrency=concurrency, + consecutive_ids=consecutive_ids, + job_id=job_id, + log_progress=log_progress, + node_labels=node_labels, + relationship_types=relationship_types, + relationship_weight_property=relationship_weight_property, + seed_property=seed_property, + sudo=sudo, + threshold=threshold, + username=username, ) # Run procedure and return results @@ -86,23 +81,18 @@ def stats( consecutive_ids: Optional[bool] = None, relationship_weight_property: Optional[str] = None, ) -> WccStatsResult: - # Build configuration dictionary from parameters - config: dict[str, Any] = {} - - self._create_procedure_config( - config, - concurrency, - consecutive_ids, - job_id, - log_progress, - None, - node_labels, - relationship_types, - relationship_weight_property, - seed_property, - sudo, - threshold, - username, + config = ConfigConverter.convert_to_gds_config( + concurrency=concurrency, + consecutive_ids=consecutive_ids, + job_id=job_id, + log_progress=log_progress, + node_labels=node_labels, + relationship_types=relationship_types, + relationship_weight_property=relationship_weight_property, + seed_property=seed_property, + sudo=sudo, + threshold=threshold, + username=username, ) # Run procedure and return results @@ -136,23 +126,19 @@ def stream( consecutive_ids: Optional[bool] = None, relationship_weight_property: Optional[str] = None, ) -> DataFrame: - # Build configuration dictionary from parameters - config: dict[str, Any] = {} - - self._create_procedure_config( - config, - concurrency, - consecutive_ids, - job_id, - log_progress, - min_component_size, - node_labels, - relationship_types, - relationship_weight_property, - seed_property, - sudo, - threshold, - username, + config = ConfigConverter.convert_to_gds_config( + concurrency=concurrency, + consecutive_ids=consecutive_ids, + job_id=job_id, + log_progress=log_progress, + min_component_size=min_component_size, + node_labels=node_labels, + relationship_types=relationship_types, + relationship_weight_property=relationship_weight_property, + seed_property=seed_property, + sudo=sudo, + threshold=threshold, + username=username, ) # Run procedure and return results @@ -179,24 +165,20 @@ def write( relationship_weight_property: Optional[str] = None, write_concurrency: Optional[int] = None, ) -> WccWriteResult: - # Build configuration dictionary from parameters - config: dict[str, Any] = { - "writeProperty": write_property, - } - self._create_procedure_config( - config, - concurrency, - consecutive_ids, - job_id, - log_progress, - min_component_size, - node_labels, - relationship_types, - relationship_weight_property, - seed_property, - sudo, - threshold, - username, + config = ConfigConverter.convert_to_gds_config( + write_property=write_property, + concurrency=concurrency, + consecutive_ids=consecutive_ids, + job_id=job_id, + log_progress=log_progress, + min_component_size=min_component_size, + node_labels=node_labels, + relationship_types=relationship_types, + relationship_weight_property=relationship_weight_property, + seed_property=seed_property, + sudo=sudo, + threshold=threshold, + username=username, ) if write_concurrency is not None: @@ -217,45 +199,3 @@ def write( result["nodePropertiesWritten"], result["configuration"], ) - - @staticmethod - def _create_procedure_config( - config: dict[str, Any], - concurrency: Optional[int], - consecutive_ids: Optional[bool], - job_id: Optional[str], - log_progress: Optional[bool], - min_component_size: Optional[int], - node_labels: Optional[List[str]], - relationship_types: Optional[List[str]], - relationship_weight_property: Optional[str], - seed_property: Optional[str], - sudo: Optional[bool], - threshold: Optional[float], - username: Optional[str], - ) -> None: - # Add optional parameters - if min_component_size is not None: - config["minComponentSize"] = min_component_size - if threshold is not None: - config["threshold"] = threshold - if relationship_types is not None: - config["relationshipTypes"] = relationship_types - if node_labels is not None: - config["nodeLabels"] = node_labels - if sudo is not None: - config["sudo"] = sudo - if log_progress is not None: - config["logProgress"] = log_progress - if username is not None: - config["username"] = username - if concurrency is not None: - config["concurrency"] = concurrency - if job_id is not None: - config["jobId"] = job_id - if seed_property is not None: - config["seedProperty"] = seed_property - if consecutive_ids is not None: - config["consecutiveIds"] = consecutive_ids - if relationship_weight_property is not None: - config["relationshipWeightProperty"] = relationship_weight_property diff --git a/graphdatascience/tests/unit/procedure_surface/cypher/test_wcc_cypher_endpoints.py b/graphdatascience/tests/unit/procedure_surface/cypher/test_wcc_cypher_endpoints.py index bb46e7be5..519825390 100644 --- a/graphdatascience/tests/unit/procedure_surface/cypher/test_wcc_cypher_endpoints.py +++ b/graphdatascience/tests/unit/procedure_surface/cypher/test_wcc_cypher_endpoints.py @@ -3,7 +3,7 @@ from graphdatascience.graph.graph_object import Graph from graphdatascience.procedure_surface.api.wcc_endpoints import WccMutateResult, WccStatsResult, WccWriteResult -from graphdatascience.procedure_surface.cypher.wcc_proc_runner import WccCypherEndpoints +from graphdatascience.procedure_surface.cypher.wcc_cypher_endpoints import WccCypherEndpoints from graphdatascience.tests.unit.conftest import DEFAULT_SERVER_VERSION, CollectingQueryRunner diff --git a/graphdatascience/tests/unit/procedure_surface/test_config_converter.py b/graphdatascience/tests/unit/procedure_surface/test_config_converter.py new file mode 100644 index 000000000..a01dd86e6 --- /dev/null +++ b/graphdatascience/tests/unit/procedure_surface/test_config_converter.py @@ -0,0 +1,39 @@ +from graphdatascience.procedure_surface.config_converter import ConfigConverter + + +def test_build_configuration_with_no_additional_args() -> None: + config = ConfigConverter.convert_to_gds_config() + assert config == {} + + +def test_build_configuration_with_additional_args() -> None: + config = ConfigConverter.convert_to_gds_config(some_property="value", another_property=42) + assert config["someProperty"] == "value" + assert config["anotherProperty"] == 42 + + +def test_build_configuration_ignores_none_values() -> None: + config = ConfigConverter.convert_to_gds_config(included_property="present", excluded_property=None) + assert "includedProperty" in config + assert "excludedProperty" not in config + assert config["includedProperty"] == "present" + + +def test_build_configuration_with_nested_dict() -> None: + config = ConfigConverter.convert_to_gds_config(foo_bar={"bar_baz": 42, "another_key": "value"}) + assert "fooBar" in config + assert isinstance(config["fooBar"], dict) + assert "barBaz" in config["fooBar"] + assert config["fooBar"]["barBaz"] == 42 + assert "anotherKey" in config["fooBar"] + assert config["fooBar"]["anotherKey"] == "value" + + +def test_build_configuration_with_deeply_nested_dict() -> None: + config = ConfigConverter.convert_to_gds_config(level_one={"level_two": {"level_three": 42}}) + assert "levelOne" in config + assert isinstance(config["levelOne"], dict) + assert "levelTwo" in config["levelOne"] + assert isinstance(config["levelOne"]["levelTwo"], dict) + assert "levelThree" in config["levelOne"]["levelTwo"] + assert config["levelOne"]["levelTwo"]["levelThree"] == 42 From ae2f8cbf73930468cd980eb7cb1268680518d89b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Kie=C3=9Fling?= Date: Tue, 1 Jul 2025 17:47:07 +0200 Subject: [PATCH 12/12] Add feature flag to enabled explicit APIs --- graphdatascience/graph_data_science.py | 17 +- .../procedure_surface/api/wcc_endpoints.py | 9 + .../session/aura_graph_data_science.py | 182 ++++++++++++------ .../session/dedicated_sessions.py | 6 +- .../tests/integration/conftest.py | 11 +- .../integration/test_remote_graph_ops.py | 4 +- .../tests/integration/test_simple_algo.py | 2 +- graphdatascience/tests/unit/conftest.py | 6 +- 8 files changed, 158 insertions(+), 79 deletions(-) diff --git a/graphdatascience/graph_data_science.py b/graphdatascience/graph_data_science.py index 6055f0d97..523cf6def 100644 --- a/graphdatascience/graph_data_science.py +++ b/graphdatascience/graph_data_science.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os import warnings from types import TracebackType from typing import Any, Optional, Type, Union @@ -10,12 +11,13 @@ from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication from graphdatascience.arrow_client.arrow_info import ArrowInfo -from graphdatascience.procedure_surface.cypher.wcc_proc_runner import WccCypherEndpoints +from graphdatascience.procedure_surface.cypher.wcc_cypher_endpoints import WccCypherEndpoints from .call_builder import IndirectCallBuilder from .endpoints import AlphaEndpoints, BetaEndpoints, DirectEndpoints from .error.uncallable_namespace import UncallableNamespace from .graph.graph_proc_runner import GraphProcRunner +from .procedure_surface.api.wcc_endpoints import WccEndpoints from .query_runner.arrow_query_runner import ArrowQueryRunner from .query_runner.neo4j_query_runner import Neo4jQueryRunner from .query_runner.query_runner import QueryRunner @@ -118,15 +120,20 @@ def __init__( self._query_runner.set_show_progress(show_progress) super().__init__(self._query_runner, namespace="gds", server_version=self._server_version) - self._wcc_endpoints = WccCypherEndpoints(self._query_runner) + self._wcc_endpoints: Optional[WccEndpoints] = None + if os.environ.get("ENABLE_EXPLICIT_ENDPOINTS") is not None: + self._wcc_endpoints = WccCypherEndpoints(self._query_runner) @property def graph(self) -> GraphProcRunner: return GraphProcRunner(self._query_runner, f"{self._namespace}.graph", self._server_version) - # @property - # def wcc(self) -> WccEndpoints: - # return self._wcc_endpoints + @property + def wcc(self) -> Union[WccEndpoints, IndirectCallBuilder]: + if self._wcc_endpoints is None: + return IndirectCallBuilder(self._query_runner, f"gds.{self._namespace}.wcc", self._server_version) + + return self._wcc_endpoints @property def util(self) -> UtilProcRunner: diff --git a/graphdatascience/procedure_surface/api/wcc_endpoints.py b/graphdatascience/procedure_surface/api/wcc_endpoints.py index 5c32b822b..2096e653f 100644 --- a/graphdatascience/procedure_surface/api/wcc_endpoints.py +++ b/graphdatascience/procedure_surface/api/wcc_endpoints.py @@ -253,6 +253,9 @@ class WccMutateResult: node_properties_written: int configuration: dict[str, Any] + def __getitem__(self, item: str) -> Any: + return getattr(self, item) + @dataclass(frozen=True, repr=True) class WccStatsResult: @@ -263,6 +266,9 @@ class WccStatsResult: post_processing_millis: int configuration: dict[str, Any] + def __getitem__(self, item: str) -> Any: + return getattr(self, item) + @dataclass(frozen=True, repr=True) class WccWriteResult: @@ -274,3 +280,6 @@ class WccWriteResult: post_processing_millis: int node_properties_written: int configuration: dict[str, Any] + + def __getitem__(self, item: str) -> Any: + return getattr(self, item) diff --git a/graphdatascience/session/aura_graph_data_science.py b/graphdatascience/session/aura_graph_data_science.py index f35690d87..c30b7ffdd 100644 --- a/graphdatascience/session/aura_graph_data_science.py +++ b/graphdatascience/session/aura_graph_data_science.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from typing import Any, Callable, Optional, Union from pandas import DataFrame @@ -7,6 +8,8 @@ from graphdatascience import QueryRunner, ServerVersion from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication from graphdatascience.arrow_client.arrow_info import ArrowInfo +from graphdatascience.arrow_client.authenticated_arrow_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.v2.write_back_client import WriteBackClient from graphdatascience.call_builder import IndirectCallBuilder from graphdatascience.endpoints import ( AlphaRemoteEndpoints, @@ -15,6 +18,8 @@ ) from graphdatascience.error.uncallable_namespace import UncallableNamespace from graphdatascience.graph.graph_remote_proc_runner import GraphRemoteProcRunner +from graphdatascience.procedure_surface.api.wcc_endpoints import WccEndpoints +from graphdatascience.procedure_surface.arrow.wcc_arrow_endpoints import WccArrowEndpoints from graphdatascience.query_runner.arrow_query_runner import ArrowQueryRunner from graphdatascience.query_runner.gds_arrow_client import GdsArrowClient from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner @@ -24,15 +29,11 @@ from graphdatascience.utils.util_remote_proc_runner import UtilRemoteProcRunner -class AuraGraphDataScience(DirectEndpoints, UncallableNamespace): - """ - Primary API class for interacting with Neo4j database + Graph Data Science Session. - Always bind this object to a variable called `gds`. - """ +class AuraGraphDataScienceFactory: + """Factory class for creating AuraGraphDataScience instances with all required components.""" - @classmethod - def create( - cls, + def __init__( + self, session_bolt_connection_info: DbmsConnectionInfo, arrow_authentication: Optional[ArrowAuthentication], db_endpoint: Optional[Union[Neo4jQueryRunner, DbmsConnectionInfo]], @@ -41,74 +42,114 @@ def create( arrow_tls_root_certs: Optional[bytes] = None, bookmarks: Optional[Any] = None, show_progress: bool = True, - ) -> AuraGraphDataScience: - session_bolt_query_runner = Neo4jQueryRunner.create_for_session( - endpoint=session_bolt_connection_info.uri, - auth=session_bolt_connection_info.get_auth(), - show_progress=show_progress, + ): + self.session_bolt_connection_info = session_bolt_connection_info + self.arrow_authentication = arrow_authentication + self.db_endpoint = db_endpoint + self.delete_fn = delete_fn + self.arrow_disable_server_verification = arrow_disable_server_verification + self.arrow_tls_root_certs = arrow_tls_root_certs + self.bookmarks = bookmarks + self.show_progress = show_progress + + def create(self) -> AuraGraphDataScience: + """Create and configure an AuraGraphDataScience instance.""" + session_bolt_query_runner = self._create_session_bolt_query_runner() + arrow_info = ArrowInfo.create(session_bolt_query_runner) + session_arrow_query_runner = self._create_session_arrow_query_runner(session_bolt_query_runner, arrow_info) + session_arrow_client = self._create_session_arrow_client(arrow_info, session_bolt_query_runner) + gds_version = session_bolt_query_runner.server_version() + + session_query_runner: QueryRunner + + if self.db_endpoint is not None: + db_bolt_query_runner = self._create_db_bolt_query_runner() + session_query_runner = SessionQueryRunner.create( + session_arrow_query_runner, db_bolt_query_runner, session_arrow_client, self.show_progress + ) + wcc_endpoints = self._create_wcc_endpoints(arrow_info, session_bolt_query_runner, db_bolt_query_runner) + else: + session_query_runner = StandaloneSessionQueryRunner(session_arrow_query_runner) + wcc_endpoints = self._create_wcc_endpoints(arrow_info, session_bolt_query_runner, None) + + return AuraGraphDataScience( + query_runner=session_query_runner, + wcc_endpoints=wcc_endpoints, + delete_fn=self.delete_fn, + gds_version=gds_version, ) - arrow_info = ArrowInfo.create(session_bolt_query_runner) - session_arrow_query_runner = ArrowQueryRunner.create( + def _create_session_bolt_query_runner(self) -> Neo4jQueryRunner: + return Neo4jQueryRunner.create_for_session( + endpoint=self.session_bolt_connection_info.uri, + auth=self.session_bolt_connection_info.get_auth(), + show_progress=self.show_progress, + ) + + def _create_session_arrow_query_runner( + self, session_bolt_query_runner: Neo4jQueryRunner, arrow_info: ArrowInfo + ) -> ArrowQueryRunner: + return ArrowQueryRunner.create( fallback_query_runner=session_bolt_query_runner, arrow_info=arrow_info, - arrow_authentication=arrow_authentication, + arrow_authentication=self.arrow_authentication, encrypted=session_bolt_query_runner.encrypted(), - disable_server_verification=arrow_disable_server_verification, - tls_root_certs=arrow_tls_root_certs, + disable_server_verification=self.arrow_disable_server_verification, + tls_root_certs=self.arrow_tls_root_certs, ) - # TODO: merge with the gds_arrow_client created inside ArrowQueryRunner - session_arrow_client = GdsArrowClient.create( + def _create_session_arrow_client( + self, arrow_info: ArrowInfo, session_bolt_query_runner: Neo4jQueryRunner + ) -> GdsArrowClient: + return GdsArrowClient.create( arrow_info, - arrow_authentication, + self.arrow_authentication, session_bolt_query_runner.encrypted(), - arrow_disable_server_verification, - arrow_tls_root_certs, + self.arrow_disable_server_verification, + self.arrow_tls_root_certs, ) - gds_version = session_bolt_query_runner.server_version() - - if db_endpoint is not None: - if isinstance(db_endpoint, Neo4jQueryRunner): - db_bolt_query_runner = db_endpoint - else: - db_bolt_query_runner = Neo4jQueryRunner.create_for_db( - db_endpoint.uri, - db_endpoint.get_auth(), - aura_ds=True, - show_progress=False, - database=db_endpoint.database, - ) - db_bolt_query_runner.set_bookmarks(bookmarks) - - session_query_runner = SessionQueryRunner.create( - session_arrow_query_runner, db_bolt_query_runner, session_arrow_client, show_progress - ) - return cls( - query_runner=session_query_runner, - delete_fn=delete_fn, - gds_version=gds_version, + def _create_db_bolt_query_runner(self) -> Neo4jQueryRunner: + if isinstance(self.db_endpoint, Neo4jQueryRunner): + db_bolt_query_runner = self.db_endpoint + elif isinstance(self.db_endpoint, DbmsConnectionInfo): + db_bolt_query_runner = Neo4jQueryRunner.create_for_db( + self.db_endpoint.uri, + self.db_endpoint.get_auth(), + aura_ds=True, + show_progress=False, + database=self.db_endpoint.database, ) else: - standalone_query_runner = StandaloneSessionQueryRunner(session_arrow_query_runner) - return cls( - query_runner=standalone_query_runner, - delete_fn=delete_fn, - gds_version=gds_version, + raise ValueError("db_endpoint must be a Neo4jQueryRunner or a DbmsConnectionInfo") + + db_bolt_query_runner.set_bookmarks(self.bookmarks) + return db_bolt_query_runner + + def _create_wcc_endpoints( + self, arrow_info: ArrowInfo, session_bolt_query_runner: Neo4jQueryRunner, db_query_runner: Optional[QueryRunner] + ) -> Optional[WccEndpoints]: + wcc_endpoints: Optional[WccEndpoints] = None + if os.environ.get("ENABLE_EXPLICIT_ENDPOINTS") is not None: + arrow_client = AuthenticatedArrowClient.create( + arrow_info, + self.arrow_authentication, + session_bolt_query_runner.encrypted(), + self.arrow_disable_server_verification, + self.arrow_tls_root_certs, ) - def __init__( - self, - query_runner: QueryRunner, - delete_fn: Callable[[], bool], - gds_version: ServerVersion, - ): - self._query_runner = query_runner - self._delete_fn = delete_fn - self._server_version = gds_version + write_back_client = WriteBackClient(arrow_client, db_query_runner) if db_query_runner is not None else None - super().__init__(self._query_runner, namespace="gds", server_version=self._server_version) + wcc_endpoints = WccArrowEndpoints(arrow_client, write_back_client) + return wcc_endpoints + + +class AuraGraphDataScience(DirectEndpoints, UncallableNamespace): + """ + Primary API class for interacting with Neo4j database + Graph Data Science Session. + Always bind this object to a variable called `gds`. + """ def run_cypher( self, @@ -133,6 +174,20 @@ def run_cypher( """ return self._query_runner.run_cypher(query, params, database, False) + def __init__( + self, + query_runner: QueryRunner, + delete_fn: Callable[[], bool], + gds_version: ServerVersion, + wcc_endpoints: Optional[WccEndpoints] = None, + ): + self._query_runner = query_runner + self._delete_fn = delete_fn + self._server_version = gds_version + self._wcc_endpoints = wcc_endpoints + + super().__init__(self._query_runner, namespace="gds", server_version=self._server_version) + @property def graph(self) -> GraphRemoteProcRunner: return GraphRemoteProcRunner(self._query_runner, f"{self._namespace}.graph", self._server_version) @@ -149,6 +204,13 @@ def alpha(self) -> AlphaRemoteEndpoints: def beta(self) -> BetaEndpoints: return BetaEndpoints(self._query_runner, "gds.beta", self._server_version) + @property + def wcc(self) -> Union[WccEndpoints, IndirectCallBuilder]: + if self._wcc_endpoints is None: + return IndirectCallBuilder(self._query_runner, f"gds.{self._namespace}.wcc", self._server_version) + + return self._wcc_endpoints + def __getattr__(self, attr: str) -> IndirectCallBuilder: return IndirectCallBuilder(self._query_runner, f"gds.{attr}", self._server_version) diff --git a/graphdatascience/session/dedicated_sessions.py b/graphdatascience/session/dedicated_sessions.py index f39a133c8..b95399c2c 100644 --- a/graphdatascience/session/dedicated_sessions.py +++ b/graphdatascience/session/dedicated_sessions.py @@ -11,7 +11,7 @@ from graphdatascience.session.aura_api import AuraApi from graphdatascience.session.aura_api_responses import SessionDetails from graphdatascience.session.aura_api_token_authentication import AuraApiTokenAuthentication -from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience +from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience, AuraGraphDataScienceFactory from graphdatascience.session.cloud_location import CloudLocation from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo from graphdatascience.session.session_info import SessionInfo @@ -210,9 +210,9 @@ def _construct_client( arrow_authentication: ArrowAuthentication, db_runner: Optional[Neo4jQueryRunner], ) -> AuraGraphDataScience: - return AuraGraphDataScience.create( + return AuraGraphDataScienceFactory( session_bolt_connection_info=session_bolt_connection_info, arrow_authentication=arrow_authentication, db_endpoint=db_runner, delete_fn=lambda: self._aura_api.delete_session(session_id=session_id), - ) + ).create() diff --git a/graphdatascience/tests/integration/conftest.py b/graphdatascience/tests/integration/conftest.py index 439e1eea0..ab4ad6822 100644 --- a/graphdatascience/tests/integration/conftest.py +++ b/graphdatascience/tests/integration/conftest.py @@ -10,7 +10,7 @@ from graphdatascience.graph_data_science import GraphDataScience from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner from graphdatascience.server_version.server_version import ServerVersion -from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience +from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience, AuraGraphDataScienceFactory from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687") @@ -92,12 +92,13 @@ def gds_without_arrow() -> Generator[GraphDataScience, None, None]: @pytest.fixture(scope="package", autouse=False) def gds_with_cloud_setup(request: pytest.FixtureRequest) -> Generator[AuraGraphDataScience, None, None]: - _gds = AuraGraphDataScience.create( + _gds = AuraGraphDataScienceFactory( session_bolt_connection_info=DbmsConnectionInfo(URI, AUTH[0], AUTH[1]), arrow_authentication=UsernamePasswordAuthentication(AUTH[0], AUTH[1]), db_endpoint=DbmsConnectionInfo(AURA_DB_URI, AURA_DB_AUTH[0], AURA_DB_AUTH[1]), delete_fn=lambda: True, - ) + ).create() + _gds.set_database(DB) yield _gds @@ -107,12 +108,12 @@ def gds_with_cloud_setup(request: pytest.FixtureRequest) -> Generator[AuraGraphD @pytest.fixture(scope="package", autouse=False) def standalone_aura_gds() -> Generator[AuraGraphDataScience, None, None]: - _gds = AuraGraphDataScience.create( + _gds = AuraGraphDataScienceFactory( session_bolt_connection_info=DbmsConnectionInfo(URI, AUTH[0], AUTH[1]), arrow_authentication=UsernamePasswordAuthentication(AUTH[0], AUTH[1]), db_endpoint=None, delete_fn=lambda: True, - ) + ).create() yield _gds diff --git a/graphdatascience/tests/integration/test_remote_graph_ops.py b/graphdatascience/tests/integration/test_remote_graph_ops.py index 7eba37c64..ad5ea078f 100644 --- a/graphdatascience/tests/integration/test_remote_graph_ops.py +++ b/graphdatascience/tests/integration/test_remote_graph_ops.py @@ -62,7 +62,7 @@ def test_remote_projection_and_writeback_custom_database_name(gds_with_cloud_set assert projection_result["nodeCount"] == 2 assert projection_result["relationshipCount"] == 1 - write_result = gds_with_cloud_setup.wcc.write(G, writeProperty="wcc") + write_result = gds_with_cloud_setup.wcc.write(G, writeProperty="wcc") # type: ignore assert write_result["nodePropertiesWritten"] == 2 count_wcc_nodes_query = "MATCH (n WHERE n.wcc IS NOT NULL) RETURN count(*) AS c" @@ -234,6 +234,6 @@ def test_empty_graph_write_back( assert G.node_count() == 0 - result = gds_with_cloud_setup.wcc.write(G, writeProperty="wcc") + result = gds_with_cloud_setup.wcc.write(G, writeProperty="wcc") # type: ignore assert result["nodePropertiesWritten"] == 0 diff --git a/graphdatascience/tests/integration/test_simple_algo.py b/graphdatascience/tests/integration/test_simple_algo.py index c7fbbba1d..b9fd0297d 100644 --- a/graphdatascience/tests/integration/test_simple_algo.py +++ b/graphdatascience/tests/integration/test_simple_algo.py @@ -57,7 +57,7 @@ def test_wcc_stats(gds: GraphDataScience) -> None: def test_wcc_stats_estimate(gds: GraphDataScience) -> None: G, _ = gds.graph.project(GRAPH_NAME, "*", "*") - result = gds.wcc.stats.estimate(G) + result = gds.wcc.stats.estimate(G) # type: ignore assert result["requiredMemory"] diff --git a/graphdatascience/tests/unit/conftest.py b/graphdatascience/tests/unit/conftest.py index ca88ff228..01bd31301 100644 --- a/graphdatascience/tests/unit/conftest.py +++ b/graphdatascience/tests/unit/conftest.py @@ -18,7 +18,7 @@ from graphdatascience.query_runner.graph_constructor import GraphConstructor from graphdatascience.query_runner.query_mode import QueryMode from graphdatascience.server_version.server_version import ServerVersion -from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience +from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience, AuraGraphDataScienceFactory from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo # Should mirror the latest GDS server version under development. @@ -194,12 +194,12 @@ def aura_gds(runner: CollectingQueryRunner, mocker: MockerFixture) -> Generator[ mocker.patch("graphdatascience.query_runner.arrow_query_runner.ArrowQueryRunner.create", return_value=runner) mocker.patch("graphdatascience.query_runner.gds_arrow_client.GdsArrowClient.create", return_value=None) - aura_gds = AuraGraphDataScience.create( + aura_gds = AuraGraphDataScienceFactory( session_bolt_connection_info=DbmsConnectionInfo("address", "some", "auth"), arrow_authentication=UsernamePasswordAuthentication("some", "auth"), db_endpoint=DbmsConnectionInfo("address", "some", "auth"), delete_fn=lambda: True, - ) + ).create() yield aura_gds aura_gds.close()