Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,17 @@
import org.apache.lucene.util.quantization.ScalarQuantizer;

/**
* Reads vectors from the index segments along with index data structures supporting KNN search.
* Reads vectors from the index segments along with index data structures
* supporting KNN search.
*
* @lucene.experimental
*/
public final class Lucene99HnswVectorsReader extends KnnVectorsReader
implements QuantizedVectorsReader, HnswGraphProvider {

private static final long SHALLOW_SIZE =
RamUsageEstimator.shallowSizeOfInstance(Lucene99HnswVectorsFormat.class);
// Number of ordinals to score at a time when scoring exhaustively rather than using HNSW.
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Lucene99HnswVectorsFormat.class);
// Number of ordinals to score at a time when scoring exhaustively rather than
// using HNSW.
private static final int EXHAUSTIVE_BULK_SCORE_ORDS = 64;

private final FlatVectorsReader flatVectorsReader;
Expand All @@ -87,41 +88,38 @@ public Lucene99HnswVectorsReader(SegmentReadState state, FlatVectorsReader flatV
this.fields = new IntObjectHashMap<>();
this.flatVectorsReader = flatVectorsReader;
this.fieldInfos = state.fieldInfos;
String metaFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name, state.segmentSuffix, Lucene99HnswVectorsFormat.META_EXTENSION);
String metaFileName = IndexFileNames.segmentFileName(
state.segmentInfo.name, state.segmentSuffix, Lucene99HnswVectorsFormat.META_EXTENSION);
int versionMeta = -1;
try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) {
Throwable priorE = null;
try {
versionMeta =
CodecUtil.checkIndexHeader(
meta,
Lucene99HnswVectorsFormat.META_CODEC_NAME,
Lucene99HnswVectorsFormat.VERSION_START,
Lucene99HnswVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
versionMeta = CodecUtil.checkIndexHeader(
meta,
Lucene99HnswVectorsFormat.META_CODEC_NAME,
Lucene99HnswVectorsFormat.VERSION_START,
Lucene99HnswVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
readFields(meta);
} catch (Throwable exception) {
priorE = exception;
} finally {
CodecUtil.checkFooter(meta, priorE);
}
this.version = versionMeta;
this.vectorIndex =
openDataInput(
state,
versionMeta,
Lucene99HnswVectorsFormat.VECTOR_INDEX_EXTENSION,
Lucene99HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME,
state.context.withHints(
// Even though this input is referred to an `indexIn`, it doesn't qualify as
// FileTypeHint#INDEX since it's a large file
FileTypeHint.DATA,
FileDataHint.KNN_VECTORS,
DataAccessHint.RANDOM,
PreloadHint.INSTANCE));
this.vectorIndex = openDataInput(
state,
versionMeta,
Lucene99HnswVectorsFormat.VECTOR_INDEX_EXTENSION,
Lucene99HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME,
state.context.withHints(
// Even though this input is referred to an `indexIn`, it doesn't qualify as
// FileTypeHint#INDEX since it's a large file
FileTypeHint.DATA,
FileDataHint.KNN_VECTORS,
DataAccessHint.RANDOM,
PreloadHint.INSTANCE));
} catch (Throwable t) {
IOUtils.closeWhileSuppressingExceptions(t, this);
throw t;
Expand Down Expand Up @@ -154,18 +152,16 @@ private static IndexInput openDataInput(
String codecName,
IOContext context)
throws IOException {
String fileName =
IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
IndexInput in = state.directory.openInput(fileName, context);
try {
int versionVectorData =
CodecUtil.checkIndexHeader(
in,
codecName,
Lucene99HnswVectorsFormat.VERSION_START,
Lucene99HnswVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
int versionVectorData = CodecUtil.checkIndexHeader(
in,
codecName,
Lucene99HnswVectorsFormat.VERSION_START,
Lucene99HnswVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
if (versionMeta != versionVectorData) {
throw new CorruptIndexException(
"Format versions mismatch: meta="
Expand Down Expand Up @@ -213,12 +209,11 @@ private void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
// to avoid an undesirable dependency on the declaration and order of values
// in VectorSimilarityFunction. The list values and order must be identical
// to that of {@link o.a.l.c.l.Lucene94FieldInfosFormat#SIMILARITY_FUNCTIONS}.
public static final List<VectorSimilarityFunction> SIMILARITY_FUNCTIONS =
List.of(
VectorSimilarityFunction.EUCLIDEAN,
VectorSimilarityFunction.DOT_PRODUCT,
VectorSimilarityFunction.COSINE,
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT);
public static final List<VectorSimilarityFunction> SIMILARITY_FUNCTIONS = List.of(
VectorSimilarityFunction.EUCLIDEAN,
VectorSimilarityFunction.DOT_PRODUCT,
VectorSimilarityFunction.COSINE,
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT);

public static VectorSimilarityFunction readSimilarityFunction(DataInput input)
throws IOException {
Expand Down Expand Up @@ -330,8 +325,7 @@ private void search(
return;
}
final RandomVectorScorer scorer = scorerSupplier.get();
final KnnCollector collector =
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
final KnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
HnswGraph graph = getGraph(fieldEntry);
boolean doHnsw = knnCollector.k() < scorer.maxOrd();
Expand All @@ -340,7 +334,8 @@ private void search(
// The approximate number of vectors that would be visited if we did not filter
int unfilteredVisit = HnswGraphSearcher.expectedVisitedNodes(knnCollector.k(), graph.size());
if (acceptDocs instanceof BitSet bitSet) {
// Use approximate cardinality as this is good enough, but ensure we don't exceed the graph
// Use approximate cardinality as this is good enough, but ensure we don't
// exceed the graph
// size as that is illogical
filteredDocCount = Math.min(bitSet.approximateCardinality(), graph.size());
if (unfilteredVisit >= filteredDocCount) {
Expand All @@ -351,7 +346,8 @@ private void search(
HnswGraphSearcher.search(
scorer, collector, getGraph(fieldEntry), acceptedOrds, filteredDocCount);
} else {
// if k is larger than the number of vectors we expect to visit in an HNSW search,
// if k is larger than the number of vectors we expect to visit in an HNSW
// search,
// we can just iterate over all vectors and collect them.
int[] ords = new int[EXHAUSTIVE_BULK_SCORE_ORDS];
float[] scores = new float[EXHAUSTIVE_BULK_SCORE_ORDS];
Expand Down Expand Up @@ -440,7 +436,8 @@ private record FieldEntry(
int dimension,
int size,
int[][] nodesByLevel,
// for each level the start offsets in vectorIndex file from where to read neighbours
// for each level the start offsets in vectorIndex file from where to read
// neighbours
DirectMonotonicReader.Meta offsetsMeta,
long offsetsOffset,
int offsetsBlockShift,
Expand Down Expand Up @@ -523,16 +520,13 @@ private final class OffHeapHnswGraph extends HnswGraph {
private final int[] currentNeighborsBuffer;

OffHeapHnswGraph(FieldEntry entry, IndexInput vectorIndex) throws IOException {
this.dataIn =
vectorIndex.slice("graph-data", entry.vectorIndexOffset, entry.vectorIndexLength);
this.dataIn = vectorIndex.slice("graph-data", entry.vectorIndexOffset, entry.vectorIndexLength);
this.nodesByLevel = entry.nodesByLevel;
this.numLevels = entry.numLevels;
this.entryNode = numLevels > 1 ? nodesByLevel[numLevels - 1][0] : 0;
this.size = entry.size();
final RandomAccessInput addressesData =
vectorIndex.randomAccessSlice(entry.offsetsOffset, entry.offsetsLength);
this.graphLevelNodeOffsets =
DirectMonotonicReader.getInstance(entry.offsetsMeta, addressesData);
final RandomAccessInput addressesData = vectorIndex.randomAccessSlice(entry.offsetsOffset, entry.offsetsLength);
this.graphLevelNodeOffsets = DirectMonotonicReader.getInstance(entry.offsetsMeta, addressesData);
this.currentNeighborsBuffer = new int[entry.M * 2];
this.maxConn = entry.M;
graphLevelNodeIndexOffsets = new long[numLevels];
Expand All @@ -546,26 +540,28 @@ private final class OffHeapHnswGraph extends HnswGraph {

@Override
public void seek(int level, int targetOrd) throws IOException {
int targetIndex =
level == 0
? targetOrd
: Arrays.binarySearch(nodesByLevel[level], 0, nodesByLevel[level].length, targetOrd);
int targetIndex = level == 0
? targetOrd
: Arrays.binarySearch(nodesByLevel[level], 0, nodesByLevel[level].length, targetOrd);
assert targetIndex >= 0
: "seek level=" + level + " target=" + targetOrd + " not found: " + targetIndex;
// unsafe; no bounds checking
dataIn.seek(graphLevelNodeOffsets.get(targetIndex + graphLevelNodeIndexOffsets[level]));
arcCount = dataIn.readVInt();
assert arcCount <= currentNeighborsBuffer.length : "too many neighbors: " + arcCount;
if (arcCount > 0) {
int sum = 0;
if (version >= VERSION_GROUPVARINT) {
GroupVIntUtil.readGroupVInts(dataIn, currentNeighborsBuffer, arcCount);
for (int i = 1; i < arcCount; i++) {
currentNeighborsBuffer[i] = currentNeighborsBuffer[i - 1] + currentNeighborsBuffer[i];
// Faster prefix sum computation (see #14979)
for (int i = 0; i < arcCount; i++) {
sum += currentNeighborsBuffer[i];
currentNeighborsBuffer[i] = sum;
}
} else {
currentNeighborsBuffer[0] = dataIn.readVInt();
for (int i = 1; i < arcCount; i++) {
currentNeighborsBuffer[i] = currentNeighborsBuffer[i - 1] + dataIn.readVInt();
for (int i = 0; i < arcCount; i++) {
sum += dataIn.readVInt();
currentNeighborsBuffer[i] = sum;
}
}
}
Expand Down
Loading