|
6 | 6 |
|
7 | 7 |
|
8 | 8 | class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker):
|
| 9 | + def make_graph_sage_config(self, graph_sage_config): |
| 10 | + GRAPH_SAGE_DEFAULT_CONFIG = {"layer_config": {}, "num_neighbors": [25, 10], "dropout": 0.5, |
| 11 | + "hidden_channels": 256} |
| 12 | + final_sage_config = GRAPH_SAGE_DEFAULT_CONFIG |
| 13 | + if graph_sage_config: |
| 14 | + bad_keys = [] |
| 15 | + for key in graph_sage_config: |
| 16 | + if key not in GRAPH_SAGE_DEFAULT_CONFIG: |
| 17 | + bad_keys.append(key) |
| 18 | + if len(bad_keys) > 0: |
| 19 | + raise Exception(f"Argument graph_sage_config contains invalid keys {', '.join(bad_keys)}.") |
| 20 | + |
| 21 | + final_sage_config.update(graph_sage_config) |
| 22 | + return final_sage_config |
| 23 | + |
9 | 24 | def train(
|
10 |
| - self, |
11 |
| - graph_name: str, |
12 |
| - model_name: str, |
13 |
| - feature_properties: List[str], |
14 |
| - target_property: str, |
15 |
| - relationship_types: List[str], |
16 |
| - target_node_label: str = None, |
17 |
| - node_labels: List[str] = None, |
| 25 | + self, |
| 26 | + graph_name: str, |
| 27 | + model_name: str, |
| 28 | + feature_properties: List[str], |
| 29 | + target_property: str, |
| 30 | + relationship_types: List[str], |
| 31 | + target_node_label: str = None, |
| 32 | + node_labels: List[str] = None, |
| 33 | + graph_sage_config = None |
18 | 34 | ) -> "Series[Any]": # noqa: F821
|
| 35 | + |
19 | 36 | mlConfigMap = {
|
20 | 37 | "featureProperties": feature_properties,
|
21 | 38 | "targetProperty": target_property,
|
22 | 39 | "job_type": "train",
|
23 | 40 | "nodeProperties": feature_properties + [target_property],
|
24 |
| - "relationshipTypes": relationship_types |
| 41 | + "relationshipTypes": relationship_types, |
| 42 | + "graph_sage_config": self.make_graph_sage_config(graph_sage_config) |
25 | 43 | }
|
26 | 44 |
|
27 | 45 | if target_node_label:
|
|
0 commit comments