Skip to content

Commit 3a49c70

Browse files
authored
Implement rangeQuery for VecSimTieredIndex - [MOD-5164] (#360)
* implemented `rangeQuery` for VecSimTieredIndex, ... including needed utility functions * renaming `merge_results.h` and moving `filter_results` to it * fix build * first test and some fixes * improved test and added a parallel test * fix a bug where we safely get (from `safeGetEntryPoint`) the old entry point but then we get the new max level when trying to search using the old one * fix tests * Update comments * review fixes * after rebase fixes * added a general comment on tiered index's guarantees
1 parent 931074b commit 3a49c70

File tree

12 files changed

+518
-60
lines changed

12 files changed

+518
-60
lines changed

src/VecSim/algorithms/hnsw/hnsw.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ class HNSWIndex : public VecSimIndexAbstract<DistType>,
231231
// (this option is used currently for tests).
232232
virtual inline bool safeCheckIfLabelExistsInIndex(labelType label,
233233
bool also_done_processing = false) const = 0;
234-
inline idType safeGetEntryPointCopy() const;
234+
inline auto safeGetEntryPointState() const;
235235
inline void lockIndexDataGuard() const;
236236
inline void unlockIndexDataGuard() const;
237237
inline void lockNodeLinks(idType node_id) const;
@@ -1901,22 +1901,22 @@ void HNSWIndex<DataType, DistType>::appendVector(const void *vector_data, const
19011901
}
19021902

19031903
template <typename DataType, typename DistType>
1904-
idType HNSWIndex<DataType, DistType>::safeGetEntryPointCopy() const {
1904+
auto HNSWIndex<DataType, DistType>::safeGetEntryPointState() const {
19051905
std::shared_lock<std::shared_mutex> lock(index_data_guard_);
1906-
return entrypoint_node_;
1906+
return std::make_pair(entrypoint_node_, max_level_);
19071907
}
19081908

19091909
template <typename DataType, typename DistType>
19101910
idType HNSWIndex<DataType, DistType>::searchBottomLayerEP(const void *query_data, void *timeoutCtx,
19111911
VecSimQueryResult_Code *rc) const {
19121912
*rc = VecSim_QueryResult_OK;
19131913

1914-
idType curr_element = safeGetEntryPointCopy();
1914+
auto [curr_element, max_level] = safeGetEntryPointState();
19151915
if (curr_element == INVALID_ID)
19161916
return curr_element; // index is empty.
19171917

19181918
DistType cur_dist = this->dist_func(query_data, getDataByInternalId(curr_element), this->dim);
1919-
for (size_t level = max_level_; level > 0 && curr_element != INVALID_ID; level--) {
1919+
for (size_t level = max_level; level > 0 && curr_element != INVALID_ID; level--) {
19201920
greedySearchLevel<true>(query_data, level, curr_element, cur_dist, timeoutCtx, rc);
19211921
}
19221922
return curr_element;
@@ -2127,7 +2127,10 @@ VecSimQueryResult_List HNSWIndex<DataType, DistType>::rangeQuery(const void *que
21272127
}
21282128

21292129
idType bottom_layer_ep = searchBottomLayerEP(query_data, timeoutCtx, &rl.code);
2130-
if (VecSim_OK != rl.code) {
2130+
// Although we checked that the index is not empty (cur_element_count == 0), it might be
2131+
// that another thread deleted all the elements or didn't finish inserting the first element
2132+
// yet. Anyway, we observed that the index is empty, so we return an empty result list.
2133+
if (VecSim_OK != rl.code || bottom_layer_ep == INVALID_ID) {
21312134
rl.results = array_new<VecSimQueryResult>(0);
21322135
return rl;
21332136
}

src/VecSim/algorithms/hnsw/hnsw_tiered.h

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,6 @@ class TieredHNSWIndex : public VecSimTieredIndex<DataType, DistType> {
179179
void increaseCapacity() override {}
180180

181181
// TODO: Implement the actual methods instead of these temporary ones.
182-
VecSimQueryResult_List rangeQuery(const void *queryBlob, double radius,
183-
VecSimQueryParams *queryParams) override {
184-
return this->backendIndex->rangeQuery(queryBlob, radius, queryParams);
185-
}
186182
VecSimIndexInfo info() const override { return this->backendIndex->info(); }
187183
VecSimInfoIterator *infoIterator() const override { return this->backendIndex->infoIterator(); }
188184
VecSimBatchIterator *newBatchIterator(const void *queryBlob,
@@ -766,8 +762,7 @@ double TieredHNSWIndex<DataType, DistType>::getDistanceFrom(labelType label,
766762

767763
// Try to get the distance from the Main index.
768764
this->mainIndexGuard.lock_shared();
769-
auto hnsw = getHNSWIndex();
770-
auto hnsw_dist = hnsw->safeGetDistanceFrom(label, blob);
765+
auto hnsw_dist = getHNSWIndex()->safeGetDistanceFrom(label, blob);
771766
this->mainIndexGuard.unlock_shared();
772767

773768
// Return the minimum distance that is not NaN.
@@ -856,7 +851,6 @@ TieredHNSWIndex<DataType, DistType>::TieredHNSW_BatchIterator::getNextResults(
856851
auto tail = this->flat_iterator->getNextResults(
857852
n_res - VecSimQueryResult_Len(this->flat_results), BY_SCORE_THEN_ID);
858853
concat_results(this->flat_results, tail);
859-
VecSimQueryResult_Free(tail);
860854

861855
if (!isMulti) {
862856
// On single-value indexes, duplicates will never appear in the hnsw results before

src/VecSim/algorithms/hnsw/hnsw_tiered_tests_friends.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_writeInPlaceMode_Test)
2626
INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_switchWriteModes_Test)
2727
INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_bufferLimit_Test)
2828
INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_bufferLimitAsync_Test)
29+
INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_RangeSearch_Test)
30+
INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_parallelRangeSearch_Test)
2931

3032
INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_insertJobAsync_Test)
3133
INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_insertJobAsyncMulti_Test)

src/VecSim/utils/merge_results.h renamed to src/VecSim/utils/query_result_utils.h

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ template <bool withSet>
3333
VecSimQueryResult *merge_results(VecSimQueryResult *&first, const VecSimQueryResult *first_end,
3434
VecSimQueryResult *&second, const VecSimQueryResult *second_end,
3535
size_t limit) {
36-
VecSimQueryResult *results = array_new<VecSimQueryResult>(limit);
36+
// Allocate the merged results array with the minimum size needed.
37+
// Min of the limit and the sum of the lengths of the two arrays.
38+
VecSimQueryResult *results = array_new<VecSimQueryResult>(
39+
std::min(limit, (size_t)(first_end - first) + (size_t)(second_end - second)));
3740
// Will hold the ids of the results we've already added to the merged results.
3841
// Will be used only if withSet is true.
3942
std::unordered_set<size_t> ids;
@@ -92,9 +95,54 @@ VecSimQueryResult_List merge_result_lists(VecSimQueryResult_List first,
9295
return mergedResults;
9396
}
9497

98+
// Concatenate the results of two queries into the results of the first query, consuming the second.
9599
static inline void concat_results(VecSimQueryResult_List &first, VecSimQueryResult_List &second) {
96100
auto &dst = first.results;
97101
auto &src = second.results;
98102

99103
dst = array_concat(dst, src);
104+
VecSimQueryResult_Free(second);
105+
}
106+
107+
// Sorts the results by id and removes duplicates.
108+
// Assumes that a result can appear at most twice in the results list.
109+
// @returns the number of unique results. This should be set to be the new length of the results
110+
template <bool IsMulti>
111+
void filter_results_by_id(VecSimQueryResult_List results) {
112+
if (VecSimQueryResult_Len(results) < 2) {
113+
return;
114+
}
115+
sort_results_by_id(results);
116+
117+
size_t i, cur_end;
118+
for (i = 0, cur_end = 0; i < VecSimQueryResult_Len(results) - 1; i++, cur_end++) {
119+
const VecSimQueryResult *cur_res = results.results + i;
120+
const VecSimQueryResult *next_res = cur_res + 1;
121+
if (VecSimQueryResult_GetId(cur_res) == VecSimQueryResult_GetId(next_res)) {
122+
if (IsMulti) {
123+
// On multi value index, scores might be different and we want to keep the lower
124+
// score.
125+
if (VecSimQueryResult_GetScore(cur_res) < VecSimQueryResult_GetScore(next_res)) {
126+
results.results[cur_end] = *cur_res;
127+
} else {
128+
results.results[cur_end] = *next_res;
129+
}
130+
} else {
131+
// On single value index, scores are the same so we can keep any of the results.
132+
results.results[cur_end] = *cur_res;
133+
}
134+
// Assuming every id can appear at most twice, we can skip the next comparison between
135+
// the current and the next result.
136+
i++;
137+
} else {
138+
results.results[cur_end] = *cur_res;
139+
}
140+
}
141+
// If the last result is unique, we need to add it to the results.
142+
if (i == VecSimQueryResult_Len(results) - 1) {
143+
results.results[cur_end] = results.results[i];
144+
// Logically, we should increment cur_end and i here, but we don't need to because it won't
145+
// affect the rest of the function.
146+
}
147+
array_pop_back_n(results.results, i - cur_end);
100148
}

src/VecSim/utils/vec_utils.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,17 @@ void sort_results_by_score_then_id(VecSimQueryResult_List rl) {
6565
(__compar_fn_t)cmpVecSimQueryResultByScoreThenId);
6666
}
6767

68+
void sort_results(VecSimQueryResult_List rl, VecSimQueryResult_Order order) {
69+
switch (order) {
70+
case BY_ID:
71+
return sort_results_by_id(rl);
72+
case BY_SCORE:
73+
return sort_results_by_score(rl);
74+
case BY_SCORE_THEN_ID:
75+
return sort_results_by_score_then_id(rl);
76+
}
77+
}
78+
6879
VecSimResolveCode validate_positive_integer_param(VecSimRawParam rawParam, long long *val) {
6980
char *ep; // For checking that strtoll used all rawParam.valLen chars.
7081
errno = 0;

src/VecSim/utils/vec_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ void sort_results_by_score(VecSimQueryResult_List results);
8383

8484
void sort_results_by_score_then_id(VecSimQueryResult_List results);
8585

86+
void sort_results(VecSimQueryResult_List results, VecSimQueryResult_Order order);
87+
8688
VecSimResolveCode validate_positive_integer_param(VecSimRawParam rawParam, long long *val);
8789

8890
VecSimResolveCode validate_positive_double_param(VecSimRawParam rawParam, double *val);

src/VecSim/vec_sim.cpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -221,14 +221,7 @@ extern "C" VecSimQueryResult_List VecSimIndex_RangeQuery(VecSimIndex *index, con
221221
if (radius < 0) {
222222
throw std::runtime_error("radius must be non-negative");
223223
}
224-
VecSimQueryResult_List results = index->rangeQueryWrapper(queryBlob, radius, queryParams);
225-
226-
if (order == BY_SCORE) {
227-
sort_results_by_score(results);
228-
} else {
229-
sort_results_by_id(results);
230-
}
231-
return results;
224+
return index->rangeQueryWrapper(queryBlob, radius, queryParams, order);
232225
}
233226

234227
extern "C" void VecSimIndex_Free(VecSimIndex *index) {

src/VecSim/vec_sim_index.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,16 @@ struct VecSimIndexAbstract : public VecSimIndexInterface {
8888
inline VecSimMetric getMetric() const { return metric; }
8989
inline size_t getDataSize() const { return data_size; }
9090

91+
virtual VecSimQueryResult_List rangeQuery(const void *queryBlob, double radius,
92+
VecSimQueryParams *queryParams) = 0;
93+
VecSimQueryResult_List rangeQuery(const void *queryBlob, double radius,
94+
VecSimQueryParams *queryParams,
95+
VecSimQueryResult_Order order) override {
96+
auto results = rangeQuery(queryBlob, radius, queryParams);
97+
sort_results(results, order);
98+
return results;
99+
}
100+
91101
void log(const char *fmt, ...) const {
92102
if (VecSimIndexInterface::logCallback) {
93103
// Format the message and call the callback
@@ -136,11 +146,12 @@ struct VecSimIndexAbstract : public VecSimIndexInterface {
136146
}
137147

138148
virtual VecSimQueryResult_List rangeQueryWrapper(const void *queryBlob, double radius,
139-
VecSimQueryParams *queryParams) override {
149+
VecSimQueryParams *queryParams,
150+
VecSimQueryResult_Order order) override {
140151
char processed_blob[this->data_size];
141152
const void *query_to_send = processBlob(queryBlob, processed_blob);
142153

143-
return this->rangeQuery(query_to_send, radius, queryParams);
154+
return this->rangeQuery(query_to_send, radius, queryParams, order);
144155
}
145156

146157
virtual VecSimBatchIterator *

src/VecSim/vec_sim_interface.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ struct VecSimIndexInterface : public VecsimBaseObject {
138138
* blob.
139139
*/
140140
virtual VecSimQueryResult_List rangeQueryWrapper(const void *queryBlob, double radius,
141-
VecSimQueryParams *queryParams) = 0;
141+
VecSimQueryParams *queryParams,
142+
VecSimQueryResult_Order order) = 0;
142143
/**
143144
* @brief Search for the vectors that are in a given range in the index with respect to a given
144145
* vector. The results can be ordered by their score or id.
@@ -153,7 +154,9 @@ struct VecSimIndexInterface : public VecsimBaseObject {
153154
* VecSimQueryResult_Iterator.
154155
*/
155156
virtual VecSimQueryResult_List rangeQuery(const void *queryBlob, double radius,
156-
VecSimQueryParams *queryParams) = 0;
157+
VecSimQueryParams *queryParams,
158+
VecSimQueryResult_Order order) = 0;
159+
157160
/**
158161
* @brief Return index information.
159162
*

src/VecSim/vec_sim_tiered_index.h

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include "vec_sim_index.h"
44
#include "algorithms/brute_force/brute_force.h"
55
#include "VecSim/batch_iterator.h"
6-
#include "VecSim/utils/merge_results.h"
6+
#include "VecSim/utils/query_result_utils.h"
77

88
#include <shared_mutex>
99

@@ -20,6 +20,9 @@ struct AsyncJob : public VecsimBaseObject {
2020
: VecsimBaseObject(allocator), jobType(type), Execute(callback), index(index_ref) {}
2121
};
2222

23+
// All read operations (including KNN, range, batch iterators and get-distance-from) are guaranteed
24+
// to consider all vectors that were added to the index before the query was submitted. The results
25+
// may include vectors that were added after the query was submitted, with no guarantees.
2326
template <typename DataType, typename DistType>
2427
class VecSimTieredIndex : public VecSimIndexInterface {
2528
protected:
@@ -62,6 +65,9 @@ class VecSimTieredIndex : public VecSimIndexInterface {
6265

6366
VecSimQueryResult_List topKQuery(const void *queryBlob, size_t k,
6467
VecSimQueryParams *queryParams) override;
68+
VecSimQueryResult_List rangeQuery(const void *queryBlob, double radius,
69+
VecSimQueryParams *queryParams,
70+
VecSimQueryResult_Order order) override;
6571

6672
// Return the current state of the global write mode (async/in-place).
6773
static VecSimWriteMode getWriteMode() { return VecSimIndexInterface::asyncWriteMode; }
@@ -83,12 +89,13 @@ class VecSimTieredIndex : public VecSimIndexInterface {
8389
}
8490

8591
virtual VecSimQueryResult_List rangeQueryWrapper(const void *queryBlob, double radius,
86-
VecSimQueryParams *queryParams) override {
92+
VecSimQueryParams *queryParams,
93+
VecSimQueryResult_Order order) override {
8794
// Will be used only if a processing stage is needed
8895
char processed_blob[this->backendIndex->getDataSize()];
8996
const void *query_to_send = this->backendIndex->processBlob(queryBlob, processed_blob);
9097

91-
return this->rangeQuery(query_to_send, radius, queryParams);
98+
return this->rangeQuery(query_to_send, radius, queryParams, order);
9299
}
93100

94101
virtual VecSimBatchIterator *
@@ -151,3 +158,74 @@ VecSimTieredIndex<DataType, DistType>::topKQuery(const void *queryBlob, size_t k
151158
}
152159
}
153160
}
161+
162+
template <typename DataType, typename DistType>
163+
VecSimQueryResult_List
164+
VecSimTieredIndex<DataType, DistType>::rangeQuery(const void *queryBlob, double radius,
165+
VecSimQueryParams *queryParams,
166+
VecSimQueryResult_Order order) {
167+
this->flatIndexGuard.lock_shared();
168+
169+
// If the flat buffer is empty, we can simply query the main index.
170+
if (this->frontendIndex->indexSize() == 0) {
171+
// Release the flat lock and acquire the main lock.
172+
this->flatIndexGuard.unlock_shared();
173+
174+
// Simply query the main index and return the results while holding the lock.
175+
this->mainIndexGuard.lock_shared();
176+
auto res = this->backendIndex->rangeQuery(queryBlob, radius, queryParams);
177+
this->mainIndexGuard.unlock_shared();
178+
179+
// We could have passed the order to the main index, but we can sort them here after
180+
// unlocking it instead.
181+
sort_results(res, order);
182+
return res;
183+
} else {
184+
// No luck... first query the flat buffer and release the lock.
185+
auto flat_results = this->frontendIndex->rangeQuery(queryBlob, radius, queryParams);
186+
this->flatIndexGuard.unlock_shared();
187+
188+
// If the query failed (currently only on timeout), return the error code and the partial
189+
// results.
190+
if (flat_results.code != VecSim_QueryResult_OK) {
191+
return flat_results;
192+
}
193+
194+
// Lock the main index and query it.
195+
this->mainIndexGuard.lock_shared();
196+
auto main_results = this->backendIndex->rangeQuery(queryBlob, radius, queryParams);
197+
this->mainIndexGuard.unlock_shared();
198+
199+
// Merge the results and return, avoiding duplicates.
200+
// At this point, the return code of the FLAT index is OK, and the return code of the MAIN
201+
// index is either OK or TIMEOUT. Make sure to return the return code of the MAIN index.
202+
if (BY_SCORE == order) {
203+
sort_results_by_score_then_id(main_results);
204+
sort_results_by_score_then_id(flat_results);
205+
206+
// Keep the return code of the main index.
207+
auto code = main_results.code;
208+
209+
// Merge the sorted results with no limit (all the results are valid).
210+
VecSimQueryResult_List ret;
211+
if (this->backendIndex->isMultiValue()) {
212+
ret = merge_result_lists<true>(main_results, flat_results, -1);
213+
} else {
214+
ret = merge_result_lists<false>(main_results, flat_results, -1);
215+
}
216+
// Restore the return code and return.
217+
ret.code = code;
218+
return ret;
219+
220+
} else { // BY_ID
221+
// Notice that we don't modify the return code of the main index in any step.
222+
concat_results(main_results, flat_results);
223+
if (this->backendIndex->isMultiValue()) {
224+
filter_results_by_id<true>(main_results);
225+
} else {
226+
filter_results_by_id<false>(main_results);
227+
}
228+
return main_results;
229+
}
230+
}
231+
}

0 commit comments

Comments
 (0)