Skip to content

Commit 0fdca34

Browse files
authored
Normalize query vector in batch search cosine (#144)
* Allocate and copy query blob in batch iterator creation + normalize it in case of cosine
1 parent b341bb7 commit 0fdca34

File tree

12 files changed

+167
-15
lines changed

12 files changed

+167
-15
lines changed

src/VecSim/algorithms/brute_force/bf_batch_iterator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ VecSimQueryResult *BF_BatchIterator::selectBasedSearch(size_t n_res) {
116116
return results;
117117
}
118118

119-
BF_BatchIterator::BF_BatchIterator(const void *query_vector, const BruteForceIndex *bf_index,
119+
BF_BatchIterator::BF_BatchIterator(void *query_vector, const BruteForceIndex *bf_index,
120120
std::shared_ptr<VecSimAllocator> allocator)
121121
: VecSimBatchIterator(query_vector, allocator), index(bf_index), scores_valid_start_pos(0) {
122122
BF_BatchIterator::next_id++;

src/VecSim/algorithms/brute_force/bf_batch_iterator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class BF_BatchIterator : public VecSimBatchIterator {
2020
void swapScores(const unordered_map<size_t, size_t> &TopCandidatesIndices, size_t res_num);
2121

2222
public:
23-
BF_BatchIterator(const void *query_vector, const BruteForceIndex *index,
23+
BF_BatchIterator(void *query_vector, const BruteForceIndex *index,
2424
std::shared_ptr<VecSimAllocator> allocator);
2525

2626
inline const BruteForceIndex *getIndex() const { return index; };

src/VecSim/algorithms/brute_force/brute_force.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <memory>
1010
#include <cstring>
1111
#include <queue>
12+
#include <cassert>
1213

1314
using namespace std;
1415

@@ -229,7 +230,7 @@ VecSimQueryResult_List BruteForceIndex::topKQuery(const void *queryBlob, size_t
229230
return results;
230231
}
231232

232-
VecSimIndexInfo BruteForceIndex::info() {
233+
VecSimIndexInfo BruteForceIndex::info() const {
233234

234235
VecSimIndexInfo info;
235236
info.algo = VecSimAlgo_BF;
@@ -281,7 +282,16 @@ VecSimInfoIterator *BruteForceIndex::infoIterator() {
281282
}
282283

283284
VecSimBatchIterator *BruteForceIndex::newBatchIterator(const void *queryBlob) {
284-
return new (this->allocator) BF_BatchIterator(queryBlob, this, this->allocator);
285+
// As this is the only supported type, we always allocate 4 bytes for every element in the
286+
// vector.
287+
assert(this->vecType == VecSimType_FLOAT32);
288+
auto *queryBlobCopy = this->allocator->allocate(sizeof(float) * this->dim);
289+
memcpy(queryBlobCopy, queryBlob, dim * sizeof(float));
290+
if (metric == VecSimMetric_Cosine) {
291+
float_vector_normalize((float *)queryBlobCopy, dim);
292+
}
293+
// Ownership of queryBlobCopy moves to BF_BatchIterator that will free it at the end.
294+
return new (this->allocator) BF_BatchIterator(queryBlobCopy, this, this->allocator);
285295
}
286296

287297
bool BruteForceIndex::preferAdHocSearch(size_t subsetSize, size_t k) {

src/VecSim/algorithms/brute_force/brute_force.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class BruteForceIndex : public VecSimIndex {
2323
VecSimQueryParams *qparams) override;
2424
virtual VecSimQueryResult_List topKQuery(const void *queryBlob, size_t k,
2525
VecSimQueryParams *queryParams) override;
26-
virtual VecSimIndexInfo info() override;
26+
virtual VecSimIndexInfo info() const override;
2727
virtual VecSimInfoIterator *infoIterator() override;
2828
virtual VecSimBatchIterator *newBatchIterator(const void *queryBlob) override;
2929
bool preferAdHocSearch(size_t subsetSize, size_t k) override;

src/VecSim/algorithms/hnsw/hnsw_batch_iterator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ candidatesMaxHeap HNSW_BatchIterator::scanGraph(candidatesMinHeap &candidates,
127127
return top_candidates;
128128
}
129129

130-
HNSW_BatchIterator::HNSW_BatchIterator(const void *query_vector, HNSWIndex *index_wrapper,
130+
HNSW_BatchIterator::HNSW_BatchIterator(void *query_vector, HNSWIndex *index_wrapper,
131131
std::shared_ptr<VecSimAllocator> allocator)
132132
: VecSimBatchIterator(query_vector, std::move(allocator)), index_wrapper(index_wrapper),
133133
depleted(false), top_candidates_extras(this->allocator), candidates(this->allocator) {

src/VecSim/algorithms/hnsw/hnsw_batch_iterator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class HNSW_BatchIterator : public VecSimBatchIterator {
3636
inline bool hasVisitedNode(idType node_id) const;
3737

3838
public:
39-
HNSW_BatchIterator(const void *query_vector, HNSWIndex *index,
39+
HNSW_BatchIterator(void *query_vector, HNSWIndex *index,
4040
std::shared_ptr<VecSimAllocator> allocator);
4141

4242
VecSimQueryResult_List getNextResults(size_t n_res, VecSimQueryResult_Order order) override;

src/VecSim/algorithms/hnsw/hnsw_wrapper.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ VecSimQueryResult_List HNSWIndex::topKQuery(const void *query_data, size_t k,
133133
}
134134
}
135135

136-
VecSimIndexInfo HNSWIndex::info() {
136+
VecSimIndexInfo HNSWIndex::info() const {
137137

138138
VecSimIndexInfo info;
139139
info.algo = VecSimAlgo_HNSWLIB;
@@ -152,7 +152,16 @@ VecSimIndexInfo HNSWIndex::info() {
152152
}
153153

154154
VecSimBatchIterator *HNSWIndex::newBatchIterator(const void *queryBlob) {
155-
return new (this->allocator) HNSW_BatchIterator(queryBlob, this, this->allocator);
155+
// As this is the only supported type, we always allocate 4 bytes for every element in the
156+
// vector.
157+
assert(this->vecType == VecSimType_FLOAT32);
158+
auto *queryBlobCopy = this->allocator->allocate(sizeof(float) * this->dim);
159+
memcpy(queryBlobCopy, queryBlob, dim * sizeof(float));
160+
if (metric == VecSimMetric_Cosine) {
161+
float_vector_normalize((float *)queryBlobCopy, dim);
162+
}
163+
// Ownership of queryBlobCopy moves to HNSW_BatchIterator that will free it at the end.
164+
return new (this->allocator) HNSW_BatchIterator(queryBlobCopy, this, this->allocator);
156165
}
157166

158167
VecSimInfoIterator *HNSWIndex::infoIterator() {

src/VecSim/algorithms/hnsw/hnsw_wrapper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class HNSWIndex : public VecSimIndex {
2020
VecSimQueryParams *qparams) override;
2121
virtual VecSimQueryResult_List topKQuery(const void *queryBlob, size_t k,
2222
VecSimQueryParams *queryParams) override;
23-
virtual VecSimIndexInfo info() override;
23+
virtual VecSimIndexInfo info() const override;
2424
virtual VecSimInfoIterator *infoIterator() override;
2525
virtual VecSimBatchIterator *newBatchIterator(const void *queryBlob) override;
2626
bool preferAdHocSearch(size_t subsetSize, size_t k) override;

src/VecSim/batch_iterator.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@
99
*/
1010
struct VecSimBatchIterator : public VecsimBaseObject {
1111
private:
12-
const void *query_vector;
12+
void *query_vector;
1313
size_t returned_results_count;
1414

1515
public:
16-
explicit VecSimBatchIterator(const void *query_vector,
17-
std::shared_ptr<VecSimAllocator> allocator)
16+
explicit VecSimBatchIterator(void *query_vector, std::shared_ptr<VecSimAllocator> allocator)
1817
: VecsimBaseObject(allocator), query_vector(query_vector), returned_results_count(0){};
1918

2019
inline const void *getQueryBlob() const { return query_vector; }
@@ -35,5 +34,5 @@ struct VecSimBatchIterator : public VecsimBaseObject {
3534
// Reset the iterator to the initial state, before any results has been returned.
3635
virtual void reset() = 0;
3736

38-
virtual ~VecSimBatchIterator() = default;
37+
virtual ~VecSimBatchIterator() { allocator->free_allocation(this->query_vector); };
3938
};

src/VecSim/vec_sim_index.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ struct VecSimIndex : public VecsimBaseObject {
9090
*
9191
* @return Index general and specific meta-data.
9292
*/
93-
virtual VecSimIndexInfo info() = 0;
93+
virtual VecSimIndexInfo info() const = 0;
9494

9595
/**
9696
* @brief Returns an index information in an iterable structure.

0 commit comments

Comments
 (0)