Skip to content

Commit d54e329

Browse files
committed
Decouple embedding initializer
1 parent b9edfe5 commit d54e329

File tree

5 files changed

+75
-51
lines changed

5 files changed

+75
-51
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.embeddings.node2vec;
21+
22+
import org.neo4j.gds.utils.StringJoining;
23+
24+
import java.util.Arrays;
25+
import java.util.List;
26+
import java.util.stream.Collectors;
27+
28+
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
29+
import static org.neo4j.gds.utils.StringFormatting.toUpperCaseWithLocale;
30+
31+
public enum EmbeddingInitializer {
32+
UNIFORM,
33+
NORMALIZED;
34+
35+
private static final List<String> VALUES = Arrays
36+
.stream(EmbeddingInitializer.values())
37+
.map(EmbeddingInitializer::name)
38+
.collect(Collectors.toList());
39+
40+
public static EmbeddingInitializer parse(Object input) {
41+
if (input instanceof String) {
42+
var inputString = toUpperCaseWithLocale((String) input);
43+
44+
if (!VALUES.contains(inputString)) {
45+
throw new IllegalArgumentException(formatWithLocale(
46+
"EmbeddingInitializer `%s` is not supported. Must be one of: %s.",
47+
input,
48+
StringJoining.join(VALUES)
49+
));
50+
}
51+
52+
return valueOf(toUpperCaseWithLocale(inputString));
53+
} else if (input instanceof EmbeddingInitializer) {
54+
return (EmbeddingInitializer) input;
55+
}
56+
57+
throw new IllegalArgumentException(formatWithLocale(
58+
"Expected EmbeddingInitializer or String. Got %s.",
59+
input.getClass().getSimpleName()
60+
));
61+
}
62+
63+
public static String toString(EmbeddingInitializer embeddingInitializer) {
64+
return embeddingInitializer.toString();
65+
}
66+
}

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2Vec.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public class Node2Vec extends Algorithm<Node2VecModel.Result> {
4747
private final List<Long> sourceNodes;
4848
private final Optional<Long> maybeRandomSeed;
4949
private final TrainParameters trainParameters;
50-
private final Node2VecBaseConfig.EmbeddingInitializer embeddingInitializer;
50+
private final EmbeddingInitializer embeddingInitializer;
5151

5252

5353
public static MemoryEstimation memoryEstimation(int walksPerNode, int walkLength, int embeddingDimension) {
@@ -83,7 +83,7 @@ public Node2Vec(
8383
Optional<Long> maybeRandomSeed,
8484
ProgressTracker progressTracker,
8585
TrainParameters trainParameters,
86-
Node2VecBaseConfig.EmbeddingInitializer embeddingInitializer
86+
EmbeddingInitializer embeddingInitializer
8787
) {
8888
super(progressTracker);
8989
this.graph = graph;

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecBaseConfig.java

Lines changed: 2 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -24,53 +24,12 @@
2424
import org.neo4j.gds.config.AlgoBaseConfig;
2525
import org.neo4j.gds.config.EmbeddingDimensionConfig;
2626
import org.neo4j.gds.traversal.RandomWalkBaseConfig;
27-
import org.neo4j.gds.utils.StringJoining;
2827

29-
import java.util.Arrays;
3028
import java.util.List;
31-
import java.util.stream.Collectors;
3229

33-
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
34-
import static org.neo4j.gds.utils.StringFormatting.toUpperCaseWithLocale;
3530

3631
public interface Node2VecBaseConfig extends AlgoBaseConfig, EmbeddingDimensionConfig, RandomWalkBaseConfig {
3732

38-
enum EmbeddingInitializer {
39-
UNIFORM,
40-
NORMALIZED;
41-
42-
private static final List<String> VALUES = Arrays
43-
.stream(Node2VecBaseConfig.EmbeddingInitializer.values())
44-
.map(Node2VecBaseConfig.EmbeddingInitializer::name)
45-
.collect(Collectors.toList());
46-
public static Node2VecBaseConfig.EmbeddingInitializer parse(Object input) {
47-
if (input instanceof String) {
48-
var inputString = toUpperCaseWithLocale((String) input);
49-
50-
if (!VALUES.contains(inputString)) {
51-
throw new IllegalArgumentException(formatWithLocale(
52-
"EmbeddingInitializer `%s` is not supported. Must be one of: %s.",
53-
input,
54-
StringJoining.join(VALUES)
55-
));
56-
}
57-
58-
return valueOf(toUpperCaseWithLocale(inputString));
59-
} else if (input instanceof Node2VecBaseConfig.EmbeddingInitializer) {
60-
return (Node2VecBaseConfig.EmbeddingInitializer) input;
61-
}
62-
63-
throw new IllegalArgumentException(formatWithLocale(
64-
"Expected EmbeddingInitializer or String. Got %s.",
65-
input.getClass().getSimpleName()
66-
));
67-
}
68-
69-
public static String toString(Node2VecBaseConfig.EmbeddingInitializer embeddingInitializer) {
70-
return embeddingInitializer.toString();
71-
}
72-
}
73-
7433
@Value.Default
7534
@Configuration.IntegerRange(min = 2)
7635
default int windowSize() {
@@ -103,8 +62,8 @@ default int embeddingDimension() {
10362
}
10463

10564
@Value.Default
106-
@Configuration.ConvertWith(method = "org.neo4j.gds.embeddings.node2vec.Node2VecBaseConfig.EmbeddingInitializer#parse", inverse = Configuration.ConvertWith.INVERSE_IS_TO_MAP)
107-
@Configuration.ToMapValue("org.neo4j.gds.embeddings.node2vec.Node2VecBaseConfig.EmbeddingInitializer#toString")
65+
@Configuration.ConvertWith(method = "org.neo4j.gds.embeddings.node2vec.EmbeddingInitializer#parse", inverse = Configuration.ConvertWith.INVERSE_IS_TO_MAP)
66+
@Configuration.ToMapValue("org.neo4j.gds.embeddings.node2vec.EmbeddingInitializer#toString")
10867
default EmbeddingInitializer embeddingInitializer() {
10968
return EmbeddingInitializer.NORMALIZED;
11069
}

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecModel.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public class Node2VecModel {
5454
private final int embeddingDimension;
5555
private final int windowSize;
5656
private final int negativeSamplingRate;
57-
private final Node2VecBaseConfig.EmbeddingInitializer embeddingInitializer;
57+
private final EmbeddingInitializer embeddingInitializer;
5858
private final int concurrency;
5959
private final CompressedRandomWalks walks;
6060
private final RandomWalkProbabilities randomWalkProbabilities;
@@ -80,7 +80,7 @@ public static MemoryEstimation memoryEstimation(int embeddingDimension) {
8080
LongUnaryOperator toOriginalId,
8181
long nodeCount,
8282
TrainParameters trainParameters,
83-
Node2VecBaseConfig.EmbeddingInitializer embeddingInitializer,
83+
EmbeddingInitializer embeddingInitializer,
8484
int concurrency,
8585
Optional<Long> maybeRandomSeed,
8686
CompressedRandomWalks walks,
@@ -114,7 +114,7 @@ public static MemoryEstimation memoryEstimation(int embeddingDimension) {
114114
int windowSize,
115115
int negativeSamplingRate,
116116
int embeddingDimension,
117-
Node2VecBaseConfig.EmbeddingInitializer embeddingInitializer,
117+
EmbeddingInitializer embeddingInitializer,
118118
int concurrency,
119119
Optional<Long> maybeRandomSeed,
120120
CompressedRandomWalks walks,

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecTest.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
import org.neo4j.gds.StoreLoaderBuilder;
4141
import org.neo4j.gds.TestProgressTracker;
4242
import org.neo4j.gds.api.Graph;
43+
import org.neo4j.gds.collections.ha.HugeLongArray;
44+
import org.neo4j.gds.collections.ha.HugeObjectArray;
4345
import org.neo4j.gds.collections.hsa.HugeSparseLongArray;
4446
import org.neo4j.gds.compat.Neo4jProxy;
4547
import org.neo4j.gds.compat.TestLog;
@@ -50,12 +52,9 @@
5052
import org.neo4j.gds.core.loading.construction.GraphFactory;
5153
import org.neo4j.gds.core.loading.construction.RelationshipsBuilder;
5254
import org.neo4j.gds.core.utils.Intersections;
53-
import org.neo4j.gds.collections.ha.HugeLongArray;
54-
import org.neo4j.gds.collections.ha.HugeObjectArray;
5555
import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory;
5656
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
5757
import org.neo4j.gds.core.utils.shuffle.ShuffleUtil;
58-
import org.neo4j.gds.embeddings.node2vec.Node2VecBaseConfig.EmbeddingInitializer;
5958
import org.neo4j.gds.gdl.GdlFactory;
6059
import org.neo4j.gds.ml.core.tensor.FloatVector;
6160

0 commit comments

Comments
 (0)