diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java index c98532d8dd8f5..710b303f67fca 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java @@ -170,6 +170,33 @@ public void scoreFromMemorySegmentOnlyVector(Blackhole bh) throws IOException { } } + @Benchmark + @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public void scoreThreeUpperBitsFromMemorySegmentOnlyVector(Blackhole bh) throws IOException { + for (int j = 0; j < numQueries; j++) { + in.seek(0); + for (int i = 0; i < numVectors; i++) { + float qDist = scorer.quantizeScoreThreeUpperBit(binaryQueries[j]); + in.readFloats(corrections, 0, corrections.length); + int addition = Short.toUnsignedInt(in.readShort()); + float score = scorer.score( + result.lowerInterval(), + result.upperInterval(), + result.quantizedComponentSum(), + result.additionalCorrection(), + VectorSimilarityFunction.EUCLIDEAN, + centroidDp, + corrections[0], + corrections[1], + addition, + corrections[2], + qDist + ); + bh.consume(score); + } + } + } + @Benchmark @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) public void scoreFromMemorySegmentOnlyVectorBulk(Blackhole bh) throws IOException { @@ -199,6 +226,35 @@ public void scoreFromMemorySegmentOnlyVectorBulk(Blackhole bh) throws IOExceptio } } + @Benchmark + @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public void scoreThreeUpperBitsFromMemorySegmentOnlyVectorBulk(Blackhole bh) throws IOException { + for (int j = 0; j < numQueries; j++) { + in.seek(0); + for (int i = 0; i < numVectors; i += 16) { + scorer.quantizeScoreThreeUpperBitBulk(binaryQueries[j], ES91OSQVectorsScorer.BULK_SIZE, scratchScores); + for (int k = 0; k < ES91OSQVectorsScorer.BULK_SIZE; k++) { + in.readFloats(corrections, 0, corrections.length); + int addition = Short.toUnsignedInt(in.readShort()); + float score = scorer.score( + result.lowerInterval(), + result.upperInterval(), + result.quantizedComponentSum(), + result.additionalCorrection(), + VectorSimilarityFunction.EUCLIDEAN, + centroidDp, + corrections[0], + corrections[1], + addition, + corrections[2], + scratchScores[k] + ); + bh.consume(score); + } + } + } + } + @Benchmark @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) public void scoreFromMemorySegmentAllBulk(Blackhole bh) throws IOException { diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java index 58df8bb03e0cb..f45fad46b57ec 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java @@ -80,6 +80,54 @@ public long quantizeScore(byte[] q) throws IOException { return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); } + public long quantizeScoreThreeUpperBit(byte[] q) throws IOException { + assert q.length == length * 4; + final int size = length; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int r = 0; + for (final int upperBound = size & -Long.BYTES; r < upperBound; r += Long.BYTES) { + final long value = in.readLong(); + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + size) & value); + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + 2 * size) & value); + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + 3 * size) & value); + } + for (final int upperBound = size & -Integer.BYTES; r < upperBound; r += Integer.BYTES) { + final int value = in.readInt(); + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + size) & value); + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + 2 * size) & value); + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + 3 * size) & value); + } + for (; r < size; r++) { + final byte value = in.readByte(); + subRet1 += Integer.bitCount((q[r + size] & value) & 0xFF); + subRet2 += Integer.bitCount((q[r + 2 * size] & value) & 0xFF); + subRet3 += Integer.bitCount((q[r + 3 * size] & value) & 0xFF); + } + return (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + + public long quantizeScoreLowerBit(byte[] q) throws IOException { + assert q.length == length * 4; + final int size = length; + long subRet = 0; + int r = 0; + for (final int upperBound = size & -Long.BYTES; r < upperBound; r += Long.BYTES) { + final long value = in.readLong(); + subRet += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r) & value); + } + for (final int upperBound = size & -Integer.BYTES; r < upperBound; r += Integer.BYTES) { + final int value = in.readInt(); + subRet += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r) & value); + } + for (; r < size; r++) { + final byte value = in.readByte(); + subRet += Integer.bitCount((q[r] & value) & 0xFF); + } + return subRet; + } + /** * compute the quantize distance between the provided quantized query and the quantized vectors * that are read from the wrapped {@link IndexInput}. The number of quantized vectors to read is @@ -91,6 +139,12 @@ public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOExce } } + public void quantizeScoreThreeUpperBitBulk(byte[] q, int count, float[] scores) throws IOException { + for (int i = 0; i < count; i++) { + scores[i] = quantizeScoreThreeUpperBit(q); + } + } + /** * Computes the score by applying the necessary corrections to the provided quantized distance. */ diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java index 4be6ede34530a..9eda517ac87a2 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java @@ -170,6 +170,178 @@ private long quantizeScore128(byte[] q) throws IOException { return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); } + @Override + public long quantizeScoreThreeUpperBit(byte[] q) throws IOException { + assert q.length == length * 4; + // 128 / 8 == 16 + if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { + if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { + return quantizeScoreThreeUpperBit256(q); + } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { + return quantizeScoreThreeUpperBit128(q); + } + } + return super.quantizeScoreThreeUpperBit(q); + } + + private long quantizeScoreThreeUpperBit256(byte[] q) throws IOException { + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + long offset = in.getFilePointer(); + if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) { + int limit = ByteVector.SPECIES_256.loopBound(length); + var sum1 = LongVector.zero(LONG_SPECIES_256); + var sum2 = LongVector.zero(LONG_SPECIES_256); + var sum3 = LongVector.zero(LONG_SPECIES_256); + for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) { + var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 3).reinterpretAsLongs(); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + + if (length - i >= ByteVector.SPECIES_128.vectorByteSize()) { + var sum1 = LongVector.zero(LONG_SPECIES_128); + var sum2 = LongVector.zero(LONG_SPECIES_128); + var sum3 = LongVector.zero(LONG_SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(length); + for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) { + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsLongs(); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + // tail as bytes + in.seek(offset); + for (; i < length; i++) { + int dValue = in.readByte() & 0xFF; + subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); + subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); + subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); + } + return (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + + private long quantizeScoreThreeUpperBit128(byte[] q) throws IOException { + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + long offset = in.getFilePointer(); + + var sum1 = IntVector.zero(INT_SPECIES_128); + var sum2 = IntVector.zero(INT_SPECIES_128); + var sum3 = IntVector.zero(INT_SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(length); + for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) { + var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsInts(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsInts(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsInts(); + sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT)); + } + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + // tail as bytes + in.seek(offset); + for (; i < length; i++) { + int dValue = in.readByte() & 0xFF; + subRet1 += Integer.bitCount((dValue & q[i + length]) & 0xFF); + subRet2 += Integer.bitCount((dValue & q[i + 2 * length]) & 0xFF); + subRet3 += Integer.bitCount((dValue & q[i + 3 * length]) & 0xFF); + } + return (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + + @Override + public long quantizeScoreLowerBit(byte[] q) throws IOException { + assert q.length == length * 4; + // 128 / 8 == 16 + if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { + if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { + return quantizeScoreLowerBit256(q); + } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { + return quantizeScoreLowerBit128(q); + } + } + return super.quantizeScore(q); + } + + private long quantizeScoreLowerBit256(byte[] q) throws IOException { + long subRet0 = 0; + int i = 0; + long offset = in.getFilePointer(); + if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) { + int limit = ByteVector.SPECIES_256.loopBound(length); + var sum0 = LongVector.zero(LONG_SPECIES_256); + for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) { + var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs(); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + } + + if (length - i >= ByteVector.SPECIES_128.vectorByteSize()) { + var sum0 = LongVector.zero(LONG_SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(length); + for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) { + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs(); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + } + // tail as bytes + in.seek(offset); + for (; i < length; i++) { + int dValue = in.readByte() & 0xFF; + subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); + } + return subRet0; + } + + private long quantizeScoreLowerBit128(byte[] q) throws IOException { + long subRet0 = 0; + int i = 0; + long offset = in.getFilePointer(); + + var sum0 = IntVector.zero(INT_SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(length); + for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) { + var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts(); + sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + // tail as bytes + in.seek(offset); + for (; i < length; i++) { + int dValue = in.readByte() & 0xFF; + subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF); + } + return subRet0; + } + @Override public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException { assert q.length == length * 4; @@ -294,6 +466,113 @@ private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IO } } + public void quantizeScoreThreeUpperBitBulk(byte[] q, int count, float[] scores) throws IOException { + assert q.length == length * 4; + // 128 / 8 == 16 + if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { + if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { + quantizeScoreThreeUpperBit256Bulk(q, count, scores); + return; + } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { + quantizeScoreThreeUpperBit128Bulk(q, count, scores); + return; + } + } + super.quantizeScoreBulk(q, count, scores); + } + + private void quantizeScoreThreeUpperBit128Bulk(byte[] q, int count, float[] scores) throws IOException { + for (int iter = 0; iter < count; iter++) { + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + long offset = in.getFilePointer(); + + var sum1 = IntVector.zero(INT_SPECIES_128); + var sum2 = IntVector.zero(INT_SPECIES_128); + var sum3 = IntVector.zero(INT_SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(length); + for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) { + var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsInts(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsInts(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsInts(); + sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT)); + } + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + // tail as bytes + in.seek(offset); + for (; i < length; i++) { + int dValue = in.readByte() & 0xFF; + subRet1 += Integer.bitCount((dValue & q[i + length]) & 0xFF); + subRet2 += Integer.bitCount((dValue & q[i + 2 * length]) & 0xFF); + subRet3 += Integer.bitCount((dValue & q[i + 3 * length]) & 0xFF); + } + scores[iter] = (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + } + + private void quantizeScoreThreeUpperBit256Bulk(byte[] q, int count, float[] scores) throws IOException { + for (int iter = 0; iter < count; iter++) { + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + long offset = in.getFilePointer(); + if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) { + int limit = ByteVector.SPECIES_256.loopBound(length); + var sum1 = LongVector.zero(LONG_SPECIES_256); + var sum2 = LongVector.zero(LONG_SPECIES_256); + var sum3 = LongVector.zero(LONG_SPECIES_256); + for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) { + var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 3).reinterpretAsLongs(); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + + if (length - i >= ByteVector.SPECIES_128.vectorByteSize()) { + var sum1 = LongVector.zero(LONG_SPECIES_128); + var sum2 = LongVector.zero(LONG_SPECIES_128); + var sum3 = LongVector.zero(LONG_SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(length); + for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) { + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsLongs(); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + // tail as bytes + in.seek(offset); + for (; i < length; i++) { + int dValue = in.readByte() & 0xFF; + subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); + subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); + subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); + } + scores[iter] = (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + } + @Override public void scoreBulk( byte[] q, diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java index b6a95c3c66bae..34dc238d41e2c 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java @@ -44,8 +44,12 @@ public void testQuantizeScore() throws Exception { final ES91OSQVectorsScorer defaultScorer = defaultProvider().newES91OSQVectorsScorer(slice, dimensions); final ES91OSQVectorsScorer panamaScorer = maybePanamaProvider().newES91OSQVectorsScorer(in, dimensions); for (int i = 0; i < numVectors; i++) { + long filePointer = in.getFilePointer(); assertEquals(defaultScorer.quantizeScore(query), panamaScorer.quantizeScore(query)); assertEquals(in.getFilePointer(), slice.getFilePointer()); + in.seek(filePointer); + slice.seek(filePointer); + assertEquals(defaultScorer.quantizeScoreThreeUpperBit(query), panamaScorer.quantizeScoreThreeUpperBit(query)); } assertEquals((long) length * numVectors, slice.getFilePointer()); } @@ -124,26 +128,7 @@ public void testScore() throws Exception { centroidDp, scores2 ); - for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { - if (scores1[j] == scores2[j]) { - continue; - } - if (scores1[j] > (maxDims * Byte.MAX_VALUE)) { - float diff = Math.abs(scores1[j] - scores2[j]); - assertThat( - "defaultScores: " + scores1[j] + " bulkScores: " + scores2[j], - diff / scores1[j], - lessThan(1e-5f) - ); - assertThat( - "defaultScores: " + scores1[j] + " bulkScores: " + scores2[j], - diff / scores2[j], - lessThan(1e-5f) - ); - } else { - assertEquals(scores1[j], scores2[j], 1e-2f); - } - } + assertScores(scores1, scores2, maxDims); assertEquals(((long) (ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), slice.getFilePointer()); assertEquals(padding + ((long) (i + ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), in.getFilePointer()); } @@ -151,4 +136,19 @@ public void testScore() throws Exception { } } } + + private void assertScores(float[] scores1, float[] scores2, int maxDims) { + for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { + if (scores1[j] == scores2[j]) { + continue; + } + if (scores1[j] > (maxDims * Byte.MAX_VALUE)) { + float diff = Math.abs(scores1[j] - scores2[j]); + assertThat("defaultScores: " + scores1[j] + " bulkScores: " + scores2[j], diff / scores1[j], lessThan(1e-5f)); + assertThat("defaultScores: " + scores1[j] + " bulkScores: " + scores2[j], diff / scores2[j], lessThan(1e-5f)); + } else { + assertEquals(scores1[j], scores2[j], 1e-2f); + } + } + } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index 304cc57284227..6f17e7fa72447 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -164,6 +164,9 @@ public Map getOffHeapByteSize(FieldInfo fieldInfo) { } private static class MemorySegmentPostingsVisitor implements PostingVisitor { + // At the beginning, most documents will be competitive so approximating to three bits is not effective. + // this value indicates how many times we need to multiply KnnCollect.k() to start applying the three bits approximation. + static long APPROXIMATION_OFFSET = 30; final long quantizedByteLength; final IndexInput indexInput; final float[] target; @@ -192,6 +195,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor { final OptimizedScalarQuantizer quantizer; final float[] correctiveValues = new float[3]; final long quantizedVectorByteSize; + int lowerBitCount; MemorySegmentPostingsVisitor( float[] target, @@ -230,16 +234,57 @@ public int resetPostingsScorer(long offset) throws IOException { return vectors; } - void scoreIndividually(int offset) throws IOException { - // score individually, first the quantized byte chunk + void scoreIndividuallyPartialScore(int offset, float minScore) throws IOException { + // read in all corrections + indexInput.seek(slicePos + (offset * quantizedByteLength) + (BULK_SIZE * quantizedVectorByteSize)); + indexInput.readFloats(correctionsLower, 0, BULK_SIZE); + indexInput.readFloats(correctionsUpper, 0, BULK_SIZE); + for (int j = 0; j < BULK_SIZE; j++) { + correctionsSum[j] = Short.toUnsignedInt(indexInput.readShort()); + } + indexInput.readFloats(correctionsAdd, 0, BULK_SIZE); + // Now apply corrections for (int j = 0; j < BULK_SIZE; j++) { - int doc = docIdsScratch[j + offset]; + int doc = docIdsScratch[offset + j]; if (doc != -1) { - indexInput.seek(slicePos + (offset * quantizedByteLength) + (j * quantizedVectorByteSize)); - float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch); - scores[j] = qcDist; + float maxScore = osqVectorsScorer.score( + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + fieldInfo.getVectorSimilarityFunction(), + centroidDp, + correctionsLower[j], + correctionsUpper[j], + correctionsSum[j], + correctionsAdd[j], + scores[j] + Math.min(correctionsSum[j], lowerBitCount) + ); + if (maxScore > minScore) { + indexInput.seek(slicePos + (offset * quantizedByteLength) + (j * quantizedVectorByteSize)); + scores[j] += osqVectorsScorer.quantizeScoreLowerBit(quantizedQueryScratch); + scores[j] = osqVectorsScorer.score( + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + fieldInfo.getVectorSimilarityFunction(), + centroidDp, + correctionsLower[j], + correctionsUpper[j], + correctionsSum[j], + correctionsAdd[j], + scores[j] + ); + assert scores[j] >= minScore; + } else { + docIdsScratch[offset + j] = -1; + } } } + } + + void scoreIndividuallyFullScore(int offset) throws IOException { // read in all corrections indexInput.seek(slicePos + (offset * quantizedByteLength) + (BULK_SIZE * quantizedVectorByteSize)); indexInput.readFloats(correctionsLower, 0, BULK_SIZE); @@ -290,18 +335,42 @@ public int visit(KnnCollector knnCollector) throws IOException { quantizeQueryIfNecessary(); indexInput.seek(slicePos + i * quantizedByteLength); if (docsToScore < BULK_SIZE / 2) { - scoreIndividually(i); + if (knnCollector.visitedCount() < APPROXIMATION_OFFSET * knnCollector.k()) { + for (int j = 0; j < BULK_SIZE; j++) { + int doc = docIdsScratch[j + i]; + if (doc != -1) { + indexInput.seek(slicePos + (i * quantizedByteLength) + (j * quantizedVectorByteSize)); + scores[j] = osqVectorsScorer.quantizeScore(quantizedQueryScratch); + } + } + scoreIndividuallyFullScore(i); + } else { + // score individually, first the quantized byte chunk + for (int j = 0; j < BULK_SIZE; j++) { + int doc = docIdsScratch[j + i]; + if (doc != -1) { + indexInput.seek(slicePos + (i * quantizedByteLength) + (j * quantizedVectorByteSize)); + scores[j] = osqVectorsScorer.quantizeScoreThreeUpperBit(quantizedQueryScratch); + } + } + scoreIndividuallyPartialScore(i, knnCollector.minCompetitiveSimilarity()); + } } else { - osqVectorsScorer.scoreBulk( - quantizedQueryScratch, - queryCorrections.lowerInterval(), - queryCorrections.upperInterval(), - queryCorrections.quantizedComponentSum(), - queryCorrections.additionalCorrection(), - fieldInfo.getVectorSimilarityFunction(), - centroidDp, - scores - ); + if (knnCollector.visitedCount() < APPROXIMATION_OFFSET * knnCollector.k()) { + osqVectorsScorer.scoreBulk( + quantizedQueryScratch, + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + fieldInfo.getVectorSimilarityFunction(), + centroidDp, + scores + ); + } else { + osqVectorsScorer.quantizeScoreThreeUpperBitBulk(quantizedQueryScratch, BULK_SIZE, scores); + scoreIndividuallyPartialScore(i, knnCollector.minCompetitiveSimilarity()); + } } for (int j = 0; j < BULK_SIZE; j++) { int doc = docIdsScratch[i + j]; @@ -317,24 +386,61 @@ public int visit(KnnCollector knnCollector) throws IOException { if (needsScoring.test(doc)) { quantizeQueryIfNecessary(); indexInput.seek(slicePos + i * quantizedByteLength); - float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch); - indexInput.readFloats(correctiveValues, 0, 3); - final int quantizedComponentSum = Short.toUnsignedInt(indexInput.readShort()); - float score = osqVectorsScorer.score( - queryCorrections.lowerInterval(), - queryCorrections.upperInterval(), - queryCorrections.quantizedComponentSum(), - queryCorrections.additionalCorrection(), - fieldInfo.getVectorSimilarityFunction(), - centroidDp, - correctiveValues[0], - correctiveValues[1], - quantizedComponentSum, - correctiveValues[2], - qcDist - ); + if (knnCollector.visitedCount() < APPROXIMATION_OFFSET * knnCollector.k()) { + float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch); + indexInput.readFloats(correctiveValues, 0, 3); + final int quantizedComponentSum = Short.toUnsignedInt(indexInput.readShort()); + float score = osqVectorsScorer.score( + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + fieldInfo.getVectorSimilarityFunction(), + centroidDp, + correctiveValues[0], + correctiveValues[1], + quantizedComponentSum, + correctiveValues[2], + qcDist + ); + knnCollector.collect(doc, score); + } else { + float qcDist = osqVectorsScorer.quantizeScoreThreeUpperBit(quantizedQueryScratch); + indexInput.readFloats(correctiveValues, 0, 3); + final int quantizedComponentSum = Short.toUnsignedInt(indexInput.readShort()); + float maxScore = osqVectorsScorer.score( + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + fieldInfo.getVectorSimilarityFunction(), + centroidDp, + correctiveValues[0], + correctiveValues[1], + quantizedComponentSum, + correctiveValues[2], + qcDist + Math.min(quantizedComponentSum, lowerBitCount) + ); + if (maxScore > knnCollector.minCompetitiveSimilarity()) { + indexInput.seek(slicePos + i * quantizedByteLength); + qcDist += osqVectorsScorer.quantizeScoreLowerBit(quantizedQueryScratch); + float score = osqVectorsScorer.score( + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + fieldInfo.getVectorSimilarityFunction(), + centroidDp, + correctiveValues[0], + correctiveValues[1], + quantizedComponentSum, + correctiveValues[2], + qcDist + ); + knnCollector.collect(doc, score); + } + } scoredDocs++; - knnCollector.collect(doc, score); } } if (scoredDocs > 0) { @@ -350,10 +456,35 @@ private void quantizeQueryIfNecessary() { VectorUtil.l2normalize(scratch); } queryCorrections = quantizer.scalarQuantize(scratch, quantizationScratch, (byte) 4, centroid); - transposeHalfByte(quantizationScratch, quantizedQueryScratch); + this.lowerBitCount = transposeHalfByte(quantizationScratch, quantizedQueryScratch); quantized = true; } } + + private static int transposeHalfByte(int[] q, byte[] quantQueryByte) { + int lowerBitCount = 0; + for (int i = 0; i < q.length;) { + assert q[i] >= 0 && q[i] <= 15; + int lowerByte = 0; + int lowerMiddleByte = 0; + int upperMiddleByte = 0; + int upperByte = 0; + for (int j = 7; j >= 0 && i < q.length; j--) { + lowerByte |= (q[i] & 1) << j; + lowerMiddleByte |= ((q[i] >> 1) & 1) << j; + upperMiddleByte |= ((q[i] >> 2) & 1) << j; + upperByte |= ((q[i] >> 3) & 1) << j; + i++; + } + int index = ((i + 7) / 8) - 1; + lowerBitCount += Integer.bitCount(lowerByte & 0xFF); + quantQueryByte[index] = (byte) lowerByte; + quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte; + quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte; + quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte; + } + return lowerBitCount; + } } }