From 62a8782bcdcca9a69515588bf33faecace7bc1df Mon Sep 17 00:00:00 2001 From: ChrisHegarty Date: Fri, 4 Jul 2025 14:11:53 +0100 Subject: [PATCH 1/2] Add low-level optimized Neon, AVX2, and AVX 512 float32 vector operations. --- .../vector/JDKVectorFloat32Benchmark.java | 220 +++++++++++++++++ .../JDKVectorFloat32BenchmarkTests.java | 139 +++++++++++ libs/native/libraries/build.gradle | 2 +- .../VectorSimilarityFunctions.java | 30 +++ .../nativeaccess/jdk/JdkVectorLibrary.java | 127 +++++++++- .../jdk/JDKVectorLibraryFloat32Tests.java | 225 ++++++++++++++++++ libs/simdvec/native/publish_vec_binaries.sh | 2 +- libs/simdvec/native/src/vec/c/aarch64/vec.c | 167 +++++++++++++ libs/simdvec/native/src/vec/c/amd64/vec.c | 159 +++++++++++++ libs/simdvec/native/src/vec/c/amd64/vec_2.cpp | 146 ++++++++++++ libs/simdvec/native/src/vec/headers/vec.h | 7 + 11 files changed, 1220 insertions(+), 4 deletions(-) create mode 100644 benchmarks/src/main/java/org/elasticsearch/benchmark/vector/JDKVectorFloat32Benchmark.java create mode 100644 benchmarks/src/test/java/org/elasticsearch/benchmark/vector/JDKVectorFloat32BenchmarkTests.java create mode 100644 libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryFloat32Tests.java diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/JDKVectorFloat32Benchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/JDKVectorFloat32Benchmark.java new file mode 100644 index 0000000000000..c8517c6a6c7fc --- /dev/null +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/JDKVectorFloat32Benchmark.java @@ -0,0 +1,220 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.benchmark.vector; + +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.common.logging.NodeNamePatternConverter; +import org.elasticsearch.nativeaccess.NativeAccess; +import org.elasticsearch.nativeaccess.VectorSimilarityFunctions; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteOrder; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@State(Scope.Benchmark) +@Warmup(iterations = 3, time = 1) +@Measurement(iterations = 5, time = 1) +public class JDKVectorFloat32Benchmark { + + static { + NodeNamePatternConverter.setGlobalNodeName("foo"); + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + static final ValueLayout.OfFloat LAYOUT_LE_FLOAT = ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + + float[] floatsA; + float[] floatsB; + float[] scratch; + MemorySegment heapSegA, heapSegB; + MemorySegment nativeSegA, nativeSegB; + + Arena arena; + + @Param({ "1", "128", "207", "256", "300", "512", "702", "1024", "1536", "2048" }) + public int size; + + @Setup(Level.Iteration) + public void init() { + ThreadLocalRandom random = ThreadLocalRandom.current(); + + floatsA = new float[size]; + floatsB = new float[size]; + scratch = new float[size]; + for (int i = 0; i < size; ++i) { + floatsA[i] = random.nextFloat(); + floatsB[i] = random.nextFloat(); + } + heapSegA = MemorySegment.ofArray(floatsA); + heapSegB = MemorySegment.ofArray(floatsB); + + arena = Arena.ofConfined(); + nativeSegA = arena.allocate((long) floatsA.length * Float.BYTES); + MemorySegment.copy(MemorySegment.ofArray(floatsA), LAYOUT_LE_FLOAT, 0L, nativeSegA, LAYOUT_LE_FLOAT, 0L, floatsA.length); + nativeSegB = arena.allocate((long) floatsB.length * Float.BYTES); + MemorySegment.copy(MemorySegment.ofArray(floatsB), LAYOUT_LE_FLOAT, 0L, nativeSegB, LAYOUT_LE_FLOAT, 0L, floatsB.length); + } + + @TearDown + public void teardown() { + arena.close(); + } + + // -- cosine + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float cosineLucene() { + return VectorUtil.cosine(floatsA, floatsB); + } + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float cosineLuceneWithCopy() { + // add a copy to better reflect what Lucene has to do to get the target vector on-heap + MemorySegment.copy(nativeSegB, LAYOUT_LE_FLOAT, 0L, scratch, 0, scratch.length); + return VectorUtil.cosine(floatsA, scratch); + } + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float cosineNativeWithNativeSeg() { + return cosineFloat32(nativeSegA, nativeSegB, size); + } + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float cosineNativeWithHeapSeg() { + return cosineFloat32(heapSegA, heapSegB, size); + } + + // -- dot product + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float dotProductLucene() { + return VectorUtil.dotProduct(floatsA, floatsB); + } + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float dotProductLuceneWithCopy() { + // add a copy to better reflect what Lucene has to do to get the target vector on-heap + MemorySegment.copy(nativeSegB, LAYOUT_LE_FLOAT, 0L, scratch, 0, scratch.length); + return VectorUtil.dotProduct(floatsA, scratch); + } + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float dotProductNativeWithNativeSeg() { + return dotProductFloat32(nativeSegA, nativeSegB, size); + } + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float dotProductNativeWithHeapSeg() { + return dotProductFloat32(heapSegA, heapSegB, size); + } + + // -- square distance + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float squareDistanceLucene() { + return VectorUtil.squareDistance(floatsA, floatsB); + } + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float squareDistanceLuceneWithCopy() { + // add a copy to better reflect what Lucene has to do to get the target vector on-heap + MemorySegment.copy(nativeSegB, LAYOUT_LE_FLOAT, 0L, scratch, 0, scratch.length); + return VectorUtil.squareDistance(floatsA, scratch); + } + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float squareDistanceNativeWithNativeSeg() { + return squareDistanceFloat32(nativeSegA, nativeSegB, size); + } + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float squareDistanceNativeWithHeapSeg() { + return squareDistanceFloat32(heapSegA, heapSegB, size); + } + + static final VectorSimilarityFunctions vectorSimilarityFunctions = vectorSimilarityFunctions(); + + static VectorSimilarityFunctions vectorSimilarityFunctions() { + return NativeAccess.instance().getVectorSimilarityFunctions().get(); + } + + float cosineFloat32(MemorySegment a, MemorySegment b, int length) { + try { + return (float) vectorSimilarityFunctions.cosineHandleFloat32().invokeExact(a, b, length); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } + + float dotProductFloat32(MemorySegment a, MemorySegment b, int length) { + try { + return (float) vectorSimilarityFunctions.dotProductHandleFloat32().invokeExact(a, b, length); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } + + float squareDistanceFloat32(MemorySegment a, MemorySegment b, int length) { + try { + return (float) vectorSimilarityFunctions.squareDistanceHandleFloat32().invokeExact(a, b, length); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } +} diff --git a/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/JDKVectorFloat32BenchmarkTests.java b/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/JDKVectorFloat32BenchmarkTests.java new file mode 100644 index 0000000000000..578f9701c9ccc --- /dev/null +++ b/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/JDKVectorFloat32BenchmarkTests.java @@ -0,0 +1,139 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.benchmark.vector; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.util.Constants; +import org.elasticsearch.test.ESTestCase; +import org.junit.BeforeClass; +import org.openjdk.jmh.annotations.Param; + +import java.util.Arrays; + +public class JDKVectorFloat32BenchmarkTests extends ESTestCase { + + final double delta; + final int size; + + public JDKVectorFloat32BenchmarkTests(int size) { + this.size = size; + delta = 1e-3 * size; + } + + @BeforeClass + public static void skipWindows() { + assumeFalse("doesn't work on windows yet", Constants.WINDOWS); + } + + static boolean supportsHeapSegments() { + return Runtime.version().feature() >= 22; + } + + public void testCosine() { + for (int i = 0; i < 100; i++) { + var bench = new JDKVectorFloat32Benchmark(); + bench.size = size; + bench.init(); + try { + float expected = cosineFloat32Scalar(bench.floatsA, bench.floatsB); + assertEquals(expected, bench.cosineLucene(), delta); + assertEquals(expected, bench.cosineLuceneWithCopy(), delta); + assertEquals(expected, bench.cosineNativeWithNativeSeg(), delta); + if (supportsHeapSegments()) { + assertEquals(expected, bench.cosineNativeWithHeapSeg(), delta); + } + } finally { + bench.teardown(); + } + } + } + + public void testDotProduct() { + for (int i = 0; i < 100; i++) { + var bench = new JDKVectorFloat32Benchmark(); + bench.size = size; + bench.init(); + try { + float expected = dotProductFloat32Scalar(bench.floatsA, bench.floatsB); + assertEquals(expected, bench.dotProductLucene(), delta); + assertEquals(expected, bench.dotProductLuceneWithCopy(), delta); + assertEquals(expected, bench.dotProductNativeWithNativeSeg(), delta); + if (supportsHeapSegments()) { + assertEquals(expected, bench.dotProductNativeWithHeapSeg(), delta); + } + } finally { + bench.teardown(); + } + } + } + + public void testSquareDistance() { + for (int i = 0; i < 100; i++) { + var bench = new JDKVectorFloat32Benchmark(); + bench.size = size; + bench.init(); + try { + float expected = squareDistanceFloat32Scalar(bench.floatsA, bench.floatsB); + assertEquals(expected, bench.squareDistanceLucene(), delta); + assertEquals(expected, bench.squareDistanceLuceneWithCopy(), delta); + assertEquals(expected, bench.squareDistanceNativeWithNativeSeg(), delta); + if (supportsHeapSegments()) { + assertEquals(expected, bench.squareDistanceNativeWithHeapSeg(), delta); + } + } finally { + bench.teardown(); + } + } + } + + @ParametersFactory + public static Iterable parametersFactory() { + try { + var params = JDKVectorFloat32Benchmark.class.getField("size").getAnnotationsByType(Param.class)[0].value(); + return () -> Arrays.stream(params).map(Integer::parseInt).map(i -> new Object[] { i }).iterator(); + } catch (NoSuchFieldException e) { + throw new AssertionError(e); + } + } + + /** Computes the cosine of the given vectors a and b. */ + static float cosineFloat32Scalar(float[] a, float[] b) { + float dot = 0, normA = 0, normB = 0; + for (int i = 0; i < a.length; i++) { + dot += a[i] * b[i]; + normA += a[i] * a[i]; + normB += b[i] * b[i]; + } + double normAA = Math.sqrt(normA); + double normBB = Math.sqrt(normB); + if (normAA == 0.0f || normBB == 0.0f) return 0.0f; + return (float) (dot / (normAA * normBB)); + } + + /** Computes the dot product of the given vectors a and b. */ + static float dotProductFloat32Scalar(float[] a, float[] b) { + float res = 0; + for (int i = 0; i < a.length; i++) { + res += a[i] * b[i]; + } + return res; + } + + /** Computes the dot product of the given vectors a and b. */ + static float squareDistanceFloat32Scalar(float[] a, float[] b) { + float squareSum = 0; + for (int i = 0; i < a.length; i++) { + float diff = a[i] - b[i]; + squareSum += diff * diff; + } + return squareSum; + } +} diff --git a/libs/native/libraries/build.gradle b/libs/native/libraries/build.gradle index 58562ddcd6882..4d94ad6e20c73 100644 --- a/libs/native/libraries/build.gradle +++ b/libs/native/libraries/build.gradle @@ -19,7 +19,7 @@ configurations { } var zstdVersion = "1.5.5" -var vecVersion = "1.0.11" +var vecVersion = "1.0.13" repositories { exclusiveContent { diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java index 29a298b714fdd..4d3f6bc5b2c79 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java @@ -40,4 +40,34 @@ public interface VectorSimilarityFunctions { * vector data bytes. The third argument is the length of the vector data. */ MethodHandle squareDistanceHandle7u(); + + /** + * Produces a method handle returning the cosine of float32 vectors. + * + *

The type of the method handle will have {@code float} as return type, The type of + * its first and second arguments will be {@code MemorySegment}, whose contents is the + * vector data floats. The third argument is the length of the vector data - number of + * 4-byte float32 elements. + */ + MethodHandle cosineHandleFloat32(); + + /** + * Produces a method handle returning the dot product of float32 vectors. + * + *

The type of the method handle will have {@code float} as return type, The type of + * its first and second arguments will be {@code MemorySegment}, whose contents is the + * vector data floats. The third argument is the length of the vector data - number of + * 4-byte float32 elements. + */ + MethodHandle dotProductHandleFloat32(); + + /** + * Produces a method handle returning the square distance of float32 vectors. + * + *

The type of the method handle will have {@code float} as return type, The type of + * its first and second arguments will be {@code MemorySegment}, whose contents is the + * vector data floats. The third argument is the length of the vector data - number of + * 4-byte float32 elements. + */ + MethodHandle squareDistanceHandleFloat32(); } diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java index 2b56e65f39aae..2c429283d64ef 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java @@ -23,6 +23,7 @@ import java.util.Objects; import static java.lang.foreign.ValueLayout.ADDRESS; +import static java.lang.foreign.ValueLayout.JAVA_FLOAT; import static java.lang.foreign.ValueLayout.JAVA_INT; import static org.elasticsearch.nativeaccess.jdk.LinkerHelper.downcallHandle; @@ -32,8 +33,11 @@ public final class JdkVectorLibrary implements VectorLibrary { static final MethodHandle dot7u$mh; static final MethodHandle sqr7u$mh; + static final MethodHandle cosf32$mh; + static final MethodHandle dotf32$mh; + static final MethodHandle sqrf32$mh; - static final VectorSimilarityFunctions INSTANCE; + public static final JdkVectorSimilarityFunctions INSTANCE; static { LoaderHelper.loadLibrary("vec"); @@ -54,6 +58,21 @@ public final class JdkVectorLibrary implements VectorLibrary { FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT), LinkerHelperUtil.critical() ); + cosf32$mh = downcallHandle( + "cosf32_2", + FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT), + LinkerHelperUtil.critical() + ); + dotf32$mh = downcallHandle( + "dotf32_2", + FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT), + LinkerHelperUtil.critical() + ); + sqrf32$mh = downcallHandle( + "sqrf32_2", + FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT), + LinkerHelperUtil.critical() + ); } else { dot7u$mh = downcallHandle( "dot7u", @@ -65,6 +84,21 @@ public final class JdkVectorLibrary implements VectorLibrary { FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT), LinkerHelperUtil.critical() ); + cosf32$mh = downcallHandle( + "cosf32", + FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT), + LinkerHelperUtil.critical() + ); + dotf32$mh = downcallHandle( + "dotf32", + FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT), + LinkerHelperUtil.critical() + ); + sqrf32$mh = downcallHandle( + "sqrf32", + FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT), + LinkerHelperUtil.critical() + ); } INSTANCE = new JdkVectorSimilarityFunctions(); } else { @@ -75,6 +109,9 @@ public final class JdkVectorLibrary implements VectorLibrary { } dot7u$mh = null; sqr7u$mh = null; + cosf32$mh = null; + dotf32$mh = null; + sqrf32$mh = null; INSTANCE = null; } } catch (Throwable t) { @@ -120,7 +157,46 @@ static int squareDistance7u(MemorySegment a, MemorySegment b, int length) { return sqr7u(a, b, length); } - static void checkByteSize(MemorySegment a, MemorySegment b) { + /** + * Computes the cosine of given float32 vectors. + * + * @param a address of the first vector + * @param b address of the second vector + * @param elementCount the vector dimensions, number of float32 elements in the segment + */ + static float cosineF32(MemorySegment a, MemorySegment b, int elementCount) { + checkByteSize(a, b); + Objects.checkFromIndexSize(0, elementCount, (int) a.byteSize() / Float.BYTES); + return cosf32(a, b, elementCount); + } + + /** + * Computes the dot product of given float32 vectors. + * + * @param a address of the first vector + * @param b address of the second vector + * @param elementCount the vector dimensions, number of float32 elements in the segment + */ + static float dotProductF32(MemorySegment a, MemorySegment b, int elementCount) { + checkByteSize(a, b); + Objects.checkFromIndexSize(0, elementCount, (int) a.byteSize() / Float.BYTES); + return dotf32(a, b, elementCount); + } + + /** + * Computes the square distance of given float32 vectors. + * + * @param a address of the first vector + * @param b address of the second vector + * @param elementCount the vector dimensions, number of float32 elements in the segment + */ + static float squareDistanceF32(MemorySegment a, MemorySegment b, int elementCount) { + checkByteSize(a, b); + Objects.checkFromIndexSize(0, elementCount, (int) a.byteSize() / Float.BYTES); + return sqrf32(a, b, elementCount); + } + + private static void checkByteSize(MemorySegment a, MemorySegment b) { if (a.byteSize() != b.byteSize()) { throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize()); } @@ -142,8 +218,35 @@ private static int sqr7u(MemorySegment a, MemorySegment b, int length) { } } + private static float cosf32(MemorySegment a, MemorySegment b, int length) { + try { + return (float) JdkVectorLibrary.cosf32$mh.invokeExact(a, b, length); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private static float dotf32(MemorySegment a, MemorySegment b, int length) { + try { + return (float) JdkVectorLibrary.dotf32$mh.invokeExact(a, b, length); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private static float sqrf32(MemorySegment a, MemorySegment b, int length) { + try { + return (float) JdkVectorLibrary.sqrf32$mh.invokeExact(a, b, length); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + static final MethodHandle DOT_HANDLE_7U; static final MethodHandle SQR_HANDLE_7U; + static final MethodHandle COS_HANDLE_FLOAT32; + static final MethodHandle DOT_HANDLE_FLOAT32; + static final MethodHandle SQR_HANDLE_FLOAT32; static { try { @@ -151,6 +254,11 @@ private static int sqr7u(MemorySegment a, MemorySegment b, int length) { var mt = MethodType.methodType(int.class, MemorySegment.class, MemorySegment.class, int.class); DOT_HANDLE_7U = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProduct7u", mt); SQR_HANDLE_7U = lookup.findStatic(JdkVectorSimilarityFunctions.class, "squareDistance7u", mt); + + mt = MethodType.methodType(float.class, MemorySegment.class, MemorySegment.class, int.class); + COS_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "cosineF32", mt); + DOT_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProductF32", mt); + SQR_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "squareDistanceF32", mt); } catch (NoSuchMethodException | IllegalAccessException e) { throw new RuntimeException(e); } @@ -165,5 +273,20 @@ public MethodHandle dotProductHandle7u() { public MethodHandle squareDistanceHandle7u() { return SQR_HANDLE_7U; } + + @Override + public MethodHandle cosineHandleFloat32() { + return COS_HANDLE_FLOAT32; + } + + @Override + public MethodHandle dotProductHandleFloat32() { + return DOT_HANDLE_FLOAT32; + } + + @Override + public MethodHandle squareDistanceHandleFloat32() { + return SQR_HANDLE_FLOAT32; + } } } diff --git a/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryFloat32Tests.java b/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryFloat32Tests.java new file mode 100644 index 0000000000000..37473a85a6b96 --- /dev/null +++ b/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryFloat32Tests.java @@ -0,0 +1,225 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.nativeaccess.jdk; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.nativeaccess.VectorSimilarityFunctionsTests; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteOrder; +import java.util.function.IntFunction; + +import static java.lang.foreign.ValueLayout.JAVA_FLOAT_UNALIGNED; +import static org.hamcrest.Matchers.containsString; + +public class JDKVectorLibraryFloat32Tests extends VectorSimilarityFunctionsTests { + + static final ValueLayout.OfFloat LAYOUT_LE_FLOAT = ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + + final double delta; + + public JDKVectorLibraryFloat32Tests(int size) { + super(size); + this.delta = 1e-5 * size; // scale the delta with the size + } + + @BeforeClass + public static void beforeClass() { + VectorSimilarityFunctionsTests.setup(); + } + + @AfterClass + public static void afterClass() { + VectorSimilarityFunctionsTests.cleanup(); + } + + @ParametersFactory + public static Iterable parametersFactory() { + return VectorSimilarityFunctionsTests.parametersFactory(); + } + + public void testAllZeroValues() { + testFloat32Impl(float[]::new); + } + + public void testRandomFloats() { + testFloat32Impl(JDKVectorLibraryFloat32Tests::randomFloatArray); + } + + public void testFloat32Impl(IntFunction vectorGeneratorFunc) { + assumeTrue(notSupportedMsg(), supported()); + final int dims = size; + final int numVecs = randomIntBetween(2, 101); + var values = new float[numVecs][dims]; + var segment = arena.allocate((long) dims * numVecs * Float.BYTES); + for (int i = 0; i < numVecs; i++) { + values[i] = vectorGeneratorFunc.apply(dims); + long dstOffset = (long) i * dims * Float.BYTES; + MemorySegment.copy(MemorySegment.ofArray(values[i]), JAVA_FLOAT_UNALIGNED, 0L, segment, LAYOUT_LE_FLOAT, dstOffset, dims); + } + + final int loopTimes = 1000; + for (int i = 0; i < loopTimes; i++) { + int first = randomInt(numVecs - 1); + int second = randomInt(numVecs - 1); + var nativeSeg1 = segment.asSlice((long) first * dims * Float.BYTES, (long) dims * Float.BYTES); + var nativeSeg2 = segment.asSlice((long) second * dims * Float.BYTES, (long) dims * Float.BYTES); + + // cosine + float expected = cosineFloat32Scalar(values[first], values[second]); + assertEquals(expected, cosineFloat32(nativeSeg1, nativeSeg2, dims), delta); + if (supportsHeapSegments()) { + var heapSeg1 = MemorySegment.ofArray(values[first]); + var heapSeg2 = MemorySegment.ofArray(values[second]); + assertEquals(expected, cosineFloat32(heapSeg1, heapSeg2, dims), delta); + assertEquals(expected, cosineFloat32(nativeSeg1, heapSeg2, dims), delta); + assertEquals(expected, cosineFloat32(heapSeg1, nativeSeg2, dims), delta); + } + + // dot product + expected = dotProductFloat32Scalar(values[first], values[second]); + assertEquals(expected, dotProductFloat32(nativeSeg1, nativeSeg2, dims), delta); + if (supportsHeapSegments()) { + var heapSeg1 = MemorySegment.ofArray(values[first]); + var heapSeg2 = MemorySegment.ofArray(values[second]); + assertEquals(expected, dotProductFloat32(heapSeg1, heapSeg2, dims), delta); + assertEquals(expected, dotProductFloat32(nativeSeg1, heapSeg2, dims), delta); + assertEquals(expected, dotProductFloat32(heapSeg1, nativeSeg2, dims), delta); + } + + // square distance + expected = squareDistanceFloat32Scalar(values[first], values[second]); + assertEquals(expected, squareDistanceFloat32(nativeSeg1, nativeSeg2, dims), delta); + if (supportsHeapSegments()) { + var heapSeg1 = MemorySegment.ofArray(values[first]); + var heapSeg2 = MemorySegment.ofArray(values[second]); + assertEquals(expected, squareDistanceFloat32(heapSeg1, heapSeg2, dims), delta); + assertEquals(expected, squareDistanceFloat32(nativeSeg1, heapSeg2, dims), delta); + assertEquals(expected, squareDistanceFloat32(heapSeg1, nativeSeg2, dims), delta); + } + + } + } + + public void testIllegalDims() { + assumeTrue(notSupportedMsg(), supported()); + var segment = arena.allocate((long) size * 3 * Float.BYTES); + + var e1 = expectThrows(IAE, () -> cosineFloat32(segment.asSlice(0L, size), segment.asSlice(size, size + 1), size)); + assertThat(e1.getMessage(), containsString("dimensions differ")); + e1 = expectThrows(IAE, () -> dotProductFloat32(segment.asSlice(0L, size), segment.asSlice(size, size + 1), size)); + assertThat(e1.getMessage(), containsString("dimensions differ")); + e1 = expectThrows(IAE, () -> squareDistanceFloat32(segment.asSlice(0L, size), segment.asSlice(size, size + 1), size)); + assertThat(e1.getMessage(), containsString("dimensions differ")); + + var e2 = expectThrows(IOOBE, () -> cosineFloat32(segment.asSlice(0L, size), segment.asSlice(size, size), size + 1)); + assertThat(e2.getMessage(), containsString("out of bounds for length")); + e2 = expectThrows(IOOBE, () -> dotProductFloat32(segment.asSlice(0L, size), segment.asSlice(size, size), size + 1)); + assertThat(e2.getMessage(), containsString("out of bounds for length")); + e2 = expectThrows(IOOBE, () -> squareDistanceFloat32(segment.asSlice(0L, size), segment.asSlice(size, size), size + 1)); + assertThat(e2.getMessage(), containsString("out of bounds for length")); + + e2 = expectThrows(IOOBE, () -> cosineFloat32(segment.asSlice(0L, size), segment.asSlice(size, size), -1)); + assertThat(e2.getMessage(), containsString("out of bounds for length")); + e2 = expectThrows(IOOBE, () -> dotProductFloat32(segment.asSlice(0L, size), segment.asSlice(size, size), -1)); + assertThat(e2.getMessage(), containsString("out of bounds for length")); + e2 = expectThrows(IOOBE, () -> squareDistanceFloat32(segment.asSlice(0L, size), segment.asSlice(size, size), -1)); + assertThat(e2.getMessage(), containsString("out of bounds for length")); + } + + float cosineFloat32(MemorySegment a, MemorySegment b, int length) { + try { + return (float) getVectorDistance().cosineHandleFloat32().invokeExact(a, b, length); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } + + float dotProductFloat32(MemorySegment a, MemorySegment b, int length) { + try { + return (float) getVectorDistance().dotProductHandleFloat32().invokeExact(a, b, length); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } + + float squareDistanceFloat32(MemorySegment a, MemorySegment b, int length) { + try { + return (float) getVectorDistance().squareDistanceHandleFloat32().invokeExact(a, b, length); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } + + static float[] randomFloatArray(int length) { + float[] fa = new float[length]; + for (int i = 0; i < length; i++) { + fa[i] = randomFloat(); + } + return fa; + } + + /** Computes the cosine of the given vectors a and b. */ + static float cosineFloat32Scalar(float[] a, float[] b) { + float dot = 0, normA = 0, normB = 0; + for (int i = 0; i < a.length; i++) { + dot += a[i] * b[i]; + normA += a[i] * a[i]; + normB += b[i] * b[i]; + } + double normAA = Math.sqrt(normA); + double normBB = Math.sqrt(normB); + if (normAA == 0.0f || normBB == 0.0f) { + return 0.0f; + } + return (float) (dot / (normAA * normBB)); + } + + /** Computes the dot product of the given vectors a and b. */ + static float dotProductFloat32Scalar(float[] a, float[] b) { + float res = 0; + for (int i = 0; i < a.length; i++) { + res += a[i] * b[i]; + } + return res; + } + + /** Computes the dot product of the given vectors a and b. */ + static float squareDistanceFloat32Scalar(float[] a, float[] b) { + float squareSum = 0; + for (int i = 0; i < a.length; i++) { + float diff = a[i] - b[i]; + squareSum += diff * diff; + } + return squareSum; + } +} diff --git a/libs/simdvec/native/publish_vec_binaries.sh b/libs/simdvec/native/publish_vec_binaries.sh index e3d7e4858ecfc..0258ed5760b6b 100755 --- a/libs/simdvec/native/publish_vec_binaries.sh +++ b/libs/simdvec/native/publish_vec_binaries.sh @@ -20,7 +20,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then exit 1; fi -VERSION="1.0.11" +VERSION="1.0.13" ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}" TEMP=$(mktemp -d) diff --git a/libs/simdvec/native/src/vec/c/aarch64/vec.c b/libs/simdvec/native/src/vec/c/aarch64/vec.c index 59c0cdb2ff8ff..f3eb7f51ee5d1 100644 --- a/libs/simdvec/native/src/vec/c/aarch64/vec.c +++ b/libs/simdvec/native/src/vec/c/aarch64/vec.c @@ -9,6 +9,7 @@ #include #include +#include #include "vec.h" #ifndef DOT7U_STRIDE_BYTES_LEN @@ -132,3 +133,169 @@ EXPORT int32_t sqr7u(int8_t* a, int8_t* b, size_t dims) { } return res; } + +// --- single precision floats + +// const float *a pointer to the first float vector +// const float *b pointer to the second float vector +// size_t elementCount the number of floating point elements +EXPORT float dotf32(const float *a, const float *b, size_t elementCount) { + float32x4_t sum0 = vdupq_n_f32(0.0f); + float32x4_t sum1 = vdupq_n_f32(0.0f); + float32x4_t sum2 = vdupq_n_f32(0.0f); + float32x4_t sum3 = vdupq_n_f32(0.0f); + float32x4_t sum4 = vdupq_n_f32(0.0f); + float32x4_t sum5 = vdupq_n_f32(0.0f); + float32x4_t sum6 = vdupq_n_f32(0.0f); + float32x4_t sum7 = vdupq_n_f32(0.0f); + + size_t i = 0; + // Each float32x4_t holds 4 floats, so unroll 8x = 32 floats per loop + size_t unrolled_limit = elementCount & ~31UL; + for (; i < unrolled_limit; i += 32) { + sum0 = vfmaq_f32(sum0, vld1q_f32(a + i), vld1q_f32(b + i)); + sum1 = vfmaq_f32(sum1, vld1q_f32(a + i + 4), vld1q_f32(b + i + 4)); + sum2 = vfmaq_f32(sum2, vld1q_f32(a + i + 8), vld1q_f32(b + i + 8)); + sum3 = vfmaq_f32(sum3, vld1q_f32(a + i + 12), vld1q_f32(b + i + 12)); + sum4 = vfmaq_f32(sum4, vld1q_f32(a + i + 16), vld1q_f32(b + i + 16)); + sum5 = vfmaq_f32(sum5, vld1q_f32(a + i + 20), vld1q_f32(b + i + 20)); + sum6 = vfmaq_f32(sum6, vld1q_f32(a + i + 24), vld1q_f32(b + i + 24)); + sum7 = vfmaq_f32(sum7, vld1q_f32(a + i + 28), vld1q_f32(b + i + 28)); + } + + float32x4_t total = vaddq_f32( + vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3)), + vaddq_f32(vaddq_f32(sum4, sum5), vaddq_f32(sum6, sum7)) + ); + float result = vaddvq_f32(total); + + // Handle remaining elements + for (; i < elementCount; ++i) { + result += a[i] * b[i]; + } + + return result; +} + +// const float *a pointer to the first float vector +// const float *b pointer to the second float vector +// size_t elementCount the number of floating point elements +EXPORT float cosf32(const float *a, const float *b, size_t elementCount) { + float32x4_t sum0 = vdupq_n_f32(0.0f); + float32x4_t sum1 = vdupq_n_f32(0.0f); + float32x4_t sum2 = vdupq_n_f32(0.0f); + float32x4_t sum3 = vdupq_n_f32(0.0f); + + float32x4_t norm_a0 = vdupq_n_f32(0.0f); + float32x4_t norm_a1 = vdupq_n_f32(0.0f); + float32x4_t norm_a2 = vdupq_n_f32(0.0f); + float32x4_t norm_a3 = vdupq_n_f32(0.0f); + + float32x4_t norm_b0 = vdupq_n_f32(0.0f); + float32x4_t norm_b1 = vdupq_n_f32(0.0f); + float32x4_t norm_b2 = vdupq_n_f32(0.0f); + float32x4_t norm_b3 = vdupq_n_f32(0.0f); + + size_t i = 0; + // Each float32x4_t holds 4 floats, so unroll 4x = 16 floats per loop + size_t unrolled_limit = elementCount & ~15UL; + for (; i < unrolled_limit; i += 16) { + float32x4_t va0 = vld1q_f32(a + i); + float32x4_t vb0 = vld1q_f32(b + i); + float32x4_t va1 = vld1q_f32(a + i + 4); + float32x4_t vb1 = vld1q_f32(b + i + 4); + float32x4_t va2 = vld1q_f32(a + i + 8); + float32x4_t vb2 = vld1q_f32(b + i + 8); + float32x4_t va3 = vld1q_f32(a + i + 12); + float32x4_t vb3 = vld1q_f32(b + i + 12); + + // Dot products + sum0 = vfmaq_f32(sum0, va0, vb0); + sum1 = vfmaq_f32(sum1, va1, vb1); + sum2 = vfmaq_f32(sum2, va2, vb2); + sum3 = vfmaq_f32(sum3, va3, vb3); + + // Norms + norm_a0 = vfmaq_f32(norm_a0, va0, va0); + norm_a1 = vfmaq_f32(norm_a1, va1, va1); + norm_a2 = vfmaq_f32(norm_a2, va2, va2); + norm_a3 = vfmaq_f32(norm_a3, va3, va3); + + norm_b0 = vfmaq_f32(norm_b0, vb0, vb0); + norm_b1 = vfmaq_f32(norm_b1, vb1, vb1); + norm_b2 = vfmaq_f32(norm_b2, vb2, vb2); + norm_b3 = vfmaq_f32(norm_b3, vb3, vb3); + } + + // Combine accumulators + float32x4_t sums = vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3)); + float32x4_t norms_a = vaddq_f32(vaddq_f32(norm_a0, norm_a1), vaddq_f32(norm_a2, norm_a3)); + float32x4_t norms_b = vaddq_f32(vaddq_f32(norm_b0, norm_b1), vaddq_f32(norm_b2, norm_b3)); + + float dot = vaddvq_f32(sums); + float norm_a = vaddvq_f32(norms_a); + float norm_b = vaddvq_f32(norms_b); + + // Handle remaining tail elements + for (; i < elementCount; ++i) { + float va = a[i]; + float vb = b[i]; + dot += va * vb; + norm_a += va * va; + norm_b += vb * vb; + } + + float denom = sqrtf(norm_a) * sqrtf(norm_b); + if (denom == 0.0f) { + return 0.0f; + } + return dot / denom; +} + +EXPORT float sqrf32(const float *a, const float *b, size_t elementCount) { + float32x4_t sum0 = vdupq_n_f32(0.0f); + float32x4_t sum1 = vdupq_n_f32(0.0f); + float32x4_t sum2 = vdupq_n_f32(0.0f); + float32x4_t sum3 = vdupq_n_f32(0.0f); + float32x4_t sum4 = vdupq_n_f32(0.0f); + float32x4_t sum5 = vdupq_n_f32(0.0f); + float32x4_t sum6 = vdupq_n_f32(0.0f); + float32x4_t sum7 = vdupq_n_f32(0.0f); + + size_t i = 0; + // Each float32x4_t holds 4 floats, so unroll 8x = 32 floats per loop + size_t unrolled_limit = elementCount & ~31UL; + for (; i < unrolled_limit; i += 32) { + float32x4_t d0 = vsubq_f32(vld1q_f32(a + i), vld1q_f32(b + i)); + float32x4_t d1 = vsubq_f32(vld1q_f32(a + i + 4), vld1q_f32(b + i + 4)); + float32x4_t d2 = vsubq_f32(vld1q_f32(a + i + 8), vld1q_f32(b + i + 8)); + float32x4_t d3 = vsubq_f32(vld1q_f32(a + i + 12), vld1q_f32(b + i + 12)); + float32x4_t d4 = vsubq_f32(vld1q_f32(a + i + 16), vld1q_f32(b + i + 16)); + float32x4_t d5 = vsubq_f32(vld1q_f32(a + i + 20), vld1q_f32(b + i + 20)); + float32x4_t d6 = vsubq_f32(vld1q_f32(a + i + 24), vld1q_f32(b + i + 24)); + float32x4_t d7 = vsubq_f32(vld1q_f32(a + i + 28), vld1q_f32(b + i + 28)); + + sum0 = vmlaq_f32(sum0, d0, d0); + sum1 = vmlaq_f32(sum1, d1, d1); + sum2 = vmlaq_f32(sum2, d2, d2); + sum3 = vmlaq_f32(sum3, d3, d3); + sum4 = vmlaq_f32(sum4, d4, d4); + sum5 = vmlaq_f32(sum5, d5, d5); + sum6 = vmlaq_f32(sum6, d6, d6); + sum7 = vmlaq_f32(sum7, d7, d7); + } + + float32x4_t total = vaddq_f32( + vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3)), + vaddq_f32(vaddq_f32(sum4, sum5), vaddq_f32(sum6, sum7)) + ); + float result = vaddvq_f32(total); + + // Handle remaining tail elements + for (; i < elementCount; ++i) { + float diff = a[i] - b[i]; + result += diff * diff; + } + + return result; +} diff --git a/libs/simdvec/native/src/vec/c/amd64/vec.c b/libs/simdvec/native/src/vec/c/amd64/vec.c index f63a7649b1390..c6b9154b60660 100644 --- a/libs/simdvec/native/src/vec/c/amd64/vec.c +++ b/libs/simdvec/native/src/vec/c/amd64/vec.c @@ -9,6 +9,7 @@ #include #include +#include #include "vec.h" #include @@ -187,3 +188,161 @@ EXPORT int32_t sqr7u(int8_t* a, int8_t* b, size_t dims) { } return res; } + +// --- single precision floats + +// Horizontally add 8 float32 elements in a __m256 register +static inline float hsum_f32_8(const __m256 v) { + // First, add the low and high 128-bit lanes + __m128 low = _mm256_castps256_ps128(v); // lower 128 bits + __m128 high = _mm256_extractf128_ps(v, 1); // upper 128 bits + __m128 sum128 = _mm_add_ps(low, high); // sum 8 floats → 4 floats + + // Then do horizontal sum within 128-bit lane + __m128 shuf = _mm_movehdup_ps(sum128); // duplicate odd-index elements + __m128 sums = _mm_add_ps(sum128, shuf); // add pairs + + shuf = _mm_movehl_ps(shuf, sums); // move high pair to low + sums = _mm_add_ss(sums, shuf); // add final two elements + + return _mm_cvtss_f32(sums); +} + +// const float *a pointer to the first float vector +// const float *b pointer to the second float vector +// size_t elementCount the number of floating point elements +EXPORT float cosf32(const float *a, const float *b, size_t elementCount) { + __m256 dot0 = _mm256_setzero_ps(); + __m256 dot1 = _mm256_setzero_ps(); + __m256 dot2 = _mm256_setzero_ps(); + __m256 dot3 = _mm256_setzero_ps(); + + __m256 norm_a0 = _mm256_setzero_ps(); + __m256 norm_a1 = _mm256_setzero_ps(); + __m256 norm_a2 = _mm256_setzero_ps(); + __m256 norm_a3 = _mm256_setzero_ps(); + + __m256 norm_b0 = _mm256_setzero_ps(); + __m256 norm_b1 = _mm256_setzero_ps(); + __m256 norm_b2 = _mm256_setzero_ps(); + __m256 norm_b3 = _mm256_setzero_ps(); + + size_t i = 0; + // Each __m256 holds 8 floats, so unroll 4x = 32 floats per loop + size_t unrolled_limit = elementCount & ~31UL; + for (; i < unrolled_limit; i += 32) { + __m256 a0 = _mm256_loadu_ps(a + i); + __m256 b0 = _mm256_loadu_ps(b + i); + __m256 a1 = _mm256_loadu_ps(a + i + 8); + __m256 b1 = _mm256_loadu_ps(b + i + 8); + __m256 a2 = _mm256_loadu_ps(a + i + 16); + __m256 b2 = _mm256_loadu_ps(b + i + 16); + __m256 a3 = _mm256_loadu_ps(a + i + 24); + __m256 b3 = _mm256_loadu_ps(b + i + 24); + + dot0 = _mm256_fmadd_ps(a0, b0, dot0); + dot1 = _mm256_fmadd_ps(a1, b1, dot1); + dot2 = _mm256_fmadd_ps(a2, b2, dot2); + dot3 = _mm256_fmadd_ps(a3, b3, dot3); + + norm_a0 = _mm256_fmadd_ps(a0, a0, norm_a0); + norm_a1 = _mm256_fmadd_ps(a1, a1, norm_a1); + norm_a2 = _mm256_fmadd_ps(a2, a2, norm_a2); + norm_a3 = _mm256_fmadd_ps(a3, a3, norm_a3); + + norm_b0 = _mm256_fmadd_ps(b0, b0, norm_b0); + norm_b1 = _mm256_fmadd_ps(b1, b1, norm_b1); + norm_b2 = _mm256_fmadd_ps(b2, b2, norm_b2); + norm_b3 = _mm256_fmadd_ps(b3, b3, norm_b3); + } + + // combine and reduce vector accumulators + __m256 dot_total = _mm256_add_ps(_mm256_add_ps(dot0, dot1), _mm256_add_ps(dot2, dot3)); + __m256 norm_a_total = _mm256_add_ps(_mm256_add_ps(norm_a0, norm_a1), _mm256_add_ps(norm_a2, norm_a3)); + __m256 norm_b_total = _mm256_add_ps(_mm256_add_ps(norm_b0, norm_b1), _mm256_add_ps(norm_b2, norm_b3)); + + float dot_result = hsum_f32_8(dot_total); + float norm_a_result = hsum_f32_8(norm_a_total); + float norm_b_result = hsum_f32_8(norm_b_total); + + // Handle remaining tail with scalar loop + for (; i < elementCount; ++i) { + float ai = a[i]; + float bi = b[i]; + dot_result += ai * bi; + norm_a_result += ai * ai; + norm_b_result += bi * bi; + } + + float denom = sqrtf(norm_a_result) * sqrtf(norm_b_result); + if (denom == 0.0f) { + return 0.0f; + } + return dot_result / denom; +} + +// const float *a pointer to the first float vector +// const float *b pointer to the second float vector +// size_t elementCount the number of floating point elements +EXPORT float dotf32(const float *a, const float *b, size_t elementCount) { + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + + size_t i = 0; + // Each __m256 holds 8 floats, so unroll 4x = 32 floats per loop + size_t unrolled_limit = elementCount & ~31UL; + for (; i < unrolled_limit; i += 32) { + acc0 = _mm256_fmadd_ps(_mm256_loadu_ps(a + i), _mm256_loadu_ps(b + i), acc0); + acc1 = _mm256_fmadd_ps(_mm256_loadu_ps(a + i + 8), _mm256_loadu_ps(b + i + 8), acc1); + acc2 = _mm256_fmadd_ps(_mm256_loadu_ps(a + i + 16), _mm256_loadu_ps(b + i + 16), acc2); + acc3 = _mm256_fmadd_ps(_mm256_loadu_ps(a + i + 24), _mm256_loadu_ps(b + i + 24), acc3); + } + + // Combine all partial sums + __m256 total_sum = _mm256_add_ps(_mm256_add_ps(acc0, acc1), _mm256_add_ps(acc2, acc3)); + float result = hsum_f32_8(total_sum); + + for (; i < elementCount; ++i) { + result += a[i] * b[i]; + } + + return result; +} + +// const float *a pointer to the first float vector +// const float *b pointer to the second float vector +// size_t elementCount the number of floating point elements +EXPORT float sqrf32(const float *a, const float *b, size_t elementCount) { + __m256 sum0 = _mm256_setzero_ps(); + __m256 sum1 = _mm256_setzero_ps(); + __m256 sum2 = _mm256_setzero_ps(); + __m256 sum3 = _mm256_setzero_ps(); + + size_t i = 0; + size_t unrolled_limit = elementCount & ~31UL; + // Each __m256 holds 8 floats, so unroll 4x = 32 floats per loop + for (; i < unrolled_limit; i += 32) { + __m256 d0 = _mm256_sub_ps(_mm256_loadu_ps(a + i), _mm256_loadu_ps(b + i)); + __m256 d1 = _mm256_sub_ps(_mm256_loadu_ps(a + i + 8), _mm256_loadu_ps(b + i + 8)); + __m256 d2 = _mm256_sub_ps(_mm256_loadu_ps(a + i + 16), _mm256_loadu_ps(b + i + 16)); + __m256 d3 = _mm256_sub_ps(_mm256_loadu_ps(a + i + 24), _mm256_loadu_ps(b + i + 24)); + + sum0 = _mm256_fmadd_ps(d0, d0, sum0); + sum1 = _mm256_fmadd_ps(d1, d1, sum1); + sum2 = _mm256_fmadd_ps(d2, d2, sum2); + sum3 = _mm256_fmadd_ps(d3, d3, sum3); + } + + // reduce all partial sums + __m256 total_sum = _mm256_add_ps(_mm256_add_ps(sum0, sum1), _mm256_add_ps(sum2, sum3)); + float result = hsum_f32_8(total_sum); + + for (; i < elementCount; ++i) { + float diff = a[i] - b[i]; + result += diff * diff; + } + + return result; +} diff --git a/libs/simdvec/native/src/vec/c/amd64/vec_2.cpp b/libs/simdvec/native/src/vec/c/amd64/vec_2.cpp index f851b2a13a9ea..dd062f8210c3c 100644 --- a/libs/simdvec/native/src/vec/c/amd64/vec_2.cpp +++ b/libs/simdvec/native/src/vec/c/amd64/vec_2.cpp @@ -9,6 +9,7 @@ #include #include +#include #include "vec.h" #ifdef _MSC_VER @@ -195,6 +196,151 @@ EXPORT int32_t sqr7u_2(int8_t* a, int8_t* b, size_t dims) { return res; } +// --- single precision floats + +// const float *a pointer to the first float vector +// const float *b pointer to the second float vector +// size_t elementCount the number of floating point elements +extern "C" +EXPORT float cosf32_2(const float *a, const float *b, size_t elementCount) { + __m512 dot0 = _mm512_setzero_ps(); + __m512 dot1 = _mm512_setzero_ps(); + __m512 dot2 = _mm512_setzero_ps(); + __m512 dot3 = _mm512_setzero_ps(); + + __m512 norm_a0 = _mm512_setzero_ps(); + __m512 norm_a1 = _mm512_setzero_ps(); + __m512 norm_a2 = _mm512_setzero_ps(); + __m512 norm_a3 = _mm512_setzero_ps(); + + __m512 norm_b0 = _mm512_setzero_ps(); + __m512 norm_b1 = _mm512_setzero_ps(); + __m512 norm_b2 = _mm512_setzero_ps(); + __m512 norm_b3 = _mm512_setzero_ps(); + + size_t i = 0; + // Each __m512 holds 16 floats, so unroll 4x = 64 floats per loop + size_t unrolled_limit = elementCount & ~63UL; + for (; i < unrolled_limit; i += 64) { + // Load and compute 4 blocks of 16 elements + __m512 a0 = _mm512_loadu_ps(a + i); + __m512 b0 = _mm512_loadu_ps(b + i); + __m512 a1 = _mm512_loadu_ps(a + i + 16); + __m512 b1 = _mm512_loadu_ps(b + i + 16); + __m512 a2 = _mm512_loadu_ps(a + i + 32); + __m512 b2 = _mm512_loadu_ps(b + i + 32); + __m512 a3 = _mm512_loadu_ps(a + i + 48); + __m512 b3 = _mm512_loadu_ps(b + i + 48); + + dot0 = _mm512_fmadd_ps(a0, b0, dot0); + dot1 = _mm512_fmadd_ps(a1, b1, dot1); + dot2 = _mm512_fmadd_ps(a2, b2, dot2); + dot3 = _mm512_fmadd_ps(a3, b3, dot3); + + norm_a0 = _mm512_fmadd_ps(a0, a0, norm_a0); + norm_a1 = _mm512_fmadd_ps(a1, a1, norm_a1); + norm_a2 = _mm512_fmadd_ps(a2, a2, norm_a2); + norm_a3 = _mm512_fmadd_ps(a3, a3, norm_a3); + + norm_b0 = _mm512_fmadd_ps(b0, b0, norm_b0); + norm_b1 = _mm512_fmadd_ps(b1, b1, norm_b1); + norm_b2 = _mm512_fmadd_ps(b2, b2, norm_b2); + norm_b3 = _mm512_fmadd_ps(b3, b3, norm_b3); + } + + // combine and reduce vector accumulators + __m512 dot_total = _mm512_add_ps(_mm512_add_ps(dot0, dot1), _mm512_add_ps(dot2, dot3)); + __m512 norm_a_total = _mm512_add_ps(_mm512_add_ps(norm_a0, norm_a1), _mm512_add_ps(norm_a2, norm_a3)); + __m512 norm_b_total = _mm512_add_ps(_mm512_add_ps(norm_b0, norm_b1), _mm512_add_ps(norm_b2, norm_b3)); + + float dot_result = _mm512_reduce_add_ps(dot_total); + float norm_a_result = _mm512_reduce_add_ps(norm_a_total); + float norm_b_result = _mm512_reduce_add_ps(norm_b_total); + + // Handle remaining tail with scalar loop + for (; i < elementCount; ++i) { + float ai = a[i]; + float bi = b[i]; + dot_result += ai * bi; + norm_a_result += ai * ai; + norm_b_result += bi * bi; + } + + float denom = sqrtf(norm_a_result) * sqrtf(norm_b_result); + if (denom == 0.0f) { + return 0.0f; + } + return dot_result / denom; +} + +// const float *a pointer to the first float vector +// const float *b pointer to the second float vector +// size_t elementCount the number of floating point elements +extern "C" +EXPORT float dotf32_2(const float *a, const float *b, size_t elementCount) { + __m512 sum0 = _mm512_setzero_ps(); + __m512 sum1 = _mm512_setzero_ps(); + __m512 sum2 = _mm512_setzero_ps(); + __m512 sum3 = _mm512_setzero_ps(); + + size_t i = 0; + size_t unrolled_limit = elementCount & ~63UL; + // Each __m512 holds 16 floats, so unroll 4x = 64 floats per loop + for (; i < unrolled_limit; i += 64) { + sum0 = _mm512_fmadd_ps(_mm512_loadu_ps(a + i), _mm512_loadu_ps(b + i), sum0); + sum1 = _mm512_fmadd_ps(_mm512_loadu_ps(a + i + 16), _mm512_loadu_ps(b + i + 16), sum1); + sum2 = _mm512_fmadd_ps(_mm512_loadu_ps(a + i + 32), _mm512_loadu_ps(b + i + 32), sum2); + sum3 = _mm512_fmadd_ps(_mm512_loadu_ps(a + i + 48), _mm512_loadu_ps(b + i + 48), sum3); + } + + // reduce all partial sums + __m512 total_sum = _mm512_add_ps(_mm512_add_ps(sum0, sum1), _mm512_add_ps(sum2, sum3)); + float result = _mm512_reduce_add_ps(total_sum); + + for (; i < elementCount; ++i) { + result += a[i] * b[i]; + } + + return result; +} + +// const float *a pointer to the first float vector +// const float *b pointer to the second float vector +// size_t elementCount the number of floating point elements +extern "C" +EXPORT float sqrf32_2(const float *a, const float *b, size_t elementCount) { + __m512 sum0 = _mm512_setzero_ps(); + __m512 sum1 = _mm512_setzero_ps(); + __m512 sum2 = _mm512_setzero_ps(); + __m512 sum3 = _mm512_setzero_ps(); + + size_t i = 0; + size_t unrolled_limit = elementCount & ~63UL; + // Each __m512 holds 16 floats, so unroll 4x = 64 floats per loop + for (; i < unrolled_limit; i += 64) { + __m512 d0 = _mm512_sub_ps(_mm512_loadu_ps(a + i), _mm512_loadu_ps(b + i)); + __m512 d1 = _mm512_sub_ps(_mm512_loadu_ps(a + i + 16), _mm512_loadu_ps(b + i + 16)); + __m512 d2 = _mm512_sub_ps(_mm512_loadu_ps(a + i + 32), _mm512_loadu_ps(b + i + 32)); + __m512 d3 = _mm512_sub_ps(_mm512_loadu_ps(a + i + 48), _mm512_loadu_ps(b + i + 48)); + + sum0 = _mm512_fmadd_ps(d0, d0, sum0); + sum1 = _mm512_fmadd_ps(d1, d1, sum1); + sum2 = _mm512_fmadd_ps(d2, d2, sum2); + sum3 = _mm512_fmadd_ps(d3, d3, sum3); + } + + // reduce all partial sums + __m512 total_sum = _mm512_add_ps(_mm512_add_ps(sum0, sum1), _mm512_add_ps(sum2, sum3)); + float result = _mm512_reduce_add_ps(total_sum); + + for (; i < elementCount; ++i) { + float diff = a[i] - b[i]; + result += diff * diff; + } + + return result; +} + #ifdef __clang__ #pragma clang attribute pop #elif __GNUC__ diff --git a/libs/simdvec/native/src/vec/headers/vec.h b/libs/simdvec/native/src/vec/headers/vec.h index e27e9a3a68083..733aea3165659 100644 --- a/libs/simdvec/native/src/vec/headers/vec.h +++ b/libs/simdvec/native/src/vec/headers/vec.h @@ -20,3 +20,10 @@ EXPORT int vec_caps(); EXPORT int32_t dot7u(int8_t* a, int8_t* b, size_t dims); EXPORT int32_t sqr7u(int8_t *a, int8_t *b, size_t length); + +EXPORT float cosf32(const float *a, const float *b, size_t elementCount); + +EXPORT float dotf32(const float *a, const float *b, size_t elementCount); + +EXPORT float sqrf32(const float *a, const float *b, size_t elementCount); + From 4ed416dc6f8bc834916ed19319921d0e2f95afe8 Mon Sep 17 00:00:00 2001 From: Chris Hegarty <62058229+ChrisHegarty@users.noreply.github.com> Date: Fri, 4 Jul 2025 15:03:24 +0100 Subject: [PATCH 2/2] Update docs/changelog/130635.yaml --- docs/changelog/130635.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/130635.yaml diff --git a/docs/changelog/130635.yaml b/docs/changelog/130635.yaml new file mode 100644 index 0000000000000..ab1bd684641f6 --- /dev/null +++ b/docs/changelog/130635.yaml @@ -0,0 +1,5 @@ +pr: 130635 +summary: "Add low-level optimized Neon, AVX2, and AVX 512 float32 vector operations" +area: Vector Search +type: enhancement +issues: []