Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ subprojects {
compileJava {
options.annotationProcessorPath = configurations.apt
options.compilerArgs += ["-AIgnoreContextWarnings"]
options.compilerArgs += ["--add-modules=jdk.incubator.vector"]
options.encoding = "UTF-8"
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
package apoc.neo4j.docker;


import apoc.util.Neo4jContainerExtension;
import apoc.util.TestContainerUtil;
import apoc.util.Util;
import org.junit.*;
import org.neo4j.driver.Session;
import org.neo4j.graphdb.Result;

import java.util.List;
import java.util.Map;

import static apoc.util.TestContainerUtil.createEnterpriseDB;
import static apoc.util.ExtendedTestContainerUtil.singleResultFirstColumn;

public class SimilarityEntepriseTest {

// private static List nodes = null;

private static Neo4jContainerExtension neo4jContainer;
private static Session session;

@BeforeClass
public static void beforeAll() throws InterruptedException {
neo4jContainer = createEnterpriseDB(List.of(TestContainerUtil.ApocPackage.EXTENDED), true)
.withNeo4jConfig("apoc.import.file.enabled", "true")
.withNeo4jConfig("metrics.enabled", "true")
.withNeo4jConfig("metrics.csv.interval", "1s")
.withNeo4jConfig("dbms.memory.transaction.total.max", "1G")
.withNeo4jConfig("server.memory.heap.initial_size", "1G")
.withNeo4jConfig("server.memory.heap.max_size", "1G")
.withNeo4jConfig("server.memory.heap.max_size", "1G")
.withNeo4jConfig("internal.cypher.enable_vector_type", "true")
.withNeo4jConfig("metrics.namespaces.enabled", "true");
neo4jContainer.start();
session = neo4jContainer.getSession();

session.executeWrite(tx -> tx.run(
"CYPHER 25 UNWIND range(0, 50000) as id " +
"CREATE (:Similar {vect: VECTOR([1, 2, id], 3, INTEGER32), id: 1, test: 1}), (:Similar {vect: VECTOR([1, id, 3], 3, INTEGER32), id: 2}), (:Similar {id: 3}), (:Similar {vect: VECTOR([3, 2, id], 3, INTEGER32), ajeje: 1, id: 4}), (:Similar {vect: VECTOR([4, 2, id], 3, INTEGER32), brazorf: 1, id: 5})"
).consume()
);

// todo - i can't use it: Struct tag: 0x56 representing type VECTOR is not supported for this protocol version
// nodes = singleResultFirstColumn(session, "MATCH (n:Similar) RETURN collect(n) AS nodes");

// try (Transaction tx = db.beginTx()) {
// tx.findNodes(Label.label("Similar")).forEachRemaining(i -> {
// i.setProperty("embedding", new float[]{1, 2, 4});
// });
// tx.commit();
// }
}

@AfterClass
public static void afterAll() {
neo4jContainer.close();
}

@Test
public void testSimilarityCompare() {
long before = System.currentTimeMillis();
String s = session.executeRead(tx -> tx.run(
"CYPHER 25 MATCH (node:Similar) WITH COLLECT(node) AS nodes " +
"CALL custom.search.batchedSimilarity(nodes, 'vect', VECTOR([1, 2, 3], 3, INTEGER32), 5, 0.8) YIELD node, score " +
"RETURN node.id, score",
// "CALL custom.search.batchedSimilarity($nodes, 'null', null, 5, 0.8, {stopWhenFound: true}) YIELD node, score RETURN node, score",
Map.of(/*"nodes", nodes*/)).list().toString());
long after = System.currentTimeMillis();
System.out.println("after - before apoc proc= " + (after - before));

System.out.println("s = " + s);
}

// TODO - maybe this part: https://neo4j.com/docs/genai/tutorials/embeddings-vector-indexes/embeddings/compute-similarity/
// runs faster..

// TODO - https://neo4j.com/docs/genai/tutorials/embeddings-vector-indexes/embeddings/compute-similarity/
// it seems the Public APIs still need to be written.
// Maybe once they are written, it will be possible to operate with the Java Vector API and SIMD??

// todo - also cypher with float array
// todo - I have this warning: WARNING: Java vector incubator module is not readable. For optimal vector performance, pass '--add-modules jdk.incubator.vector' to enable Vector API.
// @Ignore
@Test
public void testSimilarityWithPureCypherInBatch() {

long before = System.currentTimeMillis();
String cypherRes = session.executeRead(tx -> tx.run("""
CYPHER 25
MATCH (node:Similar)
WITH COLLECT(node) AS nodes

UNWIND nodes AS node
WITH node, vector.similarity.cosine(node.vect, VECTOR([1, 2, 3], 3, INTEGER32)) AS score
WHERE score >= $threshold
RETURN node.id, score
ORDER BY score DESC
LIMIT $topK
""", Map.of(/*"nodes", nodes, */"threshold", 0.8, "queryVector", new float[]{1, 2, 3}, "topK", 5)).list().toString());
long after = System.currentTimeMillis();
System.out.println("after - before pure cypher = " + (after - before));
System.out.println("cypherRes = " + cypherRes);
}

// todo - remove
@Ignore
@Test
public void testSimilarity() {
long before = System.currentTimeMillis();
String s = session.executeRead(tx -> tx.run(
"MATCH (n:Similar) WITH collect(n) AS nodes CALL custom.search.batchedSimilarity(nodes, 'null', null, 2, 0.8) YIELD node, score RETURN node, score",
Map.of()).list().toString());
long after = System.currentTimeMillis();
System.out.println("after - before = " + (after - before));

System.out.println("s = " + s);


String s1 = session.executeRead(tx -> tx.run(
"MATCH (n:Similar) WITH collect(n) AS nodes CALL custom.search.batchedSimilarity(nodes, 'null', null, 2, 0.95) YIELD node, score RETURN node, score",
Map.of()).list().toString());

System.out.println("s = " + s1);
String s2 = session.executeRead(tx -> tx.run(
"MATCH (n:Similar) WITH collect(n) AS nodes CALL custom.search.batchedSimilarity(nodes, 'null', null, 5, 0.8) YIELD node, score RETURN node, score",
Map.of()).list().toString());

System.out.println("s = " + s2);


String s12 = session.executeRead(tx -> tx.run(
"MATCH (n:Similar) WITH collect(n) AS nodes CALL custom.search.batchedSimilarity(nodes, 'null', null, 5, 0.95) YIELD node, score RETURN node, score",
Map.of()).list().toString());

System.out.println("s = " + s12);
}

// todo - remove
@Ignore
@Test
public void testSimilarityWithStopWhenFound() {
String s = session.executeRead(tx -> tx.run(
"MATCH (n:Similar) WITH collect(n) AS nodes CALL custom.search.batchedSimilarity(nodes, 'null', null, 2, 0.8, {stopWhenFound: true}) YIELD node, score RETURN node, score",
Map.of()).list().toString());

System.out.println("stopWhenFound = " + s);
}

// todo - remove
//@Ignore
@Test
@Ignore
public void testSimilarityWithPureCypher() {
// try (Transaction tx = db.beginTx()) {
// tx.findNodes(Label.label("Similar")).forEachRemaining(i -> {
// i.setProperty("embedding", new float[]{1, 2, 4});
// });
// tx.commit();
// }

long before = System.currentTimeMillis();
String cypherRes = session.executeRead(tx -> tx.run("""
MATCH (node:Similar)
// UNWIND $nodes AS node
// 2. Calcola la similarità per ogni nodo
WITH node, vector.similarity.cosine(node.embedding, $queryVector) AS score
// 3. Filtra i risultati che superano la soglia
WHERE score >= $threshold
// 4. Restituisce il nodo e il punteggio, ordinando per trovare i migliori K
RETURN node, score
ORDER BY score DESC
LIMIT $topK
""", Map.of("threshold", 0.8, "queryVector", new float[]{1, 2, 3}, "topK", 5)).list().toString());
System.out.println("cypherRes = " + cypherRes);
long after = System.currentTimeMillis();
System.out.println("after - before cypher match = " + (after - before));
}


// todo - pure cypher with float vector
// todo - pure cypher with float vector
// todo - pure cypher with float vector

}
72 changes: 72 additions & 0 deletions extended/src/main/java/apoc/algo/Neo4jVectorSimilaritySIMD.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package apoc.algo;

import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorSpecies;

// TODO - compare..
// servirebbe una cosa di questa, ma non posso farlo..
public class Neo4jVectorSimilaritySIMD {

// Seleziona la "forma" del vettore SIMD più grande disponibile sulla CPU, fino a 256 bit.
// Questo caricherà 8 float (8 * 32bit = 256bit) alla volta.
private static final VectorSpecies<Float> SPECIES = FloatVector.SPECIES_256;

/**
* Calcola la similarità per vettori FLOAT32 usando la Java Vector API per l'accelerazione SIMD.
*/
private double calculateFloat32_SIMD(float[] v1, float[] v2) {
// Inizializza i vettori accumulatori a zero. Questi conterranno somme parziali.
FloatVector dotProductVec = FloatVector.zero(SPECIES);
FloatVector normAVec = FloatVector.zero(SPECIES);
FloatVector normBVec = FloatVector.zero(SPECIES);

// Calcola il limite superiore per il ciclo vettoriale.
// Assicura che processiamo solo blocchi completi.
int loopBound = SPECIES.loopBound(v1.length);

// --- CICLO VETTORIALE (SIMD) ---
// Processa i dati in blocchi della dimensione di SPECIES (es. 8 elementi alla volta).
for (int i = 0; i < loopBound; i += SPECIES.length()) {
// Carica un blocco di dati dagli array Java nei vettori SIMD
FloatVector va = FloatVector.fromArray(SPECIES, v1, i);
FloatVector vb = FloatVector.fromArray(SPECIES, v2, i);

// Calcola il prodotto scalare parziale usando FMA (Fused Multiply-Add: a * b + c)
// È più efficiente di una moltiplicazione seguita da un'addizione.
dotProductVec = va.fma(vb, dotProductVec);

// Calcola le norme parziali
normAVec = va.fma(va, normAVec); // va * va + normAVec
normBVec = vb.fma(vb, normBVec); // vb * vb + normBVec
}

// "Riduci" i vettori accumulatori a un singolo valore scalare (double)
// Sommando tutte le "lane" (corsie) del vettore SIMD.
double dotProduct = dotProductVec.reduceLanes(VectorOperators.ADD);
double normA = normAVec.reduceLanes(VectorOperators.ADD);
double normB = normBVec.reduceLanes(VectorOperators.ADD);

// --- CICLO SCALARE (per la "coda") ---
// Processa gli elementi rimanenti che non rientravano in un blocco completo.
for (int i = loopBound; i < v1.length; i++) {
dotProduct += v1[i] * v2[i];
normA += v1[i] * v1[i];
normB += v2[i] * v2[i];
}

if (normA == 0.0 || normB == 0.0) {
return 0.0;
}

return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}

public double cosineSimilarity(float[] v1, float[] v2) {
if (v1.length != v2.length) {
throw new IllegalArgumentException("I vettori devono avere la stessa dimensione");
}
double rawSimilarity = calculateFloat32_SIMD(v1, v2);
return (rawSimilarity + 1) / 2.0;
}
}
Loading
Loading