Skip to content

Commit 84b8cef

Browse files
committed
Parse remote ML configs
1 parent 8a681b4 commit 84b8cef

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

graphdatascience/gnn/gnn_nc_runner.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def train(
1212
model_name: str,
1313
feature_properties: List[str],
1414
target_property: str,
15+
relationship_types: List[str],
1516
target_node_label: str = None,
1617
node_labels: List[str] = None,
1718
) -> "Series[Any]": # noqa: F821
@@ -20,6 +21,7 @@ def train(
2021
"targetProperty": target_property,
2122
"job_type": "train",
2223
"nodeProperties": feature_properties + [target_property],
24+
"relationshipTypes": relationship_types
2325
}
2426

2527
if target_node_label:
@@ -31,10 +33,9 @@ def train(
3133

3234
# token and uri will be injected by arrow_query_runner
3335
self._query_runner.run_query(
34-
"CALL gds.upload.graph($graph_name, $config)",
36+
"CALL gds.upload.graph($config)",
3537
params={
36-
"graph_name": graph_name,
37-
"config": {"mlTrainingConfig": mlTrainingConfig, "modelName": model_name},
38+
"config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name},
3839
},
3940
)
4041

@@ -43,13 +44,15 @@ def predict(
4344
graph_name: str,
4445
model_name: str,
4546
feature_properties: List[str],
47+
relationship_types: List[str],
4648
target_node_label: str = None,
4749
node_labels: List[str] = None,
4850
) -> "Series[Any]": # noqa: F821
4951
mlConfigMap = {
5052
"featureProperties": feature_properties,
5153
"job_type": "predict",
5254
"nodeProperties": feature_properties,
55+
"relationshipTypes": relationship_types
5356
}
5457
if target_node_label:
5558
mlConfigMap["targetNodeLabel"] = target_node_label
@@ -58,9 +61,8 @@ def predict(
5861

5962
mlTrainingConfig = json.dumps(mlConfigMap)
6063
self._query_runner.run_query(
61-
"CALL gds.upload.graph($graph_name, $config)",
64+
"CALL gds.upload.graph($config)",
6265
params={
63-
"graph_name": graph_name,
64-
"config": {"mlTrainingConfig": mlTrainingConfig, "modelName": model_name},
66+
"config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name},
6567
},
6668
) # type: ignore

0 commit comments

Comments
 (0)