diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java index 5cca4e700dc6..2b8c2e5da1e8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java @@ -64,16 +64,17 @@ import org.apache.lucene.util.quantization.ScalarQuantizer; /** - * Reads vectors from the index segments along with index data structures supporting KNN search. + * Reads vectors from the index segments along with index data structures + * supporting KNN search. * * @lucene.experimental */ public final class Lucene99HnswVectorsReader extends KnnVectorsReader implements QuantizedVectorsReader, HnswGraphProvider { - private static final long SHALLOW_SIZE = - RamUsageEstimator.shallowSizeOfInstance(Lucene99HnswVectorsFormat.class); - // Number of ordinals to score at a time when scoring exhaustively rather than using HNSW. + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Lucene99HnswVectorsFormat.class); + // Number of ordinals to score at a time when scoring exhaustively rather than + // using HNSW. private static final int EXHAUSTIVE_BULK_SCORE_ORDS = 64; private final FlatVectorsReader flatVectorsReader; @@ -87,21 +88,19 @@ public Lucene99HnswVectorsReader(SegmentReadState state, FlatVectorsReader flatV this.fields = new IntObjectHashMap<>(); this.flatVectorsReader = flatVectorsReader; this.fieldInfos = state.fieldInfos; - String metaFileName = - IndexFileNames.segmentFileName( - state.segmentInfo.name, state.segmentSuffix, Lucene99HnswVectorsFormat.META_EXTENSION); + String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, state.segmentSuffix, Lucene99HnswVectorsFormat.META_EXTENSION); int versionMeta = -1; try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { Throwable priorE = null; try { - versionMeta = - CodecUtil.checkIndexHeader( - meta, - Lucene99HnswVectorsFormat.META_CODEC_NAME, - Lucene99HnswVectorsFormat.VERSION_START, - Lucene99HnswVectorsFormat.VERSION_CURRENT, - state.segmentInfo.getId(), - state.segmentSuffix); + versionMeta = CodecUtil.checkIndexHeader( + meta, + Lucene99HnswVectorsFormat.META_CODEC_NAME, + Lucene99HnswVectorsFormat.VERSION_START, + Lucene99HnswVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); readFields(meta); } catch (Throwable exception) { priorE = exception; @@ -109,19 +108,18 @@ public Lucene99HnswVectorsReader(SegmentReadState state, FlatVectorsReader flatV CodecUtil.checkFooter(meta, priorE); } this.version = versionMeta; - this.vectorIndex = - openDataInput( - state, - versionMeta, - Lucene99HnswVectorsFormat.VECTOR_INDEX_EXTENSION, - Lucene99HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME, - state.context.withHints( - // Even though this input is referred to an `indexIn`, it doesn't qualify as - // FileTypeHint#INDEX since it's a large file - FileTypeHint.DATA, - FileDataHint.KNN_VECTORS, - DataAccessHint.RANDOM, - PreloadHint.INSTANCE)); + this.vectorIndex = openDataInput( + state, + versionMeta, + Lucene99HnswVectorsFormat.VECTOR_INDEX_EXTENSION, + Lucene99HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME, + state.context.withHints( + // Even though this input is referred to an `indexIn`, it doesn't qualify as + // FileTypeHint#INDEX since it's a large file + FileTypeHint.DATA, + FileDataHint.KNN_VECTORS, + DataAccessHint.RANDOM, + PreloadHint.INSTANCE)); } catch (Throwable t) { IOUtils.closeWhileSuppressingExceptions(t, this); throw t; @@ -154,18 +152,16 @@ private static IndexInput openDataInput( String codecName, IOContext context) throws IOException { - String fileName = - IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); IndexInput in = state.directory.openInput(fileName, context); try { - int versionVectorData = - CodecUtil.checkIndexHeader( - in, - codecName, - Lucene99HnswVectorsFormat.VERSION_START, - Lucene99HnswVectorsFormat.VERSION_CURRENT, - state.segmentInfo.getId(), - state.segmentSuffix); + int versionVectorData = CodecUtil.checkIndexHeader( + in, + codecName, + Lucene99HnswVectorsFormat.VERSION_START, + Lucene99HnswVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); if (versionMeta != versionVectorData) { throw new CorruptIndexException( "Format versions mismatch: meta=" @@ -213,12 +209,11 @@ private void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) { // to avoid an undesirable dependency on the declaration and order of values // in VectorSimilarityFunction. The list values and order must be identical // to that of {@link o.a.l.c.l.Lucene94FieldInfosFormat#SIMILARITY_FUNCTIONS}. - public static final List SIMILARITY_FUNCTIONS = - List.of( - VectorSimilarityFunction.EUCLIDEAN, - VectorSimilarityFunction.DOT_PRODUCT, - VectorSimilarityFunction.COSINE, - VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT); + public static final List SIMILARITY_FUNCTIONS = List.of( + VectorSimilarityFunction.EUCLIDEAN, + VectorSimilarityFunction.DOT_PRODUCT, + VectorSimilarityFunction.COSINE, + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT); public static VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException { @@ -330,8 +325,7 @@ private void search( return; } final RandomVectorScorer scorer = scorerSupplier.get(); - final KnnCollector collector = - new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); + final KnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs); HnswGraph graph = getGraph(fieldEntry); boolean doHnsw = knnCollector.k() < scorer.maxOrd(); @@ -340,7 +334,8 @@ private void search( // The approximate number of vectors that would be visited if we did not filter int unfilteredVisit = HnswGraphSearcher.expectedVisitedNodes(knnCollector.k(), graph.size()); if (acceptDocs instanceof BitSet bitSet) { - // Use approximate cardinality as this is good enough, but ensure we don't exceed the graph + // Use approximate cardinality as this is good enough, but ensure we don't + // exceed the graph // size as that is illogical filteredDocCount = Math.min(bitSet.approximateCardinality(), graph.size()); if (unfilteredVisit >= filteredDocCount) { @@ -351,7 +346,8 @@ private void search( HnswGraphSearcher.search( scorer, collector, getGraph(fieldEntry), acceptedOrds, filteredDocCount); } else { - // if k is larger than the number of vectors we expect to visit in an HNSW search, + // if k is larger than the number of vectors we expect to visit in an HNSW + // search, // we can just iterate over all vectors and collect them. int[] ords = new int[EXHAUSTIVE_BULK_SCORE_ORDS]; float[] scores = new float[EXHAUSTIVE_BULK_SCORE_ORDS]; @@ -440,7 +436,8 @@ private record FieldEntry( int dimension, int size, int[][] nodesByLevel, - // for each level the start offsets in vectorIndex file from where to read neighbours + // for each level the start offsets in vectorIndex file from where to read + // neighbours DirectMonotonicReader.Meta offsetsMeta, long offsetsOffset, int offsetsBlockShift, @@ -523,16 +520,13 @@ private final class OffHeapHnswGraph extends HnswGraph { private final int[] currentNeighborsBuffer; OffHeapHnswGraph(FieldEntry entry, IndexInput vectorIndex) throws IOException { - this.dataIn = - vectorIndex.slice("graph-data", entry.vectorIndexOffset, entry.vectorIndexLength); + this.dataIn = vectorIndex.slice("graph-data", entry.vectorIndexOffset, entry.vectorIndexLength); this.nodesByLevel = entry.nodesByLevel; this.numLevels = entry.numLevels; this.entryNode = numLevels > 1 ? nodesByLevel[numLevels - 1][0] : 0; this.size = entry.size(); - final RandomAccessInput addressesData = - vectorIndex.randomAccessSlice(entry.offsetsOffset, entry.offsetsLength); - this.graphLevelNodeOffsets = - DirectMonotonicReader.getInstance(entry.offsetsMeta, addressesData); + final RandomAccessInput addressesData = vectorIndex.randomAccessSlice(entry.offsetsOffset, entry.offsetsLength); + this.graphLevelNodeOffsets = DirectMonotonicReader.getInstance(entry.offsetsMeta, addressesData); this.currentNeighborsBuffer = new int[entry.M * 2]; this.maxConn = entry.M; graphLevelNodeIndexOffsets = new long[numLevels]; @@ -546,10 +540,9 @@ private final class OffHeapHnswGraph extends HnswGraph { @Override public void seek(int level, int targetOrd) throws IOException { - int targetIndex = - level == 0 - ? targetOrd - : Arrays.binarySearch(nodesByLevel[level], 0, nodesByLevel[level].length, targetOrd); + int targetIndex = level == 0 + ? targetOrd + : Arrays.binarySearch(nodesByLevel[level], 0, nodesByLevel[level].length, targetOrd); assert targetIndex >= 0 : "seek level=" + level + " target=" + targetOrd + " not found: " + targetIndex; // unsafe; no bounds checking @@ -557,15 +550,18 @@ public void seek(int level, int targetOrd) throws IOException { arcCount = dataIn.readVInt(); assert arcCount <= currentNeighborsBuffer.length : "too many neighbors: " + arcCount; if (arcCount > 0) { + int sum = 0; if (version >= VERSION_GROUPVARINT) { GroupVIntUtil.readGroupVInts(dataIn, currentNeighborsBuffer, arcCount); - for (int i = 1; i < arcCount; i++) { - currentNeighborsBuffer[i] = currentNeighborsBuffer[i - 1] + currentNeighborsBuffer[i]; + // Faster prefix sum computation (see #14979) + for (int i = 0; i < arcCount; i++) { + sum += currentNeighborsBuffer[i]; + currentNeighborsBuffer[i] = sum; } } else { - currentNeighborsBuffer[0] = dataIn.readVInt(); - for (int i = 1; i < arcCount; i++) { - currentNeighborsBuffer[i] = currentNeighborsBuffer[i - 1] + dataIn.readVInt(); + for (int i = 0; i < arcCount; i++) { + sum += dataIn.readVInt(); + currentNeighborsBuffer[i] = sum; } } }