@@ -33,10 +33,6 @@ expected<DataSetPtr>
3333BruteForce::Search (const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config,
3434 const BitsetView& bitset) {
3535 std::string metric_str = config[meta::METRIC_TYPE].get <std::string>();
36- bool is_cosine = IsMetricType (metric_str, metric::COSINE);
37- if (is_cosine) {
38- Normalize (*base_dataset);
39- }
4036
4137 auto xb = base_dataset->GetTensor ();
4238 auto nb = base_dataset->GetRows ();
@@ -54,6 +50,13 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
5450 auto labels = new int64_t [nq * topk];
5551 auto distances = new float [nq * topk];
5652
53+ bool is_cosine = IsMetricType (metric_str, metric::COSINE);
54+ std::unique_ptr<float []> norms = nullptr ;
55+ if (is_cosine) {
56+ norms = std::make_unique<float []>(nb);
57+ faiss::fvec_norms_L2 (norms.get (), (const float *)xb, dim, nb);
58+ }
59+
5760 auto pool = ThreadPool::GetGlobalThreadPool ();
5861 std::vector<folly::Future<Status>> futs;
5962 futs.reserve (nq);
@@ -71,13 +74,17 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
7174 }
7275 case faiss::METRIC_INNER_PRODUCT: {
7376 auto cur_query = (float *)xq + dim * index;
74- if (is_cosine) {
75- NormalizeVec (cur_query, dim);
76- }
7777 faiss::float_minheap_array_t buf{(size_t )1 , (size_t )topk, cur_labels, cur_distances};
7878 faiss::knn_inner_product (cur_query, (const float *)xb, dim, 1 , nb, &buf, bitset);
7979 break ;
8080 }
81+ case faiss::METRIC_COSINE: {
82+ auto cur_query = (float *)xq + dim * index;
83+ NormalizeVec (cur_query, dim);
84+ faiss::float_minheap_array_t buf{(size_t )1 , (size_t )topk, cur_labels, cur_distances};
85+ faiss::knn_cosine (cur_query, (const float *)xb, dim, 1 , nb, &buf, norms.get (), bitset);
86+ break ;
87+ }
8188 case faiss::METRIC_Jaccard: {
8289 auto cur_query = (const uint8_t *)xq + (dim / 8 ) * index;
8390 faiss::float_maxheap_array_t res = {size_t (1 ), size_t (topk), cur_labels, cur_distances};
@@ -123,10 +130,6 @@ Status
123130BruteForce::SearchWithBuf (const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t * ids, float * dis,
124131 const Json& config, const BitsetView& bitset) {
125132 std::string metric_str = config[meta::METRIC_TYPE].get <std::string>();
126- bool is_cosine = IsMetricType (metric_str, metric::COSINE);
127- if (is_cosine) {
128- Normalize (*base_dataset);
129- }
130133
131134 auto xb = base_dataset->GetTensor ();
132135 auto nb = base_dataset->GetRows ();
@@ -150,6 +153,13 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
150153
151154 auto faiss_metric_type = metric_type.value ();
152155
156+ bool is_cosine = IsMetricType (metric_str, metric::COSINE);
157+ std::unique_ptr<float []> norms = nullptr ;
158+ if (is_cosine) {
159+ norms = std::make_unique<float []>(nb);
160+ faiss::fvec_norms_L2 (norms.get (), (const float *)xb, dim, nb);
161+ }
162+
153163 auto pool = ThreadPool::GetGlobalThreadPool ();
154164 std::vector<folly::Future<Status>> futs;
155165 futs.reserve (nq);
@@ -167,13 +177,17 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
167177 }
168178 case faiss::METRIC_INNER_PRODUCT: {
169179 auto cur_query = (float *)xq + dim * index;
170- if (is_cosine) {
171- NormalizeVec (cur_query, dim);
172- }
173180 faiss::float_minheap_array_t buf{(size_t )1 , (size_t )topk, cur_labels, cur_distances};
174181 faiss::knn_inner_product (cur_query, (const float *)xb, dim, 1 , nb, &buf, bitset);
175182 break ;
176183 }
184+ case faiss::METRIC_COSINE: {
185+ auto cur_query = (float *)xq + dim * index;
186+ NormalizeVec (cur_query, dim);
187+ faiss::float_minheap_array_t buf{(size_t )1 , (size_t )topk, cur_labels, cur_distances};
188+ faiss::knn_cosine (cur_query, (const float *)xb, dim, 1 , nb, &buf, norms.get (), bitset);
189+ break ;
190+ }
177191 case faiss::METRIC_Jaccard: {
178192 auto cur_query = (const uint8_t *)xq + (dim / 8 ) * index;
179193 faiss::float_maxheap_array_t res = {size_t (1 ), size_t (topk), cur_labels, cur_distances};
@@ -221,11 +235,6 @@ expected<DataSetPtr>
221235BruteForce::RangeSearch (const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config,
222236 const BitsetView& bitset) {
223237 std::string metric_str = config[meta::METRIC_TYPE].get <std::string>();
224- bool is_cosine = IsMetricType (metric_str, metric::COSINE);
225- if (is_cosine) {
226- Normalize (*base_dataset);
227- }
228-
229238 auto xb = base_dataset->GetTensor ();
230239 auto nb = base_dataset->GetRows ();
231240 auto dim = base_dataset->GetDim ();
@@ -241,6 +250,12 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
241250 float range_filter = cfg.range_filter .value ();
242251
243252 ASSIGN_OR_RETURN (faiss::MetricType, faiss_metric_type, Str2FaissMetricType (cfg.metric_type .value ()));
253+ bool is_cosine = IsMetricType (metric_str, metric::COSINE);
254+ std::unique_ptr<float []> norms = nullptr ;
255+ if (is_cosine) {
256+ norms = std::make_unique<float []>(nb);
257+ faiss::fvec_norms_L2 (norms.get (), (const float *)xb, dim, nb);
258+ }
244259 auto pool = ThreadPool::GetGlobalThreadPool ();
245260
246261 std::vector<std::vector<int64_t >> result_id_array (nq);
@@ -262,12 +277,16 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
262277 case faiss::METRIC_INNER_PRODUCT: {
263278 is_ip = true ;
264279 auto cur_query = (float *)xq + dim * index;
265- if (is_cosine) {
266- NormalizeVec (cur_query, dim);
267- }
268280 faiss::range_search_inner_product (cur_query, (const float *)xb, dim, 1 , nb, radius, &res, bitset);
269281 break ;
270282 }
283+ case faiss::METRIC_COSINE: {
284+ auto cur_query = (float *)xq + dim * index;
285+ NormalizeVec (cur_query, dim);
286+ faiss::range_search_cosine (cur_query, (const float *)xb, dim, 1 , nb, radius, &res, norms.get (),
287+ bitset);
288+ break ;
289+ }
271290 case faiss::METRIC_Jaccard: {
272291 auto cur_query = (const uint8_t *)xq + (dim / 8 ) * index;
273292 faiss::binary_range_search<faiss::CMin<float , int64_t >, float >(
0 commit comments