diff --git a/changelog.md b/changelog.md index a6815a7ed..4d5adb698 100644 --- a/changelog.md +++ b/changelog.md @@ -6,6 +6,8 @@ ## New features +* Added the ability to upload additional node properties via the GdsArrowClient + ## Bug fixes diff --git a/graphdatascience/query_runner/gds_arrow_client.py b/graphdatascience/query_runner/gds_arrow_client.py index f65a87fcf..323fca37a 100644 --- a/graphdatascience/query_runner/gds_arrow_client.py +++ b/graphdatascience/query_runner/gds_arrow_client.py @@ -605,6 +605,87 @@ def upload_triplets( """ self._upload_data(graph_name, "triplet", triplet_data, batch_size, progress_callback) + def put_node_properties( + self, + graph_name: str, + database: str, + node_labels: Optional[Union[str, list[str]]] = None, + consecutive_ids: bool = False, + concurrency: Optional[int] = None, + ) -> None: + """ + Starts a new node properties upload process on the GDS server. + + Parameters + ---------- + graph_name : str + The name of the graph + database : str + The name of the database to which the graph belongs + node_labels : Optional[Union[str, List[str]]] + The name of the node labels to upload (default is None) + consecutive_ids : bool + Whether the node IDs in the input data are consecutive (default is False) + concurrency : Optional[int] + The number of threads used on the server side when uploading the properties + """ + config: dict[str, Any] = { + "name": graph_name, + "database_name": database, + "consecutive_ids": consecutive_ids, + } + + if concurrency: + config["concurrency"] = concurrency + if node_labels is not None: + if isinstance(node_labels, str): + config["node_labels"] = [node_labels] + else: + config["node_labels"] = node_labels + + self._send_action("PUT_NODE_PROPERTIES", config) + + def upload_node_properties( + self, + graph_name: str, + node_data: Union[pyarrow.Table, Iterable[pyarrow.RecordBatch], pandas.DataFrame], + batch_size: int = 10_000, + progress_callback: Callable[[int], None] = lambda x: None, + ) -> None: + """ + Uploads node property data to the server. + + Parameters + ---------- + graph_name : str + The name of the graph + node_data : Union[pyarrow.Table, Iterable[pyarrow.RecordBatch], DataFrame] + The node property data to upload + batch_size : int + The number of rows per batch + progress_callback : Callable[[int], None] + A callback function that is called with the number of rows uploaded after each batch + """ + self._upload_data(graph_name, "node_properties", node_data, batch_size, progress_callback) + + def put_node_properties_done(self, graph_name: str) -> NodePropertiesLoadDoneResult: + """ + Notifies the server that all node property data has been sent. + + Parameters + ---------- + graph_name : str + The name of the graph + + Returns + ------- + NodePropertiesLoadDoneResult + A result object containing the name of the graph and the number of properties loaded + """ + return NodePropertiesLoadDoneResult.from_json( + self._send_action("PUT_NODE_PROPERTIES_DONE", {"name": graph_name}) + ) + def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() # Remove the FlightClient as it isn't serializable @@ -963,3 +1044,16 @@ class TripletLoadDoneResult: @classmethod def from_json(cls, json: dict[str, Any]) -> TripletLoadDoneResult: return cls(name=json["name"], node_count=json["node_count"], relationship_count=json["relationship_count"]) + + +@dataclass(repr=True, frozen=True) +class NodePropertiesLoadDoneResult: + name: str + node_count: int + + @classmethod + def from_json(cls, json: dict[str, Any]) -> NodePropertiesLoadDoneResult: + return cls( + name=json["name"], + node_count=json["node_count"], + ) diff --git a/graphdatascience/tests/unit/test_gds_arrow_client.py b/graphdatascience/tests/unit/test_gds_arrow_client.py index bc476dc49..cb2dd2cb5 100644 --- a/graphdatascience/tests/unit/test_gds_arrow_client.py +++ b/graphdatascience/tests/unit/test_gds_arrow_client.py @@ -51,6 +51,10 @@ def do_action(self, context: Any, action: ActionParam) -> list[bytes]: response = {"name": "g", "relationship_count": 42} elif "TRIPLET_LOAD_DONE" in actionType: response = {"name": "g", "node_count": 42, "relationship_count": 1337} + elif "PUT_NODE_PROPERTIES_DONE" in actionType: + response = {"name": "g", "node_count": 42} + elif "PUT_NODE_PROPERTIES" in actionType: + response = {"name": "g"} else: response = {} return [json.dumps(response).encode("utf-8")] @@ -102,6 +106,10 @@ def do_action(self, context: Any, action: ActionParam) -> list[bytes]: response = {"name": "g", "relationship_count": 42} elif "TRIPLET_LOAD_DONE" in actionType: response = {"name": "g", "node_count": 42, "relationship_count": 1337} + elif "PUT_NODE_PROPERTIES" == actionType: + response = {"name": "g"} + elif "PUT_NODE_PROPERTIES_DONE" == actionType: + response = {"name": "g", "node_count": 42} else: response = {} return [json.dumps(response).encode("utf-8")] @@ -258,6 +266,64 @@ def test_triplet_load_done_action(flight_server: FlightServer, flight_client: Gd assert_action(actions[0], "v1/TRIPLET_LOAD_DONE", {"name": "g"}) +def test_put_node_properties_with_defaults(flight_server: FlightServer, flight_client: GdsArrowClient) -> None: + flight_client.put_node_properties("g", "DB") + actions = flight_server._actions + assert len(actions) == 1 + assert_action(actions[0], "v1/PUT_NODE_PROPERTIES", {"name": "g", "database_name": "DB", "consecutive_ids": False}) + + +def test_put_node_properties_with_single_label(flight_server: FlightServer, flight_client: GdsArrowClient) -> None: + flight_client.put_node_properties("g", "DB", "Label1") + actions = flight_server._actions + assert len(actions) == 1 + assert_action( + actions[0], + "v1/PUT_NODE_PROPERTIES", + {"name": "g", "database_name": "DB", "consecutive_ids": False, "node_labels": ["Label1"]}, + ) + + +def test_put_node_properties_with_options(flight_server: FlightServer, flight_client: GdsArrowClient) -> None: + flight_client.put_node_properties("g", "DB", ["Label1", "Label2"], consecutive_ids=True, concurrency=42) + actions = flight_server._actions + assert len(actions) == 1 + assert_action( + actions[0], + "v1/PUT_NODE_PROPERTIES", + { + "name": "g", + "database_name": "DB", + "consecutive_ids": True, + "concurrency": 42, + "node_labels": ["Label1", "Label2"], + }, + ) + + +def test_put_node_properties_with_flaky_server( + flaky_flight_server: FlakyFlightServer, flaky_flight_client: GdsArrowClient +) -> None: + flaky_flight_client.put_node_properties("g", "DB", "Label1") + actions = flaky_flight_server._actions + assert len(actions) == flaky_flight_server.expected_retries() + assert_action( + actions[0], + "v1/PUT_NODE_PROPERTIES", + {"name": "g", "database_name": "DB", "consecutive_ids": False, "node_labels": ["Label1"]}, + ) + + +def test_put_node_properties_done(flight_server: FlightServer, flight_client: GdsArrowClient) -> None: + response = flight_client.put_node_properties_done("g") + assert response.name == "g" + assert response.node_count == 42 + + actions = flight_server._actions + assert len(actions) == 1 + assert_action(actions[0], "v1/PUT_NODE_PROPERTIES_DONE", {"name": "g"}) + + def test_abort_action(flight_server: FlightServer, flight_client: GdsArrowClient) -> None: flight_client.abort("g") actions = flight_server._actions