diff --git a/.github/workflows/build_and_test.yaml b/.github/workflows/build_and_test.yaml index 412ca378..5f81f0d2 100644 --- a/.github/workflows/build_and_test.yaml +++ b/.github/workflows/build_and_test.yaml @@ -24,6 +24,7 @@ jobs: git config --global --add safe.directory '*' eval "$(conda shell.bash hook)" conda activate quake-env + conda install libarrow-all=19.0.1 -c conda-forge mkdir -p build cd build cmake -DCMAKE_BUILD_TYPE=${{ env.BUILD_TYPE }} \ @@ -46,6 +47,7 @@ jobs: git config --global --add safe.directory '*' eval "$(conda shell.bash hook)" conda activate quake-env + conda install libarrow-all=19.0.1 -c conda-forge pip install --no-use-pep517 . pip install pytest python -m pytest test/python \ No newline at end of file diff --git a/.gitignore b/.gitignore index ed8ebf58..37bbc2a3 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,4 @@ -__pycache__ \ No newline at end of file +__pycache__ +.vscode +build/ +quake.egg-info/ \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index 265bcb27..1147a27b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -6,4 +6,4 @@ url = https://github.com/pybind/pybind11.git [submodule "src/cpp/third_party/concurrentqueue"] path = src/cpp/third_party/concurrentqueue - url = https://github.com/cameron314/concurrentqueue.git + url = https://github.com/cameron314/concurrentqueue.git \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 40c4183e..f7f2db8e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -93,11 +93,13 @@ endif() # Find Required Packages # --------------------------------------------------------------- find_package(Torch REQUIRED) +find_package(Arrow REQUIRED) # Corrected to uppercase 'Arrow' find_package(Python3 COMPONENTS Development Interpreter REQUIRED) message(STATUS "Torch include dir: ${TORCH_INCLUDE_DIRS}") message(STATUS "Torch libraries: ${TORCH_LIBRARIES}") message(STATUS "Python include dir: ${Python3_INCLUDE_DIRS}") +message(STATUS "Arrow include dir: ${ARROW_INCLUDE_DIR}") set(PYTHON_INCLUDE_DIR ${Python3_INCLUDE_DIRS}) @@ -115,6 +117,7 @@ target_include_directories(${PROJECT_NAME} ${TORCH_INCLUDE_DIRS} ${project_INCLUDE_DIR} ${project_THIRD_PARTY_DIR}/concurrentqueue/ + ${ARROW_INCLUDE_DIR} faiss ) @@ -150,6 +153,7 @@ else() endif() target_link_libraries(${PROJECT_NAME} PUBLIC ${LINK_LIBS}) +target_link_libraries(${PROJECT_NAME} PUBLIC Arrow::arrow_shared) IF(CMAKE_BUILD_TYPE MATCHES Debug AND QUAKE_USE_TSAN) message("Using thread sanitizer") diff --git a/environments/ubuntu-latest/conda.yaml b/environments/ubuntu-latest/conda.yaml index d4e95710..b6e94d9a 100644 --- a/environments/ubuntu-latest/conda.yaml +++ b/environments/ubuntu-latest/conda.yaml @@ -10,6 +10,7 @@ dependencies: - faiss-cpu - matplotlib - pytest + - libarrow-all=19.0.1 - pip - pip: - sphinx diff --git a/src/cpp/include/clustering.h b/src/cpp/include/clustering.h index 5cdc6c82..905fd68e 100644 --- a/src/cpp/include/clustering.h +++ b/src/cpp/include/clustering.h @@ -28,7 +28,9 @@ shared_ptr kmeans(Tensor vectors, int n_clusters, MetricType metric_type, int niter = 5, - Tensor initial_centroids = Tensor()); + std::shared_ptr attributes_table = nullptr, + Tensor initial_centroids = Tensor() + ); /** diff --git a/src/cpp/include/common.h b/src/cpp/include/common.h index 9dc9d58e..e068a836 100644 --- a/src/cpp/include/common.h +++ b/src/cpp/include/common.h @@ -33,6 +33,12 @@ #include #include +#include +#include +#include +#include +#include + #ifdef QUAKE_USE_NUMA #include #include @@ -81,6 +87,7 @@ constexpr bool DEFAULT_PRECOMPUTED = true; ///< Default flag to us constexpr float DEFAULT_INITIAL_SEARCH_FRACTION = 0.02f; ///< Default initial fraction of partitions to search. constexpr float DEFAULT_RECOMPUTE_THRESHOLD = 0.001f; ///< Default threshold to trigger recomputation of search parameters. constexpr int DEFAULT_APS_FLUSH_PERIOD_US = 100; ///< Default period (in microseconds) for flushing the APS buffer. +constexpr int DEFAULT_PRICE_THRESHOLD = INT_MAX; // Default constants for maintenance policy parameters constexpr const char* DEFAULT_MAINTENANCE_POLICY = "query_cost"; ///< Default maintenance policy type. @@ -164,6 +171,12 @@ inline string metric_type_to_str(faiss::MetricType metric) { } } +enum class FilteringType { + PRE_FILTERING, + POST_FILTERING, + IN_FILTERING +}; + /** * @brief Parameters for the search operation */ @@ -178,6 +191,8 @@ struct SearchParams { float recompute_threshold = DEFAULT_RECOMPUTE_THRESHOLD; float initial_search_fraction = DEFAULT_INITIAL_SEARCH_FRACTION; int aps_flush_period_us = DEFAULT_APS_FLUSH_PERIOD_US; + int price_threshold = DEFAULT_PRICE_THRESHOLD; + FilteringType filteringType = FilteringType::IN_FILTERING; SearchParams() = default; }; @@ -250,6 +265,7 @@ struct Clustering { Tensor partition_ids; vector vectors; vector vector_ids; + vector> attributes_tables; int64_t ntotal() const { int64_t n = 0; diff --git a/src/cpp/include/dynamic_inverted_list.h b/src/cpp/include/dynamic_inverted_list.h index e940127e..a94d34c5 100644 --- a/src/cpp/include/dynamic_inverted_list.h +++ b/src/cpp/include/dynamic_inverted_list.h @@ -135,6 +135,7 @@ namespace faiss { * @param n_entry Number of entries to add. * @param ids Pointer to the vector IDs. * @param codes Pointer to the encoded vectors. + * @param data_frames Arrow data frames for the attributes. * @return Number of entries added. * @throws std::runtime_error if the partition does not exist. */ @@ -142,7 +143,28 @@ namespace faiss { size_t list_no, size_t n_entry, const idx_t *ids, - const uint8_t *codes) override; + const uint8_t *codes, + std::shared_ptr attributes_table + ); + + /** + * @brief Append new entries (codes and IDs) to a partition. + * + * @param list_no Partition number. + * @param n_entry Number of entries to add. + * @param ids Pointer to the vector IDs. + * @param codes Pointer to the encoded vectors. + * @param data_frames Optional Arrow data frames. + * @return Number of entries added. + * @throws std::runtime_error if the partition does not exist. + */ + size_t add_entries( + size_t list_no, + size_t n_entry, + const idx_t *ids, + const uint8_t *codes + ) ; + /** * @brief Update existing entries in a partition. diff --git a/src/cpp/include/index_partition.h b/src/cpp/include/index_partition.h index 63b022c6..98e79c65 100644 --- a/src/cpp/include/index_partition.h +++ b/src/cpp/include/index_partition.h @@ -27,6 +27,7 @@ class IndexPartition { uint8_t* codes_ = nullptr; ///< Pointer to the encoded vectors (raw memory block) idx_t* ids_ = nullptr; ///< Pointer to the vector IDs + std::shared_ptr attributes_table_ = {}; std::unordered_map id_to_index_; ///< Map of vector ID to index @@ -88,7 +89,7 @@ class IndexPartition { * @param new_ids Pointer to the new vector IDs. * @param new_codes Pointer to the new encoded vectors. */ - void append(int64_t n_entry, const idx_t* new_ids, const uint8_t* new_codes); + void append(int64_t n_entry, const idx_t* new_ids, const uint8_t* new_codes, std::shared_ptr attributes_table=nullptr); /** * @brief Update existing entries in place. @@ -111,6 +112,15 @@ class IndexPartition { */ void remove(int64_t index); + /** + * @brief Remove the associated attribute of an entry from the partition. Used in conjuntion with the remove(index) function + * + * Removes the attribute by performing masking & filtering + * + * @param index Index of the vector to remove. + */ + void removeAttribute(int64_t index); + /** * @brief Resize the partition. * diff --git a/src/cpp/include/list_scanning.h b/src/cpp/include/list_scanning.h index d33e540a..2e9b130e 100644 --- a/src/cpp/include/list_scanning.h +++ b/src/cpp/include/list_scanning.h @@ -148,6 +148,10 @@ class TypedTopKBuffer { partitions_scanned_.fetch_add(1, std::memory_order_relaxed); } + void remove(int rejected_index) { + topk_[rejected_index] = topk_[--curr_offset_]; + } + DistanceType flush() { std::lock_guard buffer_lock(buffer_mutex_); if (curr_offset_ > k_) { @@ -280,11 +284,22 @@ inline void scan_list_with_ids_l2(const float *query_vec, const int64_t *list_ids, int list_size, int d, - TopkBuffer &buffer) { + TopkBuffer &buffer, + bool* bitmap = nullptr) { const float *vec = list_vecs; - for (int l = 0; l < list_size; l++) { - buffer.add(sqrt(faiss::fvec_L2sqr(query_vec, vec, d)), list_ids[l]); - vec += d; + + if (bitmap == nullptr) { + for (int l = 0; l < list_size; l++) { + buffer.add(sqrt(faiss::fvec_L2sqr(query_vec, vec, d)), list_ids[l]); + vec += d; + } + } else { + for (int l = 0; l < list_size; l++) { + if (bitmap[l]) { + buffer.add(sqrt(faiss::fvec_L2sqr(query_vec, vec, d)), list_ids[l]); + } + vec += d; + } } } @@ -295,7 +310,8 @@ inline void scan_list(const float *query_vec, int list_size, int d, TopkBuffer &buffer, - faiss::MetricType metric = faiss::METRIC_L2) { + faiss::MetricType metric = faiss::METRIC_L2, + bool* bitmap = nullptr) { // Dispatch based on metric type and whether list_ids is provided. if (metric == faiss::METRIC_INNER_PRODUCT) { if (list_ids == nullptr) @@ -306,7 +322,7 @@ inline void scan_list(const float *query_vec, if (list_ids == nullptr) scan_list_no_ids_l2(query_vec, list_vecs, list_size, d, buffer); else - scan_list_with_ids_l2(query_vec, list_vecs, list_ids, list_size, d, buffer); + scan_list_with_ids_l2(query_vec, list_vecs, list_ids, list_size, d, buffer, bitmap); } } diff --git a/src/cpp/include/partition_manager.h b/src/cpp/include/partition_manager.h index 9a5018a2..9764aeb8 100644 --- a/src/cpp/include/partition_manager.h +++ b/src/cpp/include/partition_manager.h @@ -9,6 +9,7 @@ #include #include +#include class QuakeIndex; @@ -56,7 +57,15 @@ class PartitionManager { * @param assignments Tensor of shape [num_vectors] containing partition IDs. If not provided, vectors are assigned using the parent index. * @return Timing information for the operation. */ - shared_ptr add(const Tensor &vectors, const Tensor &vector_ids, const Tensor &assignments = Tensor(), bool check_uniques = true); + shared_ptr add(const Tensor &vectors, const Tensor &vector_ids, const Tensor &assignments = Tensor(), bool check_uniques = true,std::shared_ptr attributes_table = {}); + + /** + * @brief Filter the appropriate row from the attribute table + * @param table Arrow table for the attributes. + * @param vector_id Vector_id by which we are filtering. + * @return Table containing only the row pertaining to the vector_id + */ + std::shared_ptr filterRowById(std::shared_ptr table, int64_t vector_id); /** * @brief Remove vectors by ID from the index. diff --git a/src/cpp/include/quake_index.h b/src/cpp/include/quake_index.h index c0cccc21..f39990b2 100644 --- a/src/cpp/include/quake_index.h +++ b/src/cpp/include/quake_index.h @@ -47,7 +47,7 @@ class QuakeIndex { * @param build_params Parameters for building the index. * @return Timing information for the build. */ - shared_ptr build(Tensor x, Tensor ids, shared_ptr build_params); + shared_ptr build(Tensor x, Tensor ids, shared_ptr build_params, std::shared_ptr attributes_table = nullptr); /** * @brief Search for vectors in the index. @@ -73,9 +73,10 @@ class QuakeIndex { * @brief Add vectors to the index. * @param x Tensor of shape [num_vectors, dimension]. * @param ids Tensor of shape [num_vectors]. + * * @param attributes_table Associated attribute_table for each vector_id. * @return Timing information for the add operation. */ - shared_ptr add(Tensor x, Tensor ids); + shared_ptr add(Tensor x, Tensor ids, std::shared_ptr attributes_table = {}); /** * @brief Remove vectors from the index. diff --git a/src/cpp/src/clustering.cpp b/src/cpp/src/clustering.cpp index 9b2d9cd3..94142589 100644 --- a/src/cpp/src/clustering.cpp +++ b/src/cpp/src/clustering.cpp @@ -5,16 +5,20 @@ // - Use descriptive variable names #include "clustering.h" -#include -#include "faiss/Clustering.h" #include "index_partition.h" #include +#include +#include "faiss/Clustering.h" +#include +#include +#include shared_ptr kmeans(Tensor vectors, Tensor ids, int n_clusters, MetricType metric_type, int niter, + std::shared_ptr attributes_table, Tensor /* initial_centroids */) { // Ensure enough vectors are available and sizes match. assert(vectors.size(0) >= n_clusters * 2); @@ -54,10 +58,59 @@ shared_ptr kmeans(Tensor vectors, // Partition vectors and ids by cluster. vector cluster_vectors(n_clusters); vector cluster_ids(n_clusters); + vector> cluster_attributes_tables(n_clusters); + for (int i = 0; i < n_clusters; i++) { - cluster_vectors[i] = vectors.index({assignments == i}); - cluster_ids[i] = ids.index({assignments == i}); + auto mask = (assignments == i); + + // List of vectors present in the cluster i + cluster_vectors[i] = vectors.index({mask}); + // List of vectorIds present in the cluster i + cluster_ids[i] = ids.index({mask}); + + if(attributes_table == nullptr) { + cluster_attributes_tables[i] = nullptr; + continue; + } + + auto cluster_ids_tensor = cluster_ids[i]; // Assuming this is a tensor with IDs + std::vector cluster_ids_vec(cluster_ids_tensor.data(), + cluster_ids_tensor.data() + cluster_ids_tensor.numel()); + + // Convert to Arrow Array + arrow::Int64Builder id_builder; + id_builder.AppendValues(cluster_ids_vec); + std::shared_ptr cluster_ids_array; + id_builder.Finish(&cluster_ids_array); + + // Get the "id" column from the attributes table + std::shared_ptr id_column = attributes_table->GetColumnByName("id"); + + auto lookup_options = std::make_shared(cluster_ids_array); + // Apply set lookup to filter rows + auto result = arrow::compute::CallFunction( + "index_in", + {id_column->chunk(0)}, + lookup_options.get() + ); + + auto index_array = std::static_pointer_cast(result->make_array()); + + auto mask_result = arrow::compute::CallFunction( + "not_equal", + {index_array, arrow::MakeScalar(-1)} + ); + + // Convert result to a Boolean mask + auto mask_table = std::static_pointer_cast(mask_result->make_array()); + + // Filter the table using the mask + auto filtered_table_result = arrow::compute::Filter(attributes_table, mask_table); + + cluster_attributes_tables[i] = filtered_table_result->table(); } + + Tensor partition_ids = torch::arange(n_clusters, torch::kInt64); shared_ptr clustering = std::make_shared(); @@ -65,6 +118,7 @@ shared_ptr kmeans(Tensor vectors, clustering->partition_ids = partition_ids; clustering->vectors = cluster_vectors; clustering->vector_ids = cluster_ids; + clustering->attributes_tables = cluster_attributes_tables; delete index_ptr; return clustering; diff --git a/src/cpp/src/dynamic_inverted_list.cpp b/src/cpp/src/dynamic_inverted_list.cpp index 0ea36450..26f67f3f 100644 --- a/src/cpp/src/dynamic_inverted_list.cpp +++ b/src/cpp/src/dynamic_inverted_list.cpp @@ -152,7 +152,9 @@ namespace faiss { size_t list_no, size_t n_entry, const idx_t *ids, - const uint8_t *codes) { + const uint8_t *codes, + shared_ptr attributes_table + ) { if (n_entry == 0) { return 0; } @@ -168,10 +170,19 @@ namespace faiss { part->set_code_size(static_cast(code_size)); } - part->append((int64_t) n_entry, ids, codes); + part->append((int64_t) n_entry, ids, codes, attributes_table); return n_entry; } + size_t DynamicInvertedLists::add_entries( + size_t list_no, + size_t n_entry, + const idx_t *ids, + const uint8_t *codes + ) { + return add_entries(list_no, n_entry, ids, codes, nullptr); + } + void DynamicInvertedLists::update_entries( size_t list_no, size_t offset, diff --git a/src/cpp/src/index_partition.cpp b/src/cpp/src/index_partition.cpp index 8604a939..fa11b1e5 100644 --- a/src/cpp/src/index_partition.cpp +++ b/src/cpp/src/index_partition.cpp @@ -5,6 +5,9 @@ // - Use descriptive variable names #include +#include +#include +#include IndexPartition::IndexPartition(int64_t num_vectors, uint8_t* codes, @@ -49,12 +52,20 @@ void IndexPartition::set_code_size(int64_t code_size) { code_size_ = code_size; } -void IndexPartition::append(int64_t n_entry, const idx_t* new_ids, const uint8_t* new_codes) { +void IndexPartition::append(int64_t n_entry, const idx_t* new_ids, const uint8_t* new_codes, std::shared_ptr attributes_table) { if (n_entry <= 0) return; ensure_capacity(num_vectors_ + n_entry); const size_t code_bytes = static_cast(code_size_); std::memcpy(codes_ + num_vectors_ * code_bytes, new_codes, n_entry * code_bytes); std::memcpy(ids_ + num_vectors_, new_ids, n_entry * sizeof(idx_t)); + // append attributes_table to attributes_table_ + if (attributes_table_ == nullptr) { + attributes_table_ = attributes_table; + } else if (attributes_table != nullptr) { + // Concatenate the new attributes table with the existing one + auto concatenated_table = arrow::ConcatenateTables({attributes_table_, attributes_table}); + attributes_table_ = concatenated_table.ValueOrDie(); + } num_vectors_ += n_entry; // @@ -99,6 +110,52 @@ void IndexPartition::remove(int64_t index) { ids_[index] = ids_[last_idx]; num_vectors_--; + + removeAttribute(index); +} + +// https://github.com/apache/arrow/issues/44243 +// Arrow data is immutable. So you can't delete a row from existing Arrow data. +// You need to create a new Arrow data that doesn't have the target row. +void IndexPartition::removeAttribute(int64_t target_id) { + + if(attributes_table_ == nullptr) { + // if there is no table, nothing to remove, so exit gracefully + return; + } + + int64_t original_size = attributes_table_->num_rows(); + if(original_size==0){ + std::cerr << "No attributes found in the table.\n"; + return; + } + + + auto id_column = attributes_table_->GetColumnByName("id"); + if (!id_column) { + std::cerr << "Column 'id' not found in table." << std::endl; + return; + } + + // Create a filter expression (id != target_id) + auto column_data = id_column->chunk(0); + auto scalar_value = arrow::MakeScalar(target_id); + auto filter_expr = arrow::compute::CallFunction("not_equal", {column_data, scalar_value}); + + + if (!filter_expr.ok()) { + std::cerr << "Error creating filter expression: " << filter_expr.status().ToString() << std::endl; + return; + } + + // Apply the filter + auto result = arrow::compute::Filter(attributes_table_, filter_expr.ValueOrDie()); + if (!result.ok()) { + std::cerr << "Error filtering table: " << result.status().ToString() << std::endl; + return; + } + + attributes_table_ = result.ValueOrDie().table(); } void IndexPartition::resize(int64_t new_capacity) { @@ -271,4 +328,4 @@ T* IndexPartition::allocate_memory(size_t num_elements, int numa_node) { throw std::bad_alloc(); } return ptr; -} \ No newline at end of file +} diff --git a/src/cpp/src/partition_manager.cpp b/src/cpp/src/partition_manager.cpp index 3b738856..8659f3f2 100644 --- a/src/cpp/src/partition_manager.cpp +++ b/src/cpp/src/partition_manager.cpp @@ -11,6 +11,9 @@ #include #include #include "quake_index.h" +#include +#include +#include using std::runtime_error; @@ -77,6 +80,10 @@ void PartitionManager::init_partitions( for (int64_t i = 0; i < nlist; i++) { Tensor v = clustering->vectors[i]; Tensor id = clustering->vector_ids[i]; + std::shared_ptr attributes_table = nullptr; + if (clustering->attributes_tables.size() > i) { + attributes_table = clustering->attributes_tables[i]; + } if (v.size(0) != id.size(0)) { throw runtime_error("[PartitionManager] init_partitions: mismatch in v.size(0) vs id.size(0)."); } @@ -103,7 +110,8 @@ void PartitionManager::init_partitions( partition_ids_accessor[i], count, id.data_ptr(), - as_uint8_ptr(v) + as_uint8_ptr(v), + attributes_table ); if (debug_) { std::cout << "[PartitionManager] init_partitions: Added " << count @@ -120,13 +128,49 @@ void PartitionManager::init_partitions( } } +std::shared_ptr PartitionManager::filterRowById( + std::shared_ptr table, + int64_t target_id +) { + if(table==nullptr ) { + return nullptr; + } + + auto id_column = table->GetColumnByName("id"); + if (!id_column) { + std::cerr << "Column 'id' not found in table." << std::endl; + return nullptr; + } + + + // Create a filter expression (id == target_id) + arrow::Datum column_data = id_column->chunk(0); + arrow::Datum scalar_value = arrow::MakeScalar(target_id); + auto filter_expr = arrow::compute::CallFunction("equal", {column_data, scalar_value}); + + if (!filter_expr.ok()) { + std::cerr << "Error creating filter expression: " << filter_expr.status().ToString() << std::endl; + return nullptr; + } + + // Apply the filter + auto result = arrow::compute::Filter(table, filter_expr.ValueOrDie()); + if (!result.ok()) { + std::cerr << "Error filtering table: " << result.status().ToString() << std::endl; + return nullptr; + } + + return result.ValueOrDie().table(); +} + + shared_ptr PartitionManager::add( const Tensor &vectors, const Tensor &vector_ids, const Tensor &assignments, - bool check_uniques + bool check_uniques, + std::shared_ptr attributes_table ) { - auto timing_info = std::make_shared(); if (debug_) { @@ -145,9 +189,19 @@ shared_ptr PartitionManager::add( if (!vectors.defined() || !vector_ids.defined()) { throw runtime_error("[PartitionManager] add: vectors or vector_ids is undefined."); } + if (vectors.size(0) != vector_ids.size(0)) { throw runtime_error("[PartitionManager] add: mismatch in vectors.size(0) and vector_ids.size(0)."); } + + if(attributes_table!=nullptr && attributes_table->num_rows()!= vector_ids.size(0)){ + throw runtime_error("[PartitionManager] add: mismatch in attributes_table and vector_ids size."); + } + + if(attributes_table!=nullptr && !attributes_table->GetColumnByName("id")){ + throw runtime_error("[PartitionManager] add: No vector_id column in attributes_table"); + } + int64_t n = vectors.size(0); if (n == 0) { if (debug_) { @@ -249,12 +303,16 @@ shared_ptr PartitionManager::add( << " into partition " << pid << std::endl; } + + std::shared_ptr filtered_table_result = filterRowById(attributes_table, id_accessor[i]); partition_store_->add_entries( pid, /*n_entry=*/1, id_ptr + i, - code_ptr + i * code_size_bytes + code_ptr + i * code_size_bytes, + filtered_table_result ); + } auto e3 = std::chrono::high_resolution_clock::now(); timing_info->modify_time_us = std::chrono::duration_cast(e3 - s3).count(); @@ -309,6 +367,7 @@ shared_ptr PartitionManager::remove(const Tensor &ids) { timing_info->find_partition_time_us = std::chrono::duration_cast(e2 - s2).count(); auto s3 = std::chrono::high_resolution_clock::now(); + partition_store_->remove_vectors(to_remove); if (debug_) { std::cout << "[PartitionManager] remove: Completed removal." << std::endl; @@ -737,4 +796,4 @@ void PartitionManager::load(const string &path) { if (debug_) { std::cout << "[PartitionManager] load: Load complete." << std::endl; } -} \ No newline at end of file +} diff --git a/src/cpp/src/quake_index.cpp b/src/cpp/src/quake_index.cpp index 63f945f3..3959911d 100644 --- a/src/cpp/src/quake_index.cpp +++ b/src/cpp/src/quake_index.cpp @@ -26,7 +26,7 @@ QuakeIndex::~QuakeIndex() { maintenance_policy_params_ = nullptr; } -shared_ptr QuakeIndex::build(Tensor x, Tensor ids, shared_ptr build_params) { +shared_ptr QuakeIndex::build(Tensor x, Tensor ids, shared_ptr build_params, std::shared_ptr attributes_table) { build_params_ = build_params; metric_ = str_to_metric_type(build_params_->metric); @@ -46,7 +46,8 @@ shared_ptr QuakeIndex::build(Tensor x, Tensor ids, shared_ptrnlist, metric_, - build_params_->niter + build_params_->niter, + attributes_table ); auto e1 = std::chrono::high_resolution_clock::now(); timing_info->train_time_us = std::chrono::duration_cast(e1 - s1).count(); @@ -73,6 +74,7 @@ shared_ptr QuakeIndex::build(Tensor x, Tensor ids, shared_ptrcentroids = x.mean(0, true); clustering->vectors = {x}; clustering->vector_ids = {ids}; + clustering->attributes_tables = {attributes_table}; partition_manager_->init_partitions(parent_, clustering); } @@ -118,12 +120,12 @@ Tensor QuakeIndex::get(Tensor ids) { return partition_manager_->get(ids); } -shared_ptr QuakeIndex::add(Tensor x, Tensor ids) { +shared_ptr QuakeIndex::add(Tensor x, Tensor ids, std::shared_ptr attributes_table) { if (!partition_manager_) { throw std::runtime_error("[QuakeIndex::add()] No partition manager. Build the index first."); } - auto modify_info = partition_manager_->add(x, ids); + auto modify_info = partition_manager_->add(x, ids, Tensor(), true, attributes_table); modify_info->n_vectors = x.size(0); return modify_info; } diff --git a/src/cpp/src/query_coordinator.cpp b/src/cpp/src/query_coordinator.cpp index 67112a1f..9c4fde0a 100644 --- a/src/cpp/src/query_coordinator.cpp +++ b/src/cpp/src/query_coordinator.cpp @@ -10,6 +10,10 @@ #include #include #include +#include +#include +#include +#include // Constructor QueryCoordinator::QueryCoordinator(shared_ptr parent, @@ -468,8 +472,26 @@ shared_ptr QueryCoordinator::worker_scan( return search_result; } -shared_ptr QueryCoordinator::serial_scan(Tensor x, Tensor partition_ids, - shared_ptr search_params) { + +bool* create_bitmap(std::unordered_map id_to_price, int64_t* list_ids, + int64_t num_ids, shared_ptr search_params) { + + bool* bitmap = new bool[num_ids]; + + for (int64_t i = 0; i < num_ids; i++) { + int64_t id = list_ids[i]; + if (id_to_price.count(id) && id_to_price[id] <= search_params->price_threshold) { + bitmap[i] = 1; + } else { + bitmap[i] = 0; + } + } + + return bitmap; +} + +shared_ptr QueryCoordinator::serial_scan(Tensor x, Tensor partition_ids_to_scan, + shared_ptr search_params) { if (!partition_manager_) { throw std::runtime_error("[QueryCoordinator::serial_scan] partition_manager_ is null."); } @@ -501,10 +523,10 @@ shared_ptr QueryCoordinator::serial_scan(Tensor x, Tensor partitio bool use_aps = (search_params->recall_target > 0.0 && parent_); // Ensure partition_ids is 2D. - if (partition_ids.dim() == 1) { - partition_ids = partition_ids.unsqueeze(0).expand({num_queries, partition_ids.size(0)}); + if (partition_ids_to_scan.dim() == 1) { + partition_ids_to_scan = partition_ids_to_scan.unsqueeze(0).expand({num_queries, partition_ids_to_scan.size(0)}); } - auto partition_ids_accessor = partition_ids.accessor(); + auto partition_ids_accessor = partition_ids_to_scan.accessor(); float *x_ptr = x.data_ptr(); // Allocate per-query result vectors. @@ -516,7 +538,7 @@ shared_ptr QueryCoordinator::serial_scan(Tensor x, Tensor partitio // Create a local TopK buffer for query q. auto topk_buf = std::make_shared(k, is_descending); const float* query_vec = x_ptr + q * dimension; - int num_parts = partition_ids.size(1); + int num_parts = partition_ids_to_scan.size(1); vector boundary_distances; vector partition_probs; @@ -525,10 +547,10 @@ shared_ptr QueryCoordinator::serial_scan(Tensor x, Tensor partitio query_radius = -1000000.0; } - Tensor partition_sizes = partition_manager_->get_partition_sizes(partition_ids[q]); + Tensor partition_sizes = partition_manager_->get_partition_sizes(partition_ids_to_scan[q]); if (use_aps) { - vector partition_ids_to_scan_vec = std::vector(partition_ids[q].data_ptr(), - partition_ids[q].data_ptr() + partition_ids[q].size(0)); + vector partition_ids_to_scan_vec = std::vector(partition_ids_to_scan[q].data_ptr(), + partition_ids_to_scan[q].data_ptr() + partition_ids_to_scan[q].size(0)); vector cluster_centroids = parent_->partition_manager_->get_vectors(partition_ids_to_scan_vec); boundary_distances = compute_boundary_distances(x[q], cluster_centroids, @@ -544,15 +566,48 @@ shared_ptr QueryCoordinator::serial_scan(Tensor x, Tensor partitio start_time = high_resolution_clock::now(); float *list_vectors = (float *) partition_manager_->partition_store_->get_codes(pi); int64_t *list_ids = (int64_t *) partition_manager_->partition_store_->get_ids(pi); + std::shared_ptr partition_attributes_table = + partition_manager_->partition_store_->partitions_[pi]->attributes_table_; int64_t list_size = partition_manager_->partition_store_->list_size(pi); + std::shared_ptr id_array = nullptr; + std::shared_ptr price_array = nullptr; + + std::unordered_map id_to_price; + + if (partition_attributes_table != nullptr) { + id_array = std::static_pointer_cast(partition_attributes_table->GetColumnByName("id")->chunk(0)); + price_array = std::static_pointer_cast(partition_attributes_table->GetColumnByName("price")->chunk(0)); + int64_t length = id_array->length(); + for (int64_t i = 0; i < id_array->length(); i++) { + id_to_price[id_array->Value(i)] = price_array->Value(i); + } + } + + bool* bitmap = nullptr; + + if (search_params->filteringType == FilteringType::PRE_FILTERING) { + bitmap = create_bitmap(id_to_price, list_ids, list_size, search_params); + } + scan_list(query_vec, list_vectors, list_ids, partition_manager_->partition_store_->list_size(pi), dimension, *topk_buf, - metric_); + metric_, + bitmap); + if (search_params->filteringType ==FilteringType::POST_FILTERING) { + auto scanned_vectors = topk_buf->topk_; + int buffer_size = topk_buf->curr_offset_; + for (int i = 0;i < buffer_size; i++) { + auto vector_id = scanned_vectors[i].second; + if (id_to_price.count(vector_id) and id_to_price[vector_id] > search_params->price_threshold) { + topk_buf->remove(i); + } + } + } float curr_radius = topk_buf->get_kth_distance(); float percent_change = abs(curr_radius - query_radius) / curr_radius; @@ -647,6 +702,45 @@ shared_ptr QueryCoordinator::search(Tensor x, shared_ptrfilteringType == FilteringType::POST_FILTERING and search_params->price_threshold != INT_MAX and parent_!=nullptr) { + // auto id_data = search_result->ids.data(); + // auto distance_data = search_result->distances.data(); + // int64_t num_results = search_result->ids.size(0); + + // std::vector filtered_ids; + // std::vector filtered_distances; + + // std::shared_ptr id_column = attributes_table->GetColumnByName("id"); + // std::shared_ptr price_column = attributes_table->GetColumnByName("price"); + + // for (int64_t i = 0; i < num_results; i++) { + // int64_t id = id_data[i]; + + // // Search for this ID in the attribute table + // std::shared_ptr found_row; + // auto id_scalar = arrow::MakeScalar(id); + // auto price_scalar = arrow::MakeScalar(search_params->price_threshold); + + // auto equal_condition = arrow::compute::CallFunction("equal", {id_column->chunk(0), id_scalar}); + // auto less_equal_condition = arrow::compute::CallFunction("less_equal", {price_column->chunk(0), price_scalar}); + + // auto combined_condition = arrow::compute::CallFunction("and", {equal_condition.ValueOrDie(), less_equal_condition.ValueOrDie()}); + + // auto mask_table = std::static_pointer_cast(combined_condition->make_array()); + + // auto filter_result = arrow::compute::Filter(attributes_table, combined_condition.ValueOrDie()); + + // if (filter_result.ok()) { + // filtered_ids.push_back(id); + // filtered_distances.push_back(distance_data[i]); + // } + // } + + // search_result->ids = torch::tensor(filtered_ids, torch::kInt64); + // search_result->distances = torch::tensor(filtered_distances, torch::kFloat); + // } + search_result->timing_info->parent_info = parent_timing_info; auto end = high_resolution_clock::now(); diff --git a/test/cpp/CMakeLists.txt b/test/cpp/CMakeLists.txt index e21a8a3f..a45605ec 100644 --- a/test/cpp/CMakeLists.txt +++ b/test/cpp/CMakeLists.txt @@ -6,7 +6,14 @@ ADD_EXECUTABLE(quake_tests ${SRCS}) TARGET_LINK_LIBRARIES(quake_tests ${PROJECT_NAME} - gtest gtest_main + gtest gtest_main + Arrow::arrow_shared +) + +# Include directories propagated from dependencies +target_include_directories(quake_tests + PRIVATE + ${ARROW_INCLUDE_DIR} # Arrow headers (if not already handled by Arrow::Arrow) ) add_test(NAME quake_tests COMMAND quake_tests WORKING_DIRECTORY ${QUAKE_TEST_HOME}) \ No newline at end of file diff --git a/test/cpp/benchmark.cpp b/test/cpp/benchmark.cpp index 71426869..8e666c4b 100644 --- a/test/cpp/benchmark.cpp +++ b/test/cpp/benchmark.cpp @@ -43,6 +43,36 @@ static Tensor generate_ids(int64_t num, int64_t start = 0) { return torch::arange(start, start + num, torch::kInt64); } +static std::shared_ptr generate_data_frame(int64_t num_vectors, torch::Tensor ids) { + arrow::MemoryPool* pool = arrow::default_memory_pool(); + + // Builders for the "price" and "id" columns + arrow::DoubleBuilder price_builder(pool); + arrow::Int64Builder id_builder(pool); + + // Append values to the builders + for (int64_t i = 0; i < num_vectors; i++) { + price_builder.Append(static_cast(i) * 1.5); // Price column + id_builder.Append(ids[i].item()); // ID column from the input tensor + } + + // Finalize the arrays + std::shared_ptr price_array; + std::shared_ptr id_array; + price_builder.Finish(&price_array); + id_builder.Finish(&id_array); + + // Define the schema with two fields: "price" and "id" + std::vector> schema_vector = { + arrow::field("id", arrow::int64()), + arrow::field("price", arrow::float64()), + }; + auto schema = std::make_shared(schema_vector); + + // Create and return the table with both columns + return arrow::Table::Make(schema, {id_array, price_array}); +} + // // ===== Quake BENCHMARK FIXTURES ===== // @@ -53,14 +83,16 @@ class QuakeSerialFlatBenchmark : public ::testing::Test { std::shared_ptr index_; Tensor data_; Tensor ids_; + std::shared_ptr attributes_table_; void SetUp() override { data_ = generate_data(NUM_VECTORS, DIM); ids_ = generate_ids(NUM_VECTORS); + attributes_table_ = generate_data_frame(NUM_VECTORS, ids_); index_ = std::make_shared(); auto build_params = std::make_shared(); build_params->nlist = 1; // flat index build_params->metric = "l2"; - index_->build(data_, ids_, build_params); + index_->build(data_, ids_, build_params, attributes_table_); } }; @@ -70,16 +102,17 @@ class QuakeWorkerFlatBenchmark : public ::testing::Test { std::shared_ptr index_; Tensor data_; Tensor ids_; + std::shared_ptr attributes_table_; void SetUp() override { data_ = generate_data(NUM_VECTORS, DIM); ids_ = generate_ids(NUM_VECTORS); + attributes_table_ = generate_data_frame(NUM_VECTORS, ids_); index_ = std::make_shared(); auto build_params = std::make_shared(); build_params->nlist = 1; // flat index build_params->metric = "l2"; - // Use as many workers as hardware concurrency build_params->num_workers = N_WORKERS; - index_->build(data_, ids_, build_params); + index_->build(data_, ids_, build_params,attributes_table_); } }; @@ -90,15 +123,17 @@ class QuakeSerialIVFBenchmark : public ::testing::Test { std::shared_ptr index_; Tensor data_; Tensor ids_; + std::shared_ptr attributes_table_; void SetUp() override { data_ = generate_data(NUM_VECTORS, DIM); ids_ = generate_ids(NUM_VECTORS); + attributes_table_ = generate_data_frame(NUM_VECTORS, ids_); index_ = std::make_shared(); auto build_params = std::make_shared(); build_params->nlist = N_LIST; // IVF index build_params->metric = "l2"; build_params->niter = 3; - index_->build(data_, ids_, build_params); + index_->build(data_, ids_, build_params, attributes_table_); } }; @@ -108,16 +143,18 @@ class QuakeWorkerIVFBenchmark : public ::testing::Test { std::shared_ptr index_; Tensor data_; Tensor ids_; + std::shared_ptr attributes_table_; void SetUp() override { data_ = generate_data(NUM_VECTORS, DIM); ids_ = generate_ids(NUM_VECTORS); + attributes_table_ = generate_data_frame(NUM_VECTORS, ids_); index_ = std::make_shared(); auto build_params = std::make_shared(); build_params->nlist = N_LIST; // IVF index build_params->metric = "l2"; build_params->niter = 3; build_params->num_workers = N_WORKERS; - index_->build(data_, ids_, build_params); + index_->build(data_, ids_, build_params,attributes_table_); } }; diff --git a/test/cpp/quake_index.cpp b/test/cpp/quake_index.cpp index f6cd7bc6..3794911e 100644 --- a/test/cpp/quake_index.cpp +++ b/test/cpp/quake_index.cpp @@ -7,6 +7,13 @@ #include #include "quake_index.h" #include +#include +#include +#include +#include +#include +#include +#include // Helper functions for random data static torch::Tensor generate_random_data(int64_t num_vectors, int64_t dim) { @@ -17,6 +24,36 @@ static torch::Tensor generate_sequential_ids(int64_t count, int64_t start = 0) { return torch::arange(start, start + count, torch::kInt64); } +static std::shared_ptr generate_data_frame(int64_t num_vectors, torch::Tensor ids) { + arrow::MemoryPool* pool = arrow::default_memory_pool(); + + // Builders for the "price" and "id" columns + arrow::DoubleBuilder price_builder(pool); + arrow::Int64Builder id_builder(pool); + + // Append values to the builders + for (int64_t i = 0; i < num_vectors; i++) { + price_builder.Append(static_cast(i) * 1.5); // Price column + id_builder.Append(ids[i].item()); // ID column from the input tensor + } + + // Finalize the arrays + std::shared_ptr price_array; + std::shared_ptr id_array; + price_builder.Finish(&price_array); + id_builder.Finish(&id_array); + + // Define the schema with two fields: "price" and "id" + std::vector> schema_vector = { + arrow::field("id", arrow::int64()), + arrow::field("price", arrow::float64()), + }; + auto schema = std::make_shared(schema_vector); + + // Create and return the table with both columns + return arrow::Table::Make(schema, {id_array, price_array}); +} + class QuakeIndexTest : public ::testing::Test { protected: // Example parameters @@ -32,6 +69,9 @@ class QuakeIndexTest : public ::testing::Test { // Query vectors torch::Tensor query_vectors_; + // Arrow data + std::shared_ptr attributes_table; + void SetUp() override { // Generate random data data_vectors_ = generate_random_data(num_vectors_, dimension_); @@ -40,6 +80,9 @@ class QuakeIndexTest : public ::testing::Test { // Queries query_vectors_ = generate_random_data(num_queries_, dimension_); + + // Arrow data + attributes_table = generate_data_frame(num_vectors_, data_ids_); } }; @@ -64,7 +107,7 @@ TEST_F(QuakeIndexTest, BuildTest) { build_params->metric = "l2"; build_params->niter = 5; // small kmeans iteration - auto timing_info = index.build(data_vectors_, data_ids_, build_params); + auto timing_info = index.build(data_vectors_, data_ids_, build_params, attributes_table); // Check that we created partition_manager_, parent_, etc. EXPECT_NE(index.partition_manager_, nullptr); @@ -106,7 +149,7 @@ TEST_F(QuakeIndexTest, SearchPartitionedTest) { auto build_params = std::make_shared(); build_params->nlist = nlist_; build_params->metric = "l2"; - index.build(data_vectors_, data_ids_, build_params); + index.build(data_vectors_, data_ids_, build_params, attributes_table); // Create a search_params object (if you need special fields, set them up) auto search_params = std::make_shared(); @@ -133,7 +176,7 @@ TEST_F(QuakeIndexTest, SearchFlatTest) { // Build auto build_params = std::make_shared(); build_params->metric = "l2"; - index.build(data_vectors_, data_ids_, build_params); + index.build(data_vectors_, data_ids_, build_params, attributes_table); // Create a search_params object (if you need special fields, set them up) auto search_params = std::make_shared(); @@ -185,8 +228,9 @@ TEST_F(QuakeIndexTest, AddTest) { Tensor add_vectors = generate_random_data(10, dimension_); Tensor add_ids = generate_sequential_ids(10, 1000); + auto attr_table = generate_data_frame(10,add_ids); - auto modify_info = index.add(add_vectors, add_ids); + auto modify_info = index.add(add_vectors, add_ids, attr_table); EXPECT_EQ(modify_info->n_vectors, 10); EXPECT_GE(modify_info->modify_time_us, 0); } @@ -198,7 +242,7 @@ TEST_F(QuakeIndexTest, RemoveTest) { // Build auto build_params = std::make_shared(); build_params->nlist = nlist_; - index.build(data_vectors_, data_ids_, build_params); + index.build(data_vectors_, data_ids_, build_params, attributes_table); // remove half of them int64_t remove_count = num_vectors_ / 2; @@ -260,6 +304,7 @@ TEST(QuakeIndexStressTest, LargeBuildTest) { int64_t num_vectors = 1e6; // 1 million vectors auto data_vectors = generate_random_data(num_vectors, dimension); auto data_ids = generate_sequential_ids(num_vectors, 0); + auto data_frames = generate_data_frame(num_vectors, data_ids); QuakeIndex index; @@ -270,7 +315,7 @@ TEST(QuakeIndexStressTest, LargeBuildTest) { build_params->niter = 5; auto t0 = std::chrono::high_resolution_clock::now(); - auto timing_info = index.build(data_vectors, data_ids, build_params); + auto timing_info = index.build(data_vectors, data_ids, build_params, data_frames); auto t1 = std::chrono::high_resolution_clock::now(); // Check that the build completed and that we didn't crash @@ -297,6 +342,7 @@ TEST(QuakeIndexStressTest, RepeatedBuildSearchTest) { // Pre-generate data auto data_vectors = generate_random_data(num_vectors, dimension); auto data_ids = generate_sequential_ids(num_vectors, 1000); + auto data_frames = generate_data_frame(num_vectors, data_ids); auto query_vectors = generate_random_data(num_queries, dimension); for (int i = 0; i < iteration_count; i++) { @@ -307,7 +353,7 @@ TEST(QuakeIndexStressTest, RepeatedBuildSearchTest) { build_params->niter = 3; // Build index - index.build(data_vectors, data_ids, build_params); + index.build(data_vectors, data_ids, build_params, data_frames); // Query index auto search_params = std::make_shared(); @@ -420,6 +466,7 @@ TEST(QuakeIndexStressTest, HighDimensionTest) { int64_t num_vectors = 5000; auto data_vectors = generate_random_data(num_vectors, dimension); auto data_ids = generate_sequential_ids(num_vectors); + auto data_frames = generate_data_frame(num_vectors, data_ids); QuakeIndex index; auto build_params = std::make_shared(); @@ -429,7 +476,7 @@ TEST(QuakeIndexStressTest, HighDimensionTest) { build_params->niter = 3; // If your system doesn’t have enough memory for bigger tests, reduce num_vectors or dimension. - auto timing_info = index.build(data_vectors, data_ids, build_params); + auto timing_info = index.build(data_vectors, data_ids, build_params, data_frames); ASSERT_NE(timing_info, nullptr); EXPECT_EQ(index.ntotal(), num_vectors); @@ -491,4 +538,4 @@ TEST(QuakeIndexStressTest, SearchAddRemoveMaintenanceTest) { } SUCCEED(); -} \ No newline at end of file +} diff --git a/test/cpp/query_coordinator.cpp b/test/cpp/query_coordinator.cpp index d272f2d0..33a468da 100644 --- a/test/cpp/query_coordinator.cpp +++ b/test/cpp/query_coordinator.cpp @@ -27,6 +27,45 @@ class QueryCoordinatorTest : public ::testing::Test { std::shared_ptr partition_manager_; MetricType metric_ = faiss::METRIC_L2; + static torch::Tensor generate_random_data(int64_t num_vectors, int64_t dim) { + return torch::randn({num_vectors, dim}, torch::kFloat32); + } + + static torch::Tensor generate_sequential_ids(int64_t count, int64_t start = 0) { + return torch::arange(start, start + count, torch::kInt64); + } + + static std::shared_ptr generate_data_frame(int64_t num_vectors, torch::Tensor ids) { + arrow::MemoryPool* pool = arrow::default_memory_pool(); + + // Builders for the "price" and "id" columns + arrow::DoubleBuilder price_builder(pool); + arrow::Int64Builder id_builder(pool); + + // Append values to the builders + for (int64_t i = 0; i < num_vectors; i++) { + price_builder.Append(i); // Price column + id_builder.Append(ids[i].item()); // ID column from the input tensor + } + + // Finalize the arrays + std::shared_ptr price_array; + std::shared_ptr id_array; + price_builder.Finish(&price_array); + id_builder.Finish(&id_array); + + // Define the schema with two fields: "price" and "id" + std::vector> schema_vector = { + arrow::field("id", arrow::int64()), + arrow::field("price", arrow::float64()), + }; + auto schema = std::make_shared(schema_vector); + + // Create and return the table with both columns + return arrow::Table::Make(schema, {id_array, price_array}); + } + + void SetUp() override { // Create dummy vectors and IDs @@ -169,6 +208,66 @@ TEST_F(QueryCoordinatorTest, WorkerInitializationTest) { ASSERT_TRUE(coordinator->workers_initialized_); } +TEST_F(QueryCoordinatorTest, PreFilteringTest) { + auto index = std::make_shared(); + auto build_params = std::make_shared(); + build_params->nlist = 1; + build_params->metric = "l2"; + int64_t num_vectors = 10; + auto data_vectors = generate_random_data(num_vectors, dimension_); + auto data_ids = generate_sequential_ids(num_vectors, 0); + auto attributes_table = generate_data_frame(num_vectors, data_ids); + index->build(data_vectors, data_ids, build_params, attributes_table); + auto coordinator = std::make_shared( + index->parent_, + index->partition_manager_, + nullptr, + faiss::METRIC_L2 + ); + auto search_params = std::make_shared(); + search_params->k = 2; + search_params->price_threshold = 1; + search_params->filteringType = FilteringType::PRE_FILTERING; + auto result_worker = coordinator->search(torch::randn({1, dimension_}, torch::kFloat32), search_params); + vector expected_result = {0, 1}; + ASSERT_TRUE(result_worker != nullptr); + ASSERT_EQ(result_worker->ids.sizes(), (std::vector{1, 2})); + ASSERT_EQ(result_worker->distances.sizes(), (std::vector{1, 2})); + std::vector result_worker_vector(result_worker->ids.data(), result_worker->ids.data() + result_worker->ids.numel()); + sort(result_worker_vector.begin(), result_worker_vector.end()); + ASSERT_EQ(expected_result, result_worker_vector); +} + +TEST_F(QueryCoordinatorTest, PostFilteringTest) { + auto index = std::make_shared(); + auto build_params = std::make_shared(); + build_params->nlist = 1; + build_params->metric = "l2"; + int64_t num_vectors = 10; + auto data_vectors = generate_random_data(num_vectors, dimension_); + auto data_ids = generate_sequential_ids(num_vectors, 0); + auto attributes_table = generate_data_frame(num_vectors, data_ids); + index->build(data_vectors, data_ids, build_params, attributes_table); + auto coordinator = std::make_shared( + index->parent_, + index->partition_manager_, + nullptr, + faiss::METRIC_L2 + ); + auto search_params = std::make_shared(); + search_params->k = 2; + search_params->price_threshold = 1; + search_params->filteringType = FilteringType::POST_FILTERING; + auto result_worker = coordinator->search(torch::randn({1, dimension_}, torch::kFloat32), search_params); + vector expected_result = {0, 1}; + ASSERT_TRUE(result_worker != nullptr); + ASSERT_EQ(result_worker->ids.sizes(), (std::vector{1, 2})); + ASSERT_EQ(result_worker->distances.sizes(), (std::vector{1, 2})); + std::vector result_worker_vector(result_worker->ids.data(), result_worker->ids.data() + result_worker->ids.numel()); + sort(result_worker_vector.begin(), result_worker_vector.end()); + ASSERT_EQ(expected_result, result_worker_vector); +} + TEST_F(QueryCoordinatorTest, FlatWorkerScan) { int num_workers = 4;