Skip to content
This repository was archived by the owner on Aug 16, 2023. It is now read-only.

Commit 20ea105

Browse files
committed
Fix bruteforce cosine
Signed-off-by: zh Wang <[email protected]>
1 parent 62c0a4b commit 20ea105

File tree

5 files changed

+247
-28
lines changed

5 files changed

+247
-28
lines changed

src/common/comp/brute_force.cc

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,6 @@ expected<DataSetPtr>
3333
BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config,
3434
const BitsetView& bitset) {
3535
std::string metric_str = config[meta::METRIC_TYPE].get<std::string>();
36-
bool is_cosine = IsMetricType(metric_str, metric::COSINE);
37-
if (is_cosine) {
38-
Normalize(*base_dataset);
39-
}
4036

4137
auto xb = base_dataset->GetTensor();
4238
auto nb = base_dataset->GetRows();
@@ -54,6 +50,13 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
5450
auto labels = new int64_t[nq * topk];
5551
auto distances = new float[nq * topk];
5652

53+
bool is_cosine = IsMetricType(metric_str, metric::COSINE);
54+
std::unique_ptr<float[]> norms = nullptr;
55+
if (is_cosine) {
56+
norms = std::make_unique<float[]>(nb);
57+
faiss::fvec_norms_L2(norms.get(), (const float*)xb, dim, nb);
58+
}
59+
5760
auto pool = ThreadPool::GetGlobalThreadPool();
5861
std::vector<folly::Future<Status>> futs;
5962
futs.reserve(nq);
@@ -71,13 +74,17 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
7174
}
7275
case faiss::METRIC_INNER_PRODUCT: {
7376
auto cur_query = (float*)xq + dim * index;
74-
if (is_cosine) {
75-
NormalizeVec(cur_query, dim);
76-
}
7777
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
7878
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
7979
break;
8080
}
81+
case faiss::METRIC_COSINE: {
82+
auto cur_query = (float*)xq + dim * index;
83+
NormalizeVec(cur_query, dim);
84+
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
85+
faiss::knn_cosine(cur_query, (const float*)xb, dim, 1, nb, &buf, norms.get(), bitset);
86+
break;
87+
}
8188
case faiss::METRIC_Jaccard: {
8289
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
8390
faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances};
@@ -123,10 +130,6 @@ Status
123130
BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, float* dis,
124131
const Json& config, const BitsetView& bitset) {
125132
std::string metric_str = config[meta::METRIC_TYPE].get<std::string>();
126-
bool is_cosine = IsMetricType(metric_str, metric::COSINE);
127-
if (is_cosine) {
128-
Normalize(*base_dataset);
129-
}
130133

131134
auto xb = base_dataset->GetTensor();
132135
auto nb = base_dataset->GetRows();
@@ -150,6 +153,13 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
150153

151154
auto faiss_metric_type = metric_type.value();
152155

156+
bool is_cosine = IsMetricType(metric_str, metric::COSINE);
157+
std::unique_ptr<float[]> norms = nullptr;
158+
if (is_cosine) {
159+
norms = std::make_unique<float[]>(nb);
160+
faiss::fvec_norms_L2(norms.get(), (const float*)xb, dim, nb);
161+
}
162+
153163
auto pool = ThreadPool::GetGlobalThreadPool();
154164
std::vector<folly::Future<Status>> futs;
155165
futs.reserve(nq);
@@ -167,13 +177,17 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
167177
}
168178
case faiss::METRIC_INNER_PRODUCT: {
169179
auto cur_query = (float*)xq + dim * index;
170-
if (is_cosine) {
171-
NormalizeVec(cur_query, dim);
172-
}
173180
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
174181
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
175182
break;
176183
}
184+
case faiss::METRIC_COSINE: {
185+
auto cur_query = (float*)xq + dim * index;
186+
NormalizeVec(cur_query, dim);
187+
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
188+
faiss::knn_cosine(cur_query, (const float*)xb, dim, 1, nb, &buf, norms.get(), bitset);
189+
break;
190+
}
177191
case faiss::METRIC_Jaccard: {
178192
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
179193
faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances};
@@ -221,11 +235,6 @@ expected<DataSetPtr>
221235
BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config,
222236
const BitsetView& bitset) {
223237
std::string metric_str = config[meta::METRIC_TYPE].get<std::string>();
224-
bool is_cosine = IsMetricType(metric_str, metric::COSINE);
225-
if (is_cosine) {
226-
Normalize(*base_dataset);
227-
}
228-
229238
auto xb = base_dataset->GetTensor();
230239
auto nb = base_dataset->GetRows();
231240
auto dim = base_dataset->GetDim();
@@ -241,6 +250,12 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
241250
float range_filter = cfg.range_filter.value();
242251

243252
ASSIGN_OR_RETURN(faiss::MetricType, faiss_metric_type, Str2FaissMetricType(cfg.metric_type.value()));
253+
bool is_cosine = IsMetricType(metric_str, metric::COSINE);
254+
std::unique_ptr<float[]> norms = nullptr;
255+
if (is_cosine) {
256+
norms = std::make_unique<float[]>(nb);
257+
faiss::fvec_norms_L2(norms.get(), (const float*)xb, dim, nb);
258+
}
244259
auto pool = ThreadPool::GetGlobalThreadPool();
245260

246261
std::vector<std::vector<int64_t>> result_id_array(nq);
@@ -262,12 +277,16 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
262277
case faiss::METRIC_INNER_PRODUCT: {
263278
is_ip = true;
264279
auto cur_query = (float*)xq + dim * index;
265-
if (is_cosine) {
266-
NormalizeVec(cur_query, dim);
267-
}
268280
faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset);
269281
break;
270282
}
283+
case faiss::METRIC_COSINE: {
284+
auto cur_query = (float*)xq + dim * index;
285+
NormalizeVec(cur_query, dim);
286+
faiss::range_search_cosine(cur_query, (const float*)xb, dim, 1, nb, radius, &res, norms.get(),
287+
bitset);
288+
break;
289+
}
271290
case faiss::METRIC_Jaccard: {
272291
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
273292
faiss::binary_range_search<faiss::CMin<float, int64_t>, float>(

src/common/metric.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Str2FaissMetricType(std::string metric) {
2727
static const std::unordered_map<std::string, faiss::MetricType> metric_map = {
2828
{metric::L2, faiss::MetricType::METRIC_L2},
2929
{metric::IP, faiss::MetricType::METRIC_INNER_PRODUCT},
30-
{metric::COSINE, faiss::MetricType::METRIC_INNER_PRODUCT},
30+
{metric::COSINE, faiss::MetricType::METRIC_COSINE},
3131
{metric::HAMMING, faiss::MetricType::METRIC_Hamming},
3232
{metric::JACCARD, faiss::MetricType::METRIC_Jaccard},
3333
{metric::SUBSTRUCTURE, faiss::MetricType::METRIC_Substructure},

thirdparty/faiss/faiss/MetricType.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ enum MetricType {
3030
METRIC_Substructure, ///< Tversky case alpha = 0, beta = 1
3131
METRIC_Superstructure, ///< Tversky case alpha = 1, beta = 0
3232

33+
METRIC_COSINE,
34+
3335
/// some additional metrics defined in scipy.spatial.distance
3436
METRIC_Canberra = 20,
3537
METRIC_BrayCurtis,

0 commit comments

Comments
 (0)