Skip to content

Commit 282ea6a

Browse files
committed
added knn log ctx
1 parent 75281d5 commit 282ea6a

File tree

8 files changed

+168
-35
lines changed

8 files changed

+168
-35
lines changed

src/VecSim/vec_sim_index.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,14 @@ struct VecSimIndexAbstract : public VecSimIndexInterface {
9898
delete[] buf;
9999
}
100100
}
101+
102+
#ifdef BUILD_TESTS
103+
// Set new log context to be sent to the log callback.
104+
// Returns the previous logctx.
105+
inline void *setLogCtx(void *new_logCtx) {
106+
void *prev_logCtx = this->logCallbackCtx;
107+
this->logCallbackCtx = new_logCtx;
108+
return prev_logCtx;
109+
}
110+
#endif
101111
};

src/VecSim/vec_sim_interface.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,11 @@ void Vecsim_Log(void *ctx, const char *message) { std::cout << message << std::e
1414
timeoutCallbackFunction VecSimIndexInterface::timeoutCallback = [](void *ctx) { return 0; };
1515
logCallbackFunction VecSimIndexInterface::logCallback = Vecsim_Log;
1616
VecSimWriteMode VecSimIndexInterface::asyncWriteMode = VecSim_WriteAsync;
17+
18+
#ifdef BUILD_TESTS
19+
static inline void Vecsim_Log_DO_NOTHING(void *ctx, const char *message) {}
20+
21+
void VecSimIndexInterface::resetLogCallbackFunction() {
22+
VecSimIndexInterface::logCallback = Vecsim_Log_DO_NOTHING;
23+
}
24+
#endif

src/VecSim/vec_sim_interface.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,10 @@ struct VecSimIndexInterface : public VecsimBaseObject {
186186
VecSimIndexInterface::logCallback = callback;
187187
}
188188

189+
#ifdef BUILD_TESTS
190+
static void resetLogCallbackFunction();
191+
#endif
192+
189193
/**
190194
* @brief Allow 3rd party to set the write mode for tiered index - async insert/delete using
191195
* background jobs, or insert/delete inplace.

src/VecSim/vec_sim_tiered_index.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,20 @@ class VecSimTieredIndex : public VecSimIndexInterface {
6262
VecSimQueryParams *queryParams) override;
6363

6464
static VecSimWriteMode getWriteMode() { return asyncWriteMode; }
65+
66+
#ifdef BUILD_TESTS
67+
inline VecSimIndexAbstract<DistType> *getFlatbufferIndex() { return this->frontendIndex; }
68+
#endif
6569
};
6670

6771
template <typename DataType, typename DistType>
6872
VecSimQueryResult_List
6973
VecSimTieredIndex<DataType, DistType>::topKQuery(const void *queryBlob, size_t k,
7074
VecSimQueryParams *queryParams) {
7175
this->flatIndexGuard.lock_shared();
72-
76+
#ifdef BUILD_TESTS
77+
this->getFlatbufferIndex()->log("");
78+
#endif
7379
// If the flat buffer is empty, we can simply query the main index.
7480
if (this->frontendIndex->indexSize() == 0) {
7581
// Release the flat lock and acquire the main lock.

src/python_bindings/bindings.cpp

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,19 @@ class PyHNSWLibIndex : public PyVecSimIndex {
381381
}
382382
};
383383

384+
template <typename DistType>
385+
struct KNNLogCtx {
386+
VecSimIndexAbstract<DistType> *flat_index;
387+
size_t curr_flat_size;
388+
KNNLogCtx() : flat_index(nullptr), curr_flat_size(0) {}
389+
};
390+
384391
class PyTIEREDIndex : public PyVecSimIndex {
392+
private:
393+
VecSimIndexAbstract<float> *getFlatBuffer() {
394+
return reinterpret_cast<VecSimTieredIndex<float, float> *>(this->index.get())
395+
->getFlatbufferIndex();
396+
}
385397

386398
protected:
387399
JobQueue jobQueue; // External queue that holds the jobs.
@@ -395,6 +407,8 @@ class PyTIEREDIndex : public PyVecSimIndex {
395407
bool run_thread;
396408
std::bitset<MAX_POOL_SIZE> executions_status;
397409

410+
KNNLogCtx<float> knnLogCtx;
411+
398412
TieredIndexParams TieredIndexParams_Init() {
399413
TieredIndexParams ret = {
400414
.jobQueue = &this->jobQueue,
@@ -409,14 +423,16 @@ class PyTIEREDIndex : public PyVecSimIndex {
409423
}
410424

411425
public:
412-
explicit PyTIEREDIndex(size_t BufferLimit = 20000000)
413-
: submitCb(submit_callback), memoryCtx(0), UpdateMemCb(update_mem_callback), flatBufferLimit(BufferLimit),
414-
run_thread(true) {
426+
explicit PyTIEREDIndex(size_t BufferLimit = 1000)
427+
: submitCb(submit_callback), memoryCtx(0), UpdateMemCb(update_mem_callback),
428+
flatBufferLimit(BufferLimit), run_thread(true) {
415429

416430
for (size_t i = 0; i < THREAD_POOL_SIZE; i++) {
417431
ThreadParams params(run_thread, executions_status, i, jobQueue);
418432
thread_pool.emplace_back(thread_main_loop, params);
419433
}
434+
435+
ResetLogCB();
420436
}
421437

422438
virtual ~PyTIEREDIndex() = 0;
@@ -438,13 +454,34 @@ class PyTIEREDIndex : public PyVecSimIndex {
438454
}
439455
}
440456

457+
static void log_flat_buffer_size(void *ctx, const char *msg) {
458+
auto *knnLogCtx = reinterpret_cast<KNNLogCtx<float> *>(ctx);
459+
knnLogCtx->curr_flat_size = knnLogCtx->flat_index->indexLabelCount();
460+
}
461+
void SetKNNLogCtx() {
462+
knnLogCtx.flat_index = getFlatBuffer();
463+
knnLogCtx.curr_flat_size = 0;
464+
knnLogCtx.flat_index->setLogCtx(&knnLogCtx);
465+
this->index->setLogCallbackFunction(log_flat_buffer_size);
466+
}
467+
size_t getFlatIndexSize(const char *mode = "None") {
468+
if (!strcmp(mode, "insert_and_knn")) {
469+
return knnLogCtx.curr_flat_size;
470+
}
441471

472+
return getFlatBuffer()->indexLabelCount();
473+
}
474+
475+
void ResetLogCB() { this->index->resetLogCallbackFunction(); }
442476
static size_t GetThreadsNum() { return THREAD_POOL_SIZE; }
443477

444-
size_t getBufferLimit() {return flatBufferLimit; }
478+
size_t getBufferLimit() { return flatBufferLimit; }
445479
};
446480

447-
PyTIEREDIndex::~PyTIEREDIndex() { thread_pool_terminate(jobQueue, run_thread); }
481+
PyTIEREDIndex::~PyTIEREDIndex() {
482+
thread_pool_terminate(jobQueue, run_thread);
483+
ResetLogCB();
484+
}
448485
class PyTIERED_HNSWIndex : public PyTIEREDIndex {
449486
public:
450487
explicit PyTIERED_HNSWIndex(const HNSWParams &hnsw_params,
@@ -578,8 +615,11 @@ PYBIND11_MODULE(VecSim, m) {
578615

579616
py::class_<PyTIEREDIndex, PyVecSimIndex>(m, "TIEREDIndex")
580617
.def("wait_for_index", &PyTIERED_HNSWIndex::WaitForIndex, py::arg("waiting_duration") = 10)
618+
.def("get_curr_bf_size", &PyTIERED_HNSWIndex::getFlatIndexSize, py::arg("mode") = "None")
581619
.def("get_buffer_limit", &PyTIERED_HNSWIndex::getBufferLimit)
582-
.def_static("get_threads_num", &PyTIEREDIndex::GetThreadsNum);
620+
.def_static("get_threads_num", &PyTIEREDIndex::GetThreadsNum)
621+
.def("reset_log", &PyTIERED_HNSWIndex::ResetLogCB)
622+
.def("start_knn_log", &PyTIERED_HNSWIndex::SetKNNLogCtx);
583623

584624
py::class_<PyTIERED_HNSWIndex, PyTIEREDIndex>(m, "TIERED_HNSWIndex")
585625
.def(

src/python_bindings/tiered_index_mock.h

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*
1+
/*
22
*Copyright Redis Ltd. 2021 - present
33
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
44
*the Server Side Public License v1 (SSPLv1).
@@ -21,21 +21,6 @@ typedef struct RefManagedJob {
2121
std::weak_ptr<VecSimIndex> index_weak_ref;
2222
} RefManagedJob;
2323

24-
struct SearchJobMock : public AsyncJob {
25-
void *query; // The query vector. ownership is passed to the job in the constructor.
26-
size_t k; // The number of results to return.
27-
size_t n; // The number of vectors in the index (might be useful for the mock)
28-
size_t dim; // The dimension of the vectors in the index (might be useful for the mock)
29-
std::atomic_int &successful_searches; // A reference to a shared counter that counts the number
30-
// of successful searches.
31-
SearchJobMock(std::shared_ptr<VecSimAllocator> allocator, JobCallback searchCB,
32-
VecSimIndex *index_, void *query_, size_t k_, size_t n_, size_t dim_,
33-
std::atomic_int &successful_searches_)
34-
: AsyncJob(allocator, HNSW_SEARCH_JOB, searchCB, index_), query(query_), k(k_), n(n_),
35-
dim(dim_), successful_searches(successful_searches_) {}
36-
~SearchJobMock() { this->allocator->free_allocation(query); }
37-
};
38-
3924
using JobQueue = std::queue<RefManagedJob>;
4025
int submit_callback(void *job_queue, AsyncJob **jobs, size_t len, void *index_ctx);
4126
int update_mem_callback(void *mem_ctx, size_t mem);

tests/flow/common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,13 @@ def create_hnsw_index(dim, num_elements, metric, data_type, ef_construction=200,
4141

4242
return HNSWIndex(hnsw_params)
4343

44+
def bytes_to_mega(bytes, ndigits = 3):
45+
return round(bytes/pow(10,6), ndigits)
46+
47+
def round_(f_value, ndigits = 2):
48+
return round(f_value, ndigits)
49+
50+
51+
def round_ms(f_value, ndigits = 2):
52+
return round(f_value * 1000, ndigits)
4453

tests/flow/test_bm_hnsw_tiered_dataset.py

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,19 @@ def __init__(self, data_size = 0, initialCap = 0, M = 32, ef_c = 512, ef_r = 10,
7575

7676
data = load_data("dbpedia-768")
7777
self.num_elements = data_size if data_size != 0 else data.shape[0]
78-
self.initialCap = initialCap if initialCap != 0 else 2 * self.num_elements
78+
#self.initialCap = initialCap if initialCap != 0 else 2 * self.num_elements
79+
self.initialCap = initialCap if initialCap != 0 else self.num_elements
7980

8081
self.data = data[:self.num_elements]
8182
self.dim = len(self.data[0])
8283
self.metric = metric
83-
self.type = data_type
84+
self.data_type = data_type
8485
self.is_multi = is_multi
8586

8687
self.hnsw_params = create_hnsw_params(dim=self.dim,
8788
num_elements=self.initialCap,
8889
metric=self.metric,
89-
data_type=self.type,
90+
data_type=self.data_type,
9091
ef_construction=ef_c,
9192
m=M,
9293
ef_runtime=ef_r,
@@ -102,22 +103,18 @@ def create_tiered(self):
102103

103104
def create_hnsw(self):
104105
return HNSWIndex(self.hnsw_params)
105-
106-
def set_num_vectors_per_label(self, num_per_label = 1):
107-
self.num_per_label = num_per_label
108106

109107
def init_and_populate_flat_index(self):
110108
bfparams = BFParams()
111109
bfparams.initialCapacity = self.num_elements
112110
bfparams.dim =self.dim
113-
bfparams.type =self.type
111+
bfparams.type =self.data_type
114112
bfparams.metric =self.metric
115113
bfparams.multi = self.is_multi
116114
self.flat_index = BFIndex(bfparams)
117115

118116
for i, vector in enumerate(self.data):
119-
for _ in range(self.num_per_label):
120-
self.flat_index.add_vector(vector, i)
117+
self.flat_index.add_vector(vector, i)
121118

122119
return self.flat_index
123120

@@ -129,6 +126,16 @@ def init_and_populate_hnsw_index(self):
129126
self.hnsw_index = hnsw_index
130127
return hnsw_index
131128

129+
def populate_index(self, index):
130+
start = time.time()
131+
duration = 0
132+
for label, vector in enumerate(self.data):
133+
start_add = time.time()
134+
index.add_vector(vector, label)
135+
duration += time.time() - start_add
136+
end = time.time()
137+
return (start, duration, end)
138+
132139
def generate_random_vectors(self, num_vectors):
133140
vectors = 0
134141
np_file_path = os.path.join(f'np_{num_vectors}vec_dim{self.dim}.npy')
@@ -154,7 +161,12 @@ def insert_in_batch(self, index, data, data_first_idx, batch_size, first_label):
154161
duration += time.time() - start_add
155162
end = time.time()
156163
return (duration, end)
164+
165+
def generate_queries(self, num_queries):
166+
self.rng = np.random.default_rng(seed=47)
157167

168+
queries = self.rng.random((num_queries, self.dim))
169+
return np.float32(queries) if self.data_type == VecSimType_FLOAT32 else queries
158170

159171
def create_dbpedia():
160172
indices_ctx = DBPediaIndexCtx()
@@ -192,7 +204,7 @@ def create_tiered():
192204
create_tiered()
193205

194206
def create_dbpedia_graph():
195-
indices_ctx = DBPediaIndexCtx(data_size = 100000)
207+
indices_ctx = DBPediaIndexCtx()
196208

197209
threads_num = TIEREDIndex.get_threads_num()
198210
print(f"thread num = {threads_num}")
@@ -283,9 +295,68 @@ def create_hnsw():
283295
print(f"Start hnsw creation")
284296

285297
create_hnsw()
298+
299+
def search_insert(is_multi: bool, num_per_label = 1):
300+
indices_ctx = DBPediaIndexCtx(data_size=1000, mode=CreationMode.CREATE_TIERED_INDEX, is_multi=is_multi)
301+
index = indices_ctx.tiered_index
302+
303+
num_elements = indices_ctx.num_elements
304+
305+
query_data = indices_ctx.generate_queries(num_queries=1)
306+
307+
# Add vectors to the flat index.
308+
bf_index = indices_ctx.init_and_populate_flat_index()
309+
310+
# Start background insertion to the tiered index.
311+
index_start, _, _ = indices_ctx.populate_index(index)
312+
313+
correct = 0
314+
k = 10
315+
searches_number = 0
316+
317+
# config knn log
318+
index.start_knn_log()
319+
320+
# run knn query every 1 s.
321+
total_tiered_search_time = 0
322+
prev_bf_size = num_elements
323+
while index.hnsw_label_count() < num_elements:
324+
# For each run get the current hnsw size and the query time.
325+
bf_curr_size = index.get_curr_bf_size(mode = 'insert_and_knn')
326+
query_start = time.time()
327+
tiered_labels, _ = index.knn_query(query_data, k)
328+
query_dur = time.time() - query_start
329+
total_tiered_search_time += query_dur
330+
331+
print(f"query time = {round_ms(query_dur)} ms")
332+
333+
# BF size should decrease.
334+
print(f"bf size = {bf_curr_size}")
335+
assert bf_curr_size < prev_bf_size
336+
337+
# Run the query also in the bf index to get the ground truth results.
338+
bf_labels, _ = bf_index.knn_query(query_data, k)
339+
correct += len(np.intersect1d(tiered_labels[0], bf_labels[0]))
340+
time.sleep(1)
341+
searches_number += 1
342+
prev_bf_size = bf_curr_size
343+
344+
index.reset_log()
345+
346+
# HNSW labels count updates before the job is done, so we need to wait for the queue to be empty.
347+
index.wait_for_index(1)
348+
index_dur = time.time() - index_start
349+
print(f"indexing during search in tiered took {round_(index_dur)} s")
350+
351+
# Measure recall.
352+
recall = float(correct)/(k*searches_number)
353+
print("Average recall is:", round_(recall, 3))
354+
print("tiered query per seconds: ", round_(searches_number/total_tiered_search_time))
286355

287356
def test_main():
288357
print("Test creation")
289-
# create_dbpedia()
290-
create_dbpedia_graph()
358+
create_dbpedia()
359+
# create_dbpedia_graph()
360+
print(f"\nStart insert & search test")
361+
# search_insert(is_multi=False)
291362

0 commit comments

Comments
 (0)