Skip to content
This repository was archived by the owner on Aug 16, 2023. It is now read-only.

Commit 8629c01

Browse files
committed
Fix cosine bruteforce
Signed-off-by: zh Wang <[email protected]>
1 parent 62c0a4b commit 8629c01

File tree

5 files changed

+196
-21
lines changed

5 files changed

+196
-21
lines changed

src/common/comp/brute_force.cc

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,6 @@ expected<DataSetPtr>
3333
BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config,
3434
const BitsetView& bitset) {
3535
std::string metric_str = config[meta::METRIC_TYPE].get<std::string>();
36-
bool is_cosine = IsMetricType(metric_str, metric::COSINE);
37-
if (is_cosine) {
38-
Normalize(*base_dataset);
39-
}
40-
4136
auto xb = base_dataset->GetTensor();
4237
auto nb = base_dataset->GetRows();
4338
auto dim = base_dataset->GetDim();
@@ -71,13 +66,17 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
7166
}
7267
case faiss::METRIC_INNER_PRODUCT: {
7368
auto cur_query = (float*)xq + dim * index;
74-
if (is_cosine) {
75-
NormalizeVec(cur_query, dim);
76-
}
7769
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
7870
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
7971
break;
8072
}
73+
case faiss::METRIC_COSINE: {
74+
auto cur_query = (float*)xq + dim * index;
75+
NormalizeVec(cur_query, dim);
76+
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
77+
faiss::knn_cosine(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
78+
break;
79+
}
8180
case faiss::METRIC_Jaccard: {
8281
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
8382
faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances};
@@ -123,11 +122,6 @@ Status
123122
BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, float* dis,
124123
const Json& config, const BitsetView& bitset) {
125124
std::string metric_str = config[meta::METRIC_TYPE].get<std::string>();
126-
bool is_cosine = IsMetricType(metric_str, metric::COSINE);
127-
if (is_cosine) {
128-
Normalize(*base_dataset);
129-
}
130-
131125
auto xb = base_dataset->GetTensor();
132126
auto nb = base_dataset->GetRows();
133127
auto dim = base_dataset->GetDim();
@@ -167,13 +161,17 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
167161
}
168162
case faiss::METRIC_INNER_PRODUCT: {
169163
auto cur_query = (float*)xq + dim * index;
170-
if (is_cosine) {
171-
NormalizeVec(cur_query, dim);
172-
}
173164
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
174165
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
175166
break;
176167
}
168+
case faiss::METRIC_COSINE: {
169+
auto cur_query = (float*)xq + dim * index;
170+
NormalizeVec(cur_query, dim);
171+
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
172+
faiss::knn_cosine(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
173+
break;
174+
}
177175
case faiss::METRIC_Jaccard: {
178176
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
179177
faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances};
@@ -262,12 +260,16 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
262260
case faiss::METRIC_INNER_PRODUCT: {
263261
is_ip = true;
264262
auto cur_query = (float*)xq + dim * index;
265-
if (is_cosine) {
266-
NormalizeVec(cur_query, dim);
267-
}
268263
faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset);
269264
break;
270265
}
266+
case faiss::METRIC_COSINE: {
267+
is_ip = true;
268+
auto cur_query = (float*)xq + dim * index;
269+
NormalizeVec(cur_query, dim);
270+
faiss::range_search_cosine(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset);
271+
break;
272+
}
271273
case faiss::METRIC_Jaccard: {
272274
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
273275
faiss::binary_range_search<faiss::CMin<float, int64_t>, float>(

src/common/metric.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Str2FaissMetricType(std::string metric) {
2727
static const std::unordered_map<std::string, faiss::MetricType> metric_map = {
2828
{metric::L2, faiss::MetricType::METRIC_L2},
2929
{metric::IP, faiss::MetricType::METRIC_INNER_PRODUCT},
30-
{metric::COSINE, faiss::MetricType::METRIC_INNER_PRODUCT},
30+
{metric::COSINE, faiss::MetricType::METRIC_COSINE},
3131
{metric::HAMMING, faiss::MetricType::METRIC_Hamming},
3232
{metric::JACCARD, faiss::MetricType::METRIC_Jaccard},
3333
{metric::SUBSTRUCTURE, faiss::MetricType::METRIC_Substructure},

thirdparty/faiss/faiss/MetricType.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ enum MetricType {
2929
METRIC_Hamming,
3030
METRIC_Substructure, ///< Tversky case alpha = 0, beta = 1
3131
METRIC_Superstructure, ///< Tversky case alpha = 1, beta = 0
32-
32+
METRIC_COSINE,
3333
/// some additional metrics defined in scipy.spatial.distance
3434
METRIC_Canberra = 20,
3535
METRIC_BrayCurtis,

thirdparty/faiss/faiss/utils/distances.cpp

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <cmath>
1515
#include <cstdio>
1616
#include <cstring>
17+
#include "simd/hook.h"
1718

1819
#include <omp.h>
1920

@@ -284,6 +285,44 @@ void exhaustive_L2sqr_seq(
284285
}
285286
}
286287

288+
namespace {
289+
float fvec_cosine(const float* x, const float* y, size_t d) {
290+
return fvec_inner_product(x, y, d) / sqrtf(fvec_norm_L2sqr(y, d));
291+
}
292+
} // namespace
293+
294+
template <class ResultHandler>
295+
void exhaustive_cosine_seq(
296+
const float* x,
297+
const float* y,
298+
size_t d,
299+
size_t nx,
300+
size_t ny,
301+
ResultHandler& res,
302+
const BitsetView bitset) {
303+
using SingleResultHandler = typename ResultHandler::SingleResultHandler;
304+
int nt = std::min(int(nx), omp_get_max_threads());
305+
306+
#pragma omp parallel num_threads(nt)
307+
{
308+
SingleResultHandler resi(res);
309+
#pragma omp for
310+
for (int64_t i = 0; i < nx; i++) {
311+
const float* x_i = x + i * d;
312+
const float* y_j = y;
313+
resi.begin(i);
314+
for (size_t j = 0; j < ny; j++) {
315+
if (bitset.empty() || !bitset.test(j)) {
316+
float disij = fvec_cosine(x_i, y_j, d);
317+
resi.add_result(disij, j);
318+
}
319+
y_j += d;
320+
}
321+
resi.end();
322+
}
323+
}
324+
}
325+
287326
/** Find the nearest neighbors for nx queries in a set of ny vectors */
288327
template <class ResultHandler>
289328
void exhaustive_inner_product_blas(
@@ -426,6 +465,76 @@ void exhaustive_L2sqr_blas(
426465
}
427466
}
428467

468+
template <class ResultHandler>
469+
void exhaustive_cosine_blas(
470+
const float* x,
471+
const float* y,
472+
size_t d,
473+
size_t nx,
474+
size_t ny,
475+
ResultHandler& res,
476+
const BitsetView bitset = nullptr) {
477+
// BLAS does not like empty matrices
478+
if (nx == 0 || ny == 0)
479+
return;
480+
481+
/* block sizes */
482+
const size_t bs_x = distance_compute_blas_query_bs;
483+
const size_t bs_y = distance_compute_blas_database_bs;
484+
// const size_t bs_x = 16, bs_y = 16;
485+
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
486+
std::unique_ptr<float[]> y_norms(new float[nx]);
487+
std::unique_ptr<float[]> del2;
488+
489+
fvec_norms_L2(y_norms.get(), x, d, nx);
490+
491+
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
492+
size_t i1 = i0 + bs_x;
493+
if (i1 > nx)
494+
i1 = nx;
495+
496+
res.begin_multiple(i0, i1);
497+
498+
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
499+
size_t j1 = j0 + bs_y;
500+
if (j1 > ny)
501+
j1 = ny;
502+
/* compute the actual dot products */
503+
{
504+
float one = 1, zero = 0;
505+
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
506+
sgemm_("Transpose",
507+
"Not transpose",
508+
&nyi,
509+
&nxi,
510+
&di,
511+
&one,
512+
y + j0 * d,
513+
&di,
514+
x + i0 * d,
515+
&di,
516+
&zero,
517+
ip_block.get(),
518+
&nyi);
519+
}
520+
#pragma omp parallel for
521+
for (int64_t i = i0; i < i1; i++) {
522+
float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
523+
524+
for (size_t j = j0; j < j1; j++) {
525+
float ip = *ip_line;
526+
float dis = ip / y_norms[j];
527+
*ip_line = dis;
528+
ip_line++;
529+
}
530+
}
531+
res.add_results(j0, j1, ip_block.get(), bitset);
532+
}
533+
res.end_multiple();
534+
InterruptCallback::check();
535+
}
536+
}
537+
429538
template <class DistanceCorrection, class ResultHandler>
430539
static void knn_jaccard_blas(
431540
const float* x,
@@ -577,6 +686,34 @@ void knn_L2sqr(
577686
}
578687
}
579688

689+
void knn_cosine(
690+
const float* x,
691+
const float* y,
692+
size_t d,
693+
size_t nx,
694+
size_t ny,
695+
float_minheap_array_t* ha,
696+
const BitsetView bitset) {
697+
if (ha->k < distance_compute_min_k_reservoir) {
698+
HeapResultHandler<CMin<float, int64_t>> res(
699+
ha->nh, ha->val, ha->ids, ha->k);
700+
if (nx < distance_compute_blas_threshold) {
701+
exhaustive_L2sqr_IP_seq(x, y, d, nx, ny, res, fvec_cosine, bitset);
702+
} else {
703+
exhaustive_cosine_blas(x, y, d, nx, ny, res, bitset);
704+
}
705+
} else {
706+
ReservoirResultHandler<CMin<float, int64_t>> res(
707+
ha->nh, ha->val, ha->ids, ha->k);
708+
if (nx < distance_compute_blas_threshold) {
709+
exhaustive_L2sqr_IP_seq(
710+
x, y, d, nx, ny, res, fvec_inner_product, bitset);
711+
} else {
712+
exhaustive_cosine_blas(x, y, d, nx, ny, res, bitset);
713+
}
714+
}
715+
}
716+
580717
struct NopDistanceCorrection {
581718
float operator()(float dis, size_t /*qno*/, size_t /*bno*/) const {
582719
return dis;
@@ -640,6 +777,23 @@ void range_search_inner_product(
640777
}
641778
}
642779

780+
void range_search_cosine(
781+
const float* x,
782+
const float* y,
783+
size_t d,
784+
size_t nx,
785+
size_t ny,
786+
float radius,
787+
RangeSearchResult* res,
788+
const BitsetView bitset) {
789+
RangeSearchResultHandler<CMin<float, int64_t>> resh(res, radius);
790+
if (nx < distance_compute_blas_threshold) {
791+
exhaustive_cosine_seq(x, y, d, nx, ny, resh, bitset);
792+
} else {
793+
exhaustive_cosine_blas(x, y, d, nx, ny, resh, bitset);
794+
}
795+
}
796+
643797
/***************************************************************************
644798
* compute a subset of distances
645799
***************************************************************************/

thirdparty/faiss/faiss/utils/distances.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,15 @@ void knn_L2sqr(
199199
const float* y_norm2 = nullptr,
200200
const BitsetView bitset = nullptr);
201201

202+
void knn_cosine(
203+
const float* x,
204+
const float* y,
205+
size_t d,
206+
size_t nx,
207+
size_t ny,
208+
float_minheap_array_t* ha,
209+
const BitsetView bitset);
210+
202211
void knn_jaccard(
203212
const float* x,
204213
const float* y,
@@ -265,6 +274,16 @@ void range_search_inner_product(
265274
RangeSearchResult* result,
266275
const BitsetView bitset = nullptr);
267276

277+
void range_search_cosine(
278+
const float* x,
279+
const float* y,
280+
size_t d,
281+
size_t nx,
282+
size_t ny,
283+
float radius,
284+
RangeSearchResult* result,
285+
const BitsetView bitset = nullptr);
286+
268287
/***************************************************************************
269288
* PQ tables computations
270289
***************************************************************************/

0 commit comments

Comments
 (0)