@@ -12,6 +12,7 @@ def train(
12
12
model_name : str ,
13
13
feature_properties : List [str ],
14
14
target_property : str ,
15
+ relationship_types : List [str ],
15
16
target_node_label : str = None ,
16
17
node_labels : List [str ] = None ,
17
18
) -> "Series[Any]" : # noqa: F821
@@ -20,6 +21,7 @@ def train(
20
21
"targetProperty" : target_property ,
21
22
"job_type" : "train" ,
22
23
"nodeProperties" : feature_properties + [target_property ],
24
+ "relationshipTypes" : relationship_types
23
25
}
24
26
25
27
if target_node_label :
@@ -31,10 +33,9 @@ def train(
31
33
32
34
# token and uri will be injected by arrow_query_runner
33
35
self ._query_runner .run_query (
34
- "CALL gds.upload.graph($graph_name, $ config)" ,
36
+ "CALL gds.upload.graph($config)" ,
35
37
params = {
36
- "graph_name" : graph_name ,
37
- "config" : {"mlTrainingConfig" : mlTrainingConfig , "modelName" : model_name },
38
+ "config" : {"mlTrainingConfig" : mlTrainingConfig , "graphName" : graph_name , "modelName" : model_name },
38
39
},
39
40
)
40
41
@@ -43,13 +44,15 @@ def predict(
43
44
graph_name : str ,
44
45
model_name : str ,
45
46
feature_properties : List [str ],
47
+ relationship_types : List [str ],
46
48
target_node_label : str = None ,
47
49
node_labels : List [str ] = None ,
48
50
) -> "Series[Any]" : # noqa: F821
49
51
mlConfigMap = {
50
52
"featureProperties" : feature_properties ,
51
53
"job_type" : "predict" ,
52
54
"nodeProperties" : feature_properties ,
55
+ "relationshipTypes" : relationship_types
53
56
}
54
57
if target_node_label :
55
58
mlConfigMap ["targetNodeLabel" ] = target_node_label
@@ -58,9 +61,8 @@ def predict(
58
61
59
62
mlTrainingConfig = json .dumps (mlConfigMap )
60
63
self ._query_runner .run_query (
61
- "CALL gds.upload.graph($graph_name, $ config)" ,
64
+ "CALL gds.upload.graph($config)" ,
62
65
params = {
63
- "graph_name" : graph_name ,
64
- "config" : {"mlTrainingConfig" : mlTrainingConfig , "modelName" : model_name },
66
+ "config" : {"mlTrainingConfig" : mlTrainingConfig , "graphName" : graph_name , "modelName" : model_name },
65
67
},
66
68
) # type: ignore
0 commit comments