Skip to content

Commit fd96d72

Browse files
Merge pull request #10552 from IoannisPanagiotas/small-ref-pgr
small ref
2 parents 3dc4fcf + 80e6fd2 commit fd96d72

File tree

5 files changed

+64
-17
lines changed

5 files changed

+64
-17
lines changed

algo/src/main/java/org/neo4j/gds/pagerank/PageRankAlgorithm.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,11 @@
2424
import org.neo4j.gds.api.properties.nodes.NodePropertyValuesAdapter;
2525
import org.neo4j.gds.beta.pregel.Pregel;
2626
import org.neo4j.gds.beta.pregel.PregelComputation;
27-
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
28-
import org.neo4j.gds.termination.TerminationFlag;
2927
import org.neo4j.gds.collections.ha.HugeDoubleArray;
28+
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
3029
import org.neo4j.gds.core.utils.partition.PartitionUtils;
3130
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
32-
import org.neo4j.gds.scaling.L2Norm;
33-
import org.neo4j.gds.scaling.NoneScaler;
31+
import org.neo4j.gds.termination.TerminationFlag;
3432

3533
import java.util.Optional;
3634
import java.util.concurrent.ExecutorService;
@@ -110,8 +108,7 @@ private void scaleScores(HugeDoubleArray scores) {
110108
var scalerFactory = config.scaler();
111109
var concurrency = config.concurrency();
112110

113-
// Eigenvector produces L2NORM-scaled results by default.
114-
if (scalerFactory.type().equals(NoneScaler.TYPE) || (scalerFactory.type().equals(L2Norm.TYPE) && mode == PageRankVariant.EIGENVECTOR)) {
111+
if (!scalerFactory.workingScaler() || mode.ignoreScaling(scalerFactory)) {
115112
return;
116113
}
117114

algo/src/main/java/org/neo4j/gds/pagerank/PageRankVariant.java

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,16 @@
1919
*/
2020
package org.neo4j.gds.pagerank;
2121

22-
public enum PageRankVariant {
23-
PAGE_RANK("PageRank"),
24-
ARTICLE_RANK("ArticleRank"),
25-
EIGENVECTOR("EigenVector");
26-
27-
private final String taskName;
22+
import org.neo4j.gds.scaling.L2Norm;
23+
import org.neo4j.gds.scaling.ScalerFactory;
2824

29-
PageRankVariant(String taskName) {
30-
this.taskName = taskName;
31-
}
25+
public enum PageRankVariant {
26+
PAGE_RANK,
27+
ARTICLE_RANK,
28+
EIGENVECTOR;
3229

33-
String taskName() {
34-
return taskName;
30+
boolean ignoreScaling(ScalerFactory scalerFactory){
31+
// Eigenvector produces L2NORM-scaled results by default.
32+
return this == EIGENVECTOR && scalerFactory.type().equals(L2Norm.TYPE);
3533
}
3634
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.pagerank;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.neo4j.gds.scaling.ScalerFactory;
24+
25+
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
26+
27+
class PageRankVariantTest {
28+
29+
@Test
30+
void shouldNotIgnoreScaling(){
31+
assertThat(PageRankVariant.PAGE_RANK.ignoreScaling(ScalerFactory.parse("max"))).isFalse();
32+
assertThat(PageRankVariant.PAGE_RANK.ignoreScaling(ScalerFactory.parse("l2norm"))).isFalse();
33+
assertThat(PageRankVariant.ARTICLE_RANK.ignoreScaling(ScalerFactory.parse("max"))).isFalse();
34+
assertThat(PageRankVariant.ARTICLE_RANK.ignoreScaling(ScalerFactory.parse("l2norm"))).isFalse();
35+
assertThat(PageRankVariant.EIGENVECTOR.ignoreScaling(ScalerFactory.parse("max"))).isFalse();
36+
}
37+
@Test
38+
void shouldIgnoreScaling(){
39+
assertThat(PageRankVariant.EIGENVECTOR.ignoreScaling(ScalerFactory.parse("l2norm"))).isTrue();
40+
}
41+
42+
}

scaling-utils/src/main/java/org/neo4j/gds/scaling/ScalerFactory.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,4 +105,8 @@ ScalarScaler create(
105105
ProgressTracker progressTracker,
106106
ExecutorService executor
107107
);
108+
109+
default boolean workingScaler(){
110+
return !type().equals(NoneScaler.TYPE);
111+
}
108112
}

scaling-utils/src/test/java/org/neo4j/gds/scaling/ScalerFactoryTest.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,10 @@ void badInput() {
7171
assertThatThrownBy(() -> ScalerFactory.parse(Map.of("type", "log", "offset", false))).hasMessageContaining("The value of `offset` must be of type `Number` but was `Boolean`.");
7272
assertThatThrownBy(() -> ScalerFactory.parse(Map.of("type", "log", "offsat", 0))).hasMessageContaining("Unexpected configuration key: offsat");
7373
}
74+
75+
@Test
76+
void shouldAcceptWorkingTypes(){
77+
assertThat(ScalerFactory.parse("log").workingScaler()).isTrue();
78+
assertThat(ScalerFactory.parse("none").workingScaler()).isFalse();
79+
}
7480
}

0 commit comments

Comments
 (0)