Skip to content

Add low-level optimized Neon, AVX2, and AVX 512 float32 vector operations #130635

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.benchmark.vector;

import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.common.logging.LogConfigurator;
import org.elasticsearch.common.logging.NodeNamePatternConverter;
import org.elasticsearch.nativeaccess.NativeAccess;
import org.elasticsearch.nativeaccess.VectorSimilarityFunctions;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;

import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.nio.ByteOrder;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;

@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@State(Scope.Benchmark)
@Warmup(iterations = 3, time = 1)
@Measurement(iterations = 5, time = 1)
public class JDKVectorFloat32Benchmark {

static {
NodeNamePatternConverter.setGlobalNodeName("foo");
LogConfigurator.loadLog4jPlugins();
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}

static final ValueLayout.OfFloat LAYOUT_LE_FLOAT = ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);

float[] floatsA;
float[] floatsB;
float[] scratch;
MemorySegment heapSegA, heapSegB;
MemorySegment nativeSegA, nativeSegB;

Arena arena;

@Param({ "1", "128", "207", "256", "300", "512", "702", "1024", "1536", "2048" })
public int size;

@Setup(Level.Iteration)
public void init() {
ThreadLocalRandom random = ThreadLocalRandom.current();

floatsA = new float[size];
floatsB = new float[size];
scratch = new float[size];
for (int i = 0; i < size; ++i) {
floatsA[i] = random.nextFloat();
floatsB[i] = random.nextFloat();
}
heapSegA = MemorySegment.ofArray(floatsA);
heapSegB = MemorySegment.ofArray(floatsB);

arena = Arena.ofConfined();
nativeSegA = arena.allocate((long) floatsA.length * Float.BYTES);
MemorySegment.copy(MemorySegment.ofArray(floatsA), LAYOUT_LE_FLOAT, 0L, nativeSegA, LAYOUT_LE_FLOAT, 0L, floatsA.length);
nativeSegB = arena.allocate((long) floatsB.length * Float.BYTES);
MemorySegment.copy(MemorySegment.ofArray(floatsB), LAYOUT_LE_FLOAT, 0L, nativeSegB, LAYOUT_LE_FLOAT, 0L, floatsB.length);
}

@TearDown
public void teardown() {
arena.close();
}

// -- cosine

@Benchmark
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public float cosineLucene() {
return VectorUtil.cosine(floatsA, floatsB);
}

@Benchmark
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public float cosineLuceneWithCopy() {
// add a copy to better reflect what Lucene has to do to get the target vector on-heap
MemorySegment.copy(nativeSegB, LAYOUT_LE_FLOAT, 0L, scratch, 0, scratch.length);
return VectorUtil.cosine(floatsA, scratch);
}

@Benchmark
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public float cosineNativeWithNativeSeg() {
return cosineFloat32(nativeSegA, nativeSegB, size);
}

@Benchmark
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public float cosineNativeWithHeapSeg() {
return cosineFloat32(heapSegA, heapSegB, size);
}

// -- dot product

@Benchmark
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public float dotProductLucene() {
return VectorUtil.dotProduct(floatsA, floatsB);
}

@Benchmark
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public float dotProductLuceneWithCopy() {
// add a copy to better reflect what Lucene has to do to get the target vector on-heap
MemorySegment.copy(nativeSegB, LAYOUT_LE_FLOAT, 0L, scratch, 0, scratch.length);
return VectorUtil.dotProduct(floatsA, scratch);
}

@Benchmark
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public float dotProductNativeWithNativeSeg() {
return dotProductFloat32(nativeSegA, nativeSegB, size);
}

@Benchmark
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public float dotProductNativeWithHeapSeg() {
return dotProductFloat32(heapSegA, heapSegB, size);
}

// -- square distance

@Benchmark
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public float squareDistanceLucene() {
return VectorUtil.squareDistance(floatsA, floatsB);
}

@Benchmark
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public float squareDistanceLuceneWithCopy() {
// add a copy to better reflect what Lucene has to do to get the target vector on-heap
MemorySegment.copy(nativeSegB, LAYOUT_LE_FLOAT, 0L, scratch, 0, scratch.length);
return VectorUtil.squareDistance(floatsA, scratch);
}

@Benchmark
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public float squareDistanceNativeWithNativeSeg() {
return squareDistanceFloat32(nativeSegA, nativeSegB, size);
}

@Benchmark
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public float squareDistanceNativeWithHeapSeg() {
return squareDistanceFloat32(heapSegA, heapSegB, size);
}

static final VectorSimilarityFunctions vectorSimilarityFunctions = vectorSimilarityFunctions();

static VectorSimilarityFunctions vectorSimilarityFunctions() {
return NativeAccess.instance().getVectorSimilarityFunctions().get();
}

float cosineFloat32(MemorySegment a, MemorySegment b, int length) {
try {
return (float) vectorSimilarityFunctions.cosineHandleFloat32().invokeExact(a, b, length);
} catch (Throwable e) {
if (e instanceof Error err) {
throw err;
} else if (e instanceof RuntimeException re) {
throw re;
} else {
throw new RuntimeException(e);
}
}
}

float dotProductFloat32(MemorySegment a, MemorySegment b, int length) {
try {
return (float) vectorSimilarityFunctions.dotProductHandleFloat32().invokeExact(a, b, length);
} catch (Throwable e) {
if (e instanceof Error err) {
throw err;
} else if (e instanceof RuntimeException re) {
throw re;
} else {
throw new RuntimeException(e);
}
}
}

float squareDistanceFloat32(MemorySegment a, MemorySegment b, int length) {
try {
return (float) vectorSimilarityFunctions.squareDistanceHandleFloat32().invokeExact(a, b, length);
} catch (Throwable e) {
if (e instanceof Error err) {
throw err;
} else if (e instanceof RuntimeException re) {
throw re;
} else {
throw new RuntimeException(e);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.benchmark.vector;

import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.apache.lucene.util.Constants;
import org.elasticsearch.test.ESTestCase;
import org.junit.BeforeClass;
import org.openjdk.jmh.annotations.Param;

import java.util.Arrays;

public class JDKVectorFloat32BenchmarkTests extends ESTestCase {

final double delta;
final int size;

public JDKVectorFloat32BenchmarkTests(int size) {
this.size = size;
delta = 1e-3 * size;
}

@BeforeClass
public static void skipWindows() {
assumeFalse("doesn't work on windows yet", Constants.WINDOWS);
}

static boolean supportsHeapSegments() {
return Runtime.version().feature() >= 22;
}

public void testCosine() {
for (int i = 0; i < 100; i++) {
var bench = new JDKVectorFloat32Benchmark();
bench.size = size;
bench.init();
try {
float expected = cosineFloat32Scalar(bench.floatsA, bench.floatsB);
assertEquals(expected, bench.cosineLucene(), delta);
assertEquals(expected, bench.cosineLuceneWithCopy(), delta);
assertEquals(expected, bench.cosineNativeWithNativeSeg(), delta);
if (supportsHeapSegments()) {
assertEquals(expected, bench.cosineNativeWithHeapSeg(), delta);
}
} finally {
bench.teardown();
}
}
}

public void testDotProduct() {
for (int i = 0; i < 100; i++) {
var bench = new JDKVectorFloat32Benchmark();
bench.size = size;
bench.init();
try {
float expected = dotProductFloat32Scalar(bench.floatsA, bench.floatsB);
assertEquals(expected, bench.dotProductLucene(), delta);
assertEquals(expected, bench.dotProductLuceneWithCopy(), delta);
assertEquals(expected, bench.dotProductNativeWithNativeSeg(), delta);
if (supportsHeapSegments()) {
assertEquals(expected, bench.dotProductNativeWithHeapSeg(), delta);
}
} finally {
bench.teardown();
}
}
}

public void testSquareDistance() {
for (int i = 0; i < 100; i++) {
var bench = new JDKVectorFloat32Benchmark();
bench.size = size;
bench.init();
try {
float expected = squareDistanceFloat32Scalar(bench.floatsA, bench.floatsB);
assertEquals(expected, bench.squareDistanceLucene(), delta);
assertEquals(expected, bench.squareDistanceLuceneWithCopy(), delta);
assertEquals(expected, bench.squareDistanceNativeWithNativeSeg(), delta);
if (supportsHeapSegments()) {
assertEquals(expected, bench.squareDistanceNativeWithHeapSeg(), delta);
}
} finally {
bench.teardown();
}
}
}

@ParametersFactory
public static Iterable<Object[]> parametersFactory() {
try {
var params = JDKVectorFloat32Benchmark.class.getField("size").getAnnotationsByType(Param.class)[0].value();
return () -> Arrays.stream(params).map(Integer::parseInt).map(i -> new Object[] { i }).iterator();
} catch (NoSuchFieldException e) {
throw new AssertionError(e);
}
}

/** Computes the cosine of the given vectors a and b. */
static float cosineFloat32Scalar(float[] a, float[] b) {
float dot = 0, normA = 0, normB = 0;
for (int i = 0; i < a.length; i++) {
dot += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
double normAA = Math.sqrt(normA);
double normBB = Math.sqrt(normB);
if (normAA == 0.0f || normBB == 0.0f) return 0.0f;
return (float) (dot / (normAA * normBB));
}

/** Computes the dot product of the given vectors a and b. */
static float dotProductFloat32Scalar(float[] a, float[] b) {
float res = 0;
for (int i = 0; i < a.length; i++) {
res += a[i] * b[i];
}
return res;
}

/** Computes the dot product of the given vectors a and b. */
static float squareDistanceFloat32Scalar(float[] a, float[] b) {
float squareSum = 0;
for (int i = 0; i < a.length; i++) {
float diff = a[i] - b[i];
squareSum += diff * diff;
}
return squareSum;
}
}
5 changes: 5 additions & 0 deletions docs/changelog/130635.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 130635
summary: "Add low-level optimized Neon, AVX2, and AVX 512 float32 vector operations"
area: Vector Search
type: enhancement
issues: []
2 changes: 1 addition & 1 deletion libs/native/libraries/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ configurations {
}

var zstdVersion = "1.5.5"
var vecVersion = "1.0.11"
var vecVersion = "1.0.13"

repositories {
exclusiveContent {
Expand Down
Loading
Loading