From b68d63e01b7104d614f5167dc16431038d826761 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Fri, 21 Jul 2023 12:18:33 -0400 Subject: [PATCH 01/28] Begin refactoring RAFT CMake configuration --- CMakeLists.txt | 19 +++++++++++- cmake/raft.cmake | 81 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 1 deletion(-) create mode 100644 cmake/raft.cmake diff --git a/CMakeLists.txt b/CMakeLists.txt index b1661249d..c1b226b18 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,9 @@ -cmake_minimum_required(VERSION 3.10) +option(USE_CUDA "Build Cuda code" On) +if(USE_CUDA) + cmake_minimum_required(VERSION 3.23.1 FATAL_ERROR) +else() + cmake_minimum_required(VERSION 3.10) +endif() cmake_policy(SET CMP0077 NEW) set(CMAKE_CXX_STANDARD 20) @@ -24,6 +29,18 @@ include(cmake/san.cmake) # ---------------------------------------------------------------------------------------------- project(VectorSimilarity) +if (USE_CUDA) + # Enable CUDA compilation for this project + enable_language(CUDA) + # Add definition for conditional compilation of CUDA components + add_definitions(-DUSE_CUDA) + # Perform all RAFT-specific CMake setup + include(cmake/raft.cmake) + # Required flags for compiling RAFT + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 -std=c++17") +endif() + # Only do these if this is the main project, and not if it is included through add_subdirectory set_property(GLOBAL PROPERTY USE_FOLDERS ON) diff --git a/cmake/raft.cmake b/cmake/raft.cmake new file mode 100644 index 000000000..152689484 --- /dev/null +++ b/cmake/raft.cmake @@ -0,0 +1,81 @@ +if(USE_CUDA) + # Set which version of RAPIDS to use + set(RAPIDS_VERSION 23.08) + # Set which version of RAFT to use (defined separately for testing + # minimal dependency changes if necessary) + set(RAFT_VERSION "${RAPIDS_VERSION}") + set(RAFT_FORK "rapidsai") + set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}") + + # Download CMake file for bootstrapping RAPIDS-CMake, a utility that + # simplifies handling of complex RAPIDS dependencies + if (NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/_RAPIDS.cmake) + file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION}/_RAPIDS.cmake + ${CMAKE_CURRENT_BINARY_DIR}/_RAPIDS.cmake) + endif() + include(${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) + + # General tool for orchestrating RAPIDS dependencies + include(rapids-cmake) + # CPM helper functions with dependency tracking + include(rapids-cpm) + rapids_cpm_init() + # Common CMake CUDA logic + include(rapids-cuda) + # Include required dependencies in Project-Config.cmake modules + # include(rapids-export) TODO(wphicks) + # Functions to find system dependencies with dependency tracking + include(rapids-find) + + # Correctly handle supported CUDA architectures + # (From rapids-cuda) + rapids_cuda_init_architectures(VectorSimilarity) + + # Find system CUDA toolkit + rapids_find_package(CUDAToolkit REQUIRED) + + set(RAFT_VERSION "${RAPIDS_VERSION}") + set(RAFT_FORK "rapidsai") + set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}") + + function(find_and_configure_raft) + set(oneValueArgs VERSION FORK PINNED_TAG COMPILE_LIBRARY) + cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN} ) + + set(RAFT_COMPONENTS "") + if(PKG_COMPILE_LIBRARY) + string(APPEND RAFT_COMPONENTS " compiled") + endif() + # Invoke CPM find_package() + # (From rapids-cpm) + rapids_cpm_find(raft ${PKG_VERSION} + GLOBAL_TARGETS raft::raft + BUILD_EXPORT_SET VectorSimilarity-exports + INSTALL_EXPORT_SET VectorSimilarity-exports + COMPONENTS ${RAFT_COMPONENTS} + CPM_ARGS + GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git + GIT_TAG ${PKG_PINNED_TAG} + SOURCE_SUBDIR cpp + OPTIONS + "BUILD_TESTS OFF" + "BUILD_BENCH OFF" + "RAFT_COMPILE_LIBRARY ${PKG_COMPILE_LIBRARY}" + ) + if(raft_ADDED) + message(VERBOSE "VectorSimilarity: Using RAFT located in ${raft_SOURCE_DIR}") + else() + message(VERBOSE "VectorSimilarity: Using RAFT located in ${raft_DIR}") + endif() + endfunction() + + # Change pinned tag here to test a commit in CI + # To use a different RAFT locally, set the CMake variable + # CPM_raft_SOURCE=/path/to/local/raft + find_and_configure_raft(VERSION ${RAFT_VERSION}.00 + FORK ${RAFT_FORK} + PINNED_TAG ${RAFT_PINNED_TAG} + COMPILE_LIBRARY OFF + ) +endif() From 10093d37cb7eaaafa1920213d95c0865a443e28a Mon Sep 17 00:00:00 2001 From: William Hicks Date: Fri, 21 Jul 2023 17:54:03 -0400 Subject: [PATCH 02/28] Correct RAFT CMake configuration --- CMakeLists.txt | 2 +- check-format.sh | 4 ++-- cmake/raft.cmake | 10 +++++----- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c1b226b18..0e7588656 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -44,7 +44,7 @@ endif() # Only do these if this is the main project, and not if it is included through add_subdirectory set_property(GLOBAL PROPERTY USE_FOLDERS ON) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions -fPIC ${CLANG_SAN_FLAGS} ${LLVM_CXX_FLAGS} ${COV_CXX_FLAGS}") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions -fPIC -pthread ${CLANG_SAN_FLAGS} ${LLVM_CXX_FLAGS} ${COV_CXX_FLAGS}") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_LINKER_FLAGS} ${LLVM_LD_FLAGS}") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_LINKER_FLAGS} ${LLVM_LD_FLAGS}") diff --git a/check-format.sh b/check-format.sh index b474077c4..762fd519e 100755 --- a/check-format.sh +++ b/check-format.sh @@ -1,6 +1,6 @@ #!/bin/bash -CLANG_FMT_SRCS=$(find ./src/ \( -name '*.c' -o -name '*.cc' -o -name '*.cpp' -o -name '*.h' -o -name '*.hh' -o -name '*.hpp' \)) -CLANG_FMT_TESTS="$(find ./tests/ -type d \( -path ./tests/unit/build \) -prune -false -o \( -name '*.c' -o -name '*.cc' -o -name '*.cpp' -o -name '*.h' -o -name '*.hh' -o -name '*.hpp' \))" +CLANG_FMT_SRCS=$(find ./src/ \( -name '*.c' -o -name '*.cc' -o -name '*.cpp' -o -name '*.h' -o -name '*.hh' -o -name '*.hpp' -o -name '*.cuh' -o -name '*.cu' \)) +CLANG_FMT_TESTS="$(find ./tests/ -type d \( -path ./tests/unit/build \) -prune -false -o \( -name '*.c' -o -name '*.cc' -o -name '*.cpp' -o -name '*.h' -o -name '*.hh' -o -name '*.hpp' -o -name '*.cuh' -o -name '*.cu' \))" E=0 for filename in $CLANG_FMT_SRCS $CLANG_FMT_TESTS; do diff --git a/cmake/raft.cmake b/cmake/raft.cmake index 152689484..7268fb6e4 100644 --- a/cmake/raft.cmake +++ b/cmake/raft.cmake @@ -1,6 +1,6 @@ if(USE_CUDA) # Set which version of RAPIDS to use - set(RAPIDS_VERSION 23.08) + set(RAPIDS_VERSION 23.06) # Set which version of RAFT to use (defined separately for testing # minimal dependency changes if necessary) set(RAFT_VERSION "${RAPIDS_VERSION}") @@ -9,11 +9,11 @@ if(USE_CUDA) # Download CMake file for bootstrapping RAPIDS-CMake, a utility that # simplifies handling of complex RAPIDS dependencies - if (NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/_RAPIDS.cmake) - file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION}/_RAPIDS.cmake - ${CMAKE_CURRENT_BINARY_DIR}/_RAPIDS.cmake) + if (NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) + file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION}/RAPIDS.cmake + ${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) endif() - include(${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) + include(${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) # General tool for orchestrating RAPIDS dependencies include(rapids-cmake) From 77a497103384094214bdad2447f4f5881bf38eab Mon Sep 17 00:00:00 2001 From: William Hicks Date: Wed, 26 Jul 2023 16:03:32 -0400 Subject: [PATCH 03/28] Conditionally link to RAFT --- src/VecSim/CMakeLists.txt | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/VecSim/CMakeLists.txt b/src/VecSim/CMakeLists.txt index 4998e17be..c778498a6 100644 --- a/src/VecSim/CMakeLists.txt +++ b/src/VecSim/CMakeLists.txt @@ -32,9 +32,15 @@ add_library(VectorSimilarity ${VECSIM_LIBTYPE} ${HEADER_LIST} ) -target_link_libraries(VectorSimilarity VectorSimilaritySpaces) if(VECSIM_BUILD_TESTS) add_library(VectorSimilaritySerializer utils/serializer.cpp) - target_link_libraries(VectorSimilarity VectorSimilaritySerializer) endif() + +target_link_libraries(VectorSimilarity +PUBLIC + VectorSimilaritySpaces + $<$:VectorSimilaritySerializer> +PRIVATE + $<$:raft::raft> +) From a1aa36688d199d0be52286f51a562363d804f555 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Wed, 26 Jul 2023 17:49:30 -0400 Subject: [PATCH 04/28] Add configuration structs for RAFT indexes --- src/VecSim/vec_sim_common.h | 46 +++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/src/VecSim/vec_sim_common.h b/src/VecSim/vec_sim_common.h index fd30df1e6..8be956af2 100644 --- a/src/VecSim/vec_sim_common.h +++ b/src/VecSim/vec_sim_common.h @@ -42,6 +42,13 @@ typedef enum { VecSimAlgo_BF, VecSimAlgo_HNSWLIB, VecSimAlgo_TIERED } VecSimAlgo // Distance metric typedef enum { VecSimMetric_L2, VecSimMetric_IP, VecSimMetric_Cosine } VecSimMetric; +// Codebook kind for IVFPQ indexes +typedef enum { IVFPQCodebookKind_PerCluster, IVFPQCodebookKind_PerSubspace } IVFPQCodebookKind; + +// CUDA types supported by GPU-accelerated indexes +typedef enum { CUDAType_R_32F, CUDAType_R_16F, CUDAType_R_8U } CudaType; + + typedef size_t labelType; typedef unsigned int idType; @@ -130,9 +137,42 @@ typedef struct { } specificParams; } TieredIndexParams; +typedef struct { + VecSimType type; // Datatype to index. + size_t dim; // Vector's dimension. + VecSimMetric metric; // Distance metric to use in the index. + bool multi; // Determines if the index should multi-index or not. + size_t nLists; // Number of inverted lists + bool adaptiveCenters; // If index should be updated for new vectors + bool conservativeMemoryAllocation; // Use as little GPU memory as possible + size_t kmeans_nIters; // Iterations for kmeans calculation + float kmeans_trainsetFraction; // Fraction of dataset used for kmeans training + unsigned nProbes; // The number of clusters to search + size_t pqDim; // The dimensionality of an encoded vector after PQ + // compression. If set to 0, IVF flat will be used + // instead of IVFPQ. + // + // ******************* IVFPQ-only parameters ******************* + // The following parameters will be ignored if pqDim is set to 0 + + size_t pqBits; // Bit length of vector element after PQ compression + IVFPQCodebookKind codebookKind; + CudaType lutType; + CudaType internalDistanceType; + double preferredShmemCarvout; // Fraction of GPU's unified memory / L1 + // cache to be used as shared memory + +} IVFParams; + +typedef struct { + IVFParams ivfParams; + TieredIndexParams tieredParams; +} TieredIVFParams; + typedef union { HNSWParams hnswParams; BFParams bfParams; + IVFParams ivfParams; TieredIndexParams tieredParams; } AlgoParams; @@ -232,6 +272,12 @@ typedef struct { char dummy; // For not having this as an empty struct, can be removed after we extend this. } bfInfoStruct; +typedef struct { + size_t nLists; // Number of inverted lists. + size_t pqDim; // Dimensionality of encoded vector after PQ + size_t pqBits; // Bits per encoded vector element after PQ +} ivfInfoStruct; + typedef struct HnswTieredInfo { size_t pendingSwapJobsThreshold; } HnswTieredInfo; From bcef4c15706b6e6d5f516116caa4fef3e3941179 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Mon, 31 Jul 2023 13:10:56 -0400 Subject: [PATCH 05/28] Add IVF index headers --- src/VecSim/algorithms/ivf/ivf.cuh | 342 +++++++++++++++++++++++ src/VecSim/algorithms/ivf/ivf_tiered.cuh | 4 + src/VecSim/tombstone_interface.h | 1 + src/VecSim/vec_sim_common.h | 22 +- 4 files changed, 360 insertions(+), 9 deletions(-) create mode 100644 src/VecSim/algorithms/ivf/ivf.cuh create mode 100644 src/VecSim/algorithms/ivf/ivf_tiered.cuh diff --git a/src/VecSim/algorithms/ivf/ivf.cuh b/src/VecSim/algorithms/ivf/ivf.cuh new file mode 100644 index 000000000..25c53ba60 --- /dev/null +++ b/src/VecSim/algorithms/ivf/ivf.cuh @@ -0,0 +1,342 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include "VecSim/vec_sim.h" +// For VecSimMetric, IVFParams, labelType +#include "VecSim/vec_sim_common.h" +// For VecSimIndexAbstract +#include "VecSim/vec_sim_index.h" +#include "VecSim/query_result_struct.h" +#include "VecSim/memory/vecsim_malloc.h" + +#include +#include +#include +#include +#include +#include +#include + +#pragma once + +inline auto constexpr GetRaftDistanceType(VecSimMetric vsm) { + auto result = raft::distance::DistanceType{}; + switch (vsm) { + case VecSimMetric_L2: + result = raft::distance::DistanceType::L2Expanded; + break; + case VecSimMetric_IP: + result = raft::distance::DistanceType::InnerProduct; + break; + default: + throw raft::exception("Metric not supported"); + } + return result; +} + +inline auto constexpr GetRaftCodebookKind(IVFPQCodebookKind vss_codebook) { + auto result = raft::neighbors::ivf_pq::codebook_gen{}; + switch(vss_codebook) { + case IVFPQCodebookKind_PerCluster: + result = raft::neighbors::ivf_pq::codebook_gen::PER_CLUSTER; + break; + case IVFPQCodebookKind_PerSubspace: + result = raft::neighbors::ivf_pq::codebook_gen::PER_SUBSPACE; + break; + default: + throw raft::exception("Unexpected IVFPQ codebook kind"); + } + return result; +} + +inline auto constexpr GetCudaType(CudaType vss_type) { + auto result = cudaDataType_t{}; + switch(vss_type) { + case CUDAType_R_32F: + result = CUDA_R_32F; + break; + case CUDAType_R_16F: + result = CUDA_R_16F; + break; + case CUDAType_R_8U: + result = CUDA_R_8U; + break; + default: + throw raft::exception("Unexpected CUDA type"); + } + return result; +} + +template +struct IVFIndex : public VecSimIndexAbstract { + using data_type = DataType; + using dist_type = DistType; + +private: + // Allow either IVF-flat or IVFPQ parameters + using build_params_t = std::variant; + using search_params_t = std::variant; + using internal_idx_t = std::uint32_t; + using ann_index_t = std::variant, raft::neighbors::ivf_pq::index>; + +public: + IVFIndex(const IVFParams *ivfParams, const AbstractIndexInitParams & commonParams) + : VecSimIndexAbstract{commonParams}, + res_{}, //TODO(wphicks): Construct smartly + build_params_{[ivfParams](){ + auto result = ivfParams->pqBits > 0 ? + build_params_t{std::in_place_index<1>} : + build_params_t{std::in_place_index<0>}; + std::visit( + [ivfParams](auto&& inner) { + inner.metric = GetRaftDistanceType(ivfParams->metric); + inner.n_lists = ivfParams->nLists; + inner.kmeans_n_iters = ivfParams->kmeans_nIters; + inner.kmeans_trainset_fraction = ivfParams->kmeans_trainsetFraction; + inner.conservative_memory_allocation = ivfParams->conservativeMemoryAllocation; + if constexpr (std::is_same_v) { + inner.pq_bits = ivfParams->pqBits; + inner.pq_dim = ivfParams->pqDim; + inner.codebook_kind = GetRaftCodebookKind(ivfParams->codebookKind); + } else { + inner.adaptive_centers = ivfParams->adaptiveCenters; + } + }, result + ); + return result; + }()}, + search_params_{[ivfParams](){ + auto result = ivfParams->pqBits > 0 ? + search_params_t{std::in_place_index<1>} : + search_params_t{std::in_place_index<0>}; + std::visit( + [ivfParams](auto&& inner) { + inner.n_probes = ivfParams->nProbes; + if constexpr (std::is_same_v) { + inner.lut_dtype = GetCudaType(ivfParams->lutType); + inner.internal_distance_dtype = GetCudaType(ivfParams->internalDistanceType); + inner.preferred_shmem_carvout = ivfParams->preferredShmemCarveout; + } + }, result + ); + return result; + }()}, + index_{std::nullopt} + {} + auto addVector(const void *vector_data, labelType label, + bool overwrite_allowed = true) override { + return addVectorBatch(vector_data, &label, 1, overwrite_allowed); + } + auto addVectorBatch(const void *vector_data, labelType *label, size_t batch_size, + bool overwrite_allowed = true) { + // Allocate memory on device to hold vectors to be added + auto vector_data_gpu = + raft::make_device_matrix(res_, batch_size, + this->dim); + // Allocate memory on device to hold vector labels + auto label_gpu = raft::make_device_vector( + res_, + batch_size + ); + + // Copy vector data to previously allocated device buffer + raft::copy( + vector_data_gpu.data_handle(), + static_cast(vector_data), + this->dim * batch_size, + res_.get_stream() + ); + // Copy label data to previously allocated device buffer + raft::copy( + label_gpu.data_handle(), + label, + batch_size, + res_.get_stream() + ); + + if (std::holds_alternative(build_params_)) { + if (!index_) { + index_ = raft::neighbors::ivf_flat::build( + res_, + std::get(build_params_), + vector_data_gpu.view() + ); + } + raft::neighbors::ivf_flat::extend( + res_, + vector_data_gpu.view(), + label_gpu, + *index_ + ); + } else { + if (!index_) { + index_ = raft::neighbors::ivf_pq::build( + res_, + std::get(build_params_), + vector_data_gpu.view() + ); + } + raft::neighbors::ivf_pq::extend( + res_, + vector_data_gpu.view(), + label_gpu, + *index_ + ); + } + + // Ensure that above operations have executed on device before + // returning from this function on host + res_.sync_stream(); + return batch_size; + } + auto deleteVector(labelType label) override { + assert(!"deleteVector not implemented"); + return 0; + } + double getDistanceFrom(labelType label, const void *vector_data) const override { + assert(!"getDistanceFrom not implemented"); + return INVALID_SCORE; + } + size_t indexCapacity() const override { + assert(!"indexCapacity not implemented"); + return 0; + } + void increaseCapacity() override { + assert(!"increaseCapacity not implemented"); + } + inline auto indexLabelCount() const override { + return this->indexSize(); // TODO: Return unique counts + } + auto topKQuery(const void *queryBlob, size_t k, + VecSimQueryParams *queryParams) override { + auto result_list = VecSimQueryResult_List{0}; + auto nVectors = this->indexSize(); + if (nVectors == 0) { + result_list.results = array_new(0); + } else { + // Ensure we are not trying to retrieve more vectors than exist in the + // index + k = std::min(k, nVectors); + // Allocate memory on device for search vector + auto vector_data_gpu = raft::make_device_matrix(res_,1, this->dim); + // Allocate memory on device for neighbor results + auto neighbors_gpu = raft::make_device_vector(res_, k); + // Allocate memory on device for distance results + auto distances_gpu = raft::make_device_vector(res_, k); + // Copy query vector to device + raft::copy( + vector_data_gpu.data_handle(), + static_cast(queryBlob), + this->dim, + res_.get_stream() + ); + + // Perform correct search based on index type + if ( + std::holds_alternative(index_) + ) { + raft::neighbors::ivf_flat::search( + res_, + std::get(search_params_), + std::get(*index_), + vector_data_gpu.view(), + neighbors_gpu.view(), + distances_gpu.view() + ) + } else { + raft::neighbors::ivf_pq::search( + res_, + std::get(search_params_), + std::get(*index_), + vector_data_gpu.view(), + neighbors_gpu.view(), + distances_gpu.view() + ) + } + + // Allocate host buffers to hold returned results + auto neighbors = std::unique_ptr( + array_new_len(k, k), + &array_free + ); + auto distances = std::unique_ptr( + array_new_len(k, k), + &array_free + ); + // Copy data back from device to host + raft::copy( + neighbors.get(), + neighbors_gpu.data_handle(), + this->dim, + res_.get_stream() + ); + raft::copy( + distances.get(), + distances_gpu.data_handle(), + this->dim, + res_.get_stream() + ); + + result_list.results = array_new_len(k, k); + + // Ensure search is complete and data have been copied back before + // building query result objects on host + res_.sync_stream(); + for (size_t i = 0; i < k; ++i) { + VecSimQueryResult_SetId(result_list.results[i], neighbors[i]); + VecSimQueryResult_SetScore(result_list.results[i], distances[i]); + } + } + return result_list; + } + + VecSimQueryResult_List rangeQuery(const void *queryBlob, double radius, + VecSimQueryParams *queryParams) override { + assert(!"RangeQuery not implemented"); + } + VecSimInfoIterator *infoIterator() const override { assert(!"infoIterator not implemented"); } + virtual VecSimBatchIterator *newBatchIterator(const void *queryBlob, + VecSimQueryParams *queryParams) const override { + assert(!"newBatchIterator not implemented"); + } + bool preferAdHocSearch(size_t subsetSize, size_t k, bool initial_check) override { + assert(!"preferAdHocSearch not implemented"); + } + + auto& get_resources() const { return res_; } + + auto nLists() { + return std::visit([](auto&& params) { + return params.n_list; + }, build_params_); + } + + auto indexSize() { + auto result = size_t{}; + if (index_) { + result = std::visit([](auto&& index) { + return index.size(); + }, *index_); + } + return result; + } + + private: + // An object used to manage common device resources that may be + // expensive to build but frequently accessed + raft::device_resources res_; + // Store build params to allow for index build on first batch + // insertion + build_params_t build_params_; + // Store search params to use with each search after initializing in + // constructor + search_params_t search_params_; + // Use a std::optional to allow building of the index on first batch + // insertion + std::optional index_; +}; diff --git a/src/VecSim/algorithms/ivf/ivf_tiered.cuh b/src/VecSim/algorithms/ivf/ivf_tiered.cuh new file mode 100644 index 000000000..03e86f37f --- /dev/null +++ b/src/VecSim/algorithms/ivf/ivf_tiered.cuh @@ -0,0 +1,4 @@ +#include "VecSim/vec_sim_tiered_index.h" +template +struct TieredIVFIndex : public VecSimTieredIndex { +}; diff --git a/src/VecSim/tombstone_interface.h b/src/VecSim/tombstone_interface.h index 0d864fa7b..ff4c43c0d 100644 --- a/src/VecSim/tombstone_interface.h +++ b/src/VecSim/tombstone_interface.h @@ -1,5 +1,6 @@ #pragma once #include +#include #include "vec_sim_common.h" /* diff --git a/src/VecSim/vec_sim_common.h b/src/VecSim/vec_sim_common.h index 8be956af2..032e6d98b 100644 --- a/src/VecSim/vec_sim_common.h +++ b/src/VecSim/vec_sim_common.h @@ -143,23 +143,27 @@ typedef struct { VecSimMetric metric; // Distance metric to use in the index. bool multi; // Determines if the index should multi-index or not. size_t nLists; // Number of inverted lists - bool adaptiveCenters; // If index should be updated for new vectors bool conservativeMemoryAllocation; // Use as little GPU memory as possible size_t kmeans_nIters; // Iterations for kmeans calculation - float kmeans_trainsetFraction; // Fraction of dataset used for kmeans training + double kmeans_trainsetFraction; // Fraction of dataset used for kmeans training unsigned nProbes; // The number of clusters to search - size_t pqDim; // The dimensionality of an encoded vector after PQ - // compression. If set to 0, IVF flat will be used - // instead of IVFPQ. - // + size_t pqBits; // Bit length of vector element after PQ compression. If set + // to 0, IVF flat will be used instead of IVFPQ. + // ***************** IVF-Flat-only parameters ****************** + // The following parameters will be ignored if pqBits is set to a + // non-zero value. + bool adaptiveCenters; // If index should be updated for new vectors + // ******************* IVFPQ-only parameters ******************* - // The following parameters will be ignored if pqDim is set to 0 + // The following parameters will be ignored if pqBits is set to 0 + size_t pqDim; // The dimensionality of an encoded vector after PQ + // compression. If set to 0, a heuristic will be used to + // select the dimensionality. - size_t pqBits; // Bit length of vector element after PQ compression IVFPQCodebookKind codebookKind; CudaType lutType; CudaType internalDistanceType; - double preferredShmemCarvout; // Fraction of GPU's unified memory / L1 + double preferredShmemCarveout; // Fraction of GPU's unified memory / L1 // cache to be used as shared memory } IVFParams; From 2ebcf932257e6e1b56ba053ab68f4e1a3277a3a3 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Fri, 4 Aug 2023 17:30:19 -0400 Subject: [PATCH 06/28] Provide initial update of tiered RAFT index --- CMakeLists.txt | 2 +- cmake/raft.cmake | 10 +- .../algorithms/brute_force/brute_force.h | 5 + src/VecSim/algorithms/ivf/ivf.cuh | 14 ++- src/VecSim/algorithms/ivf/ivf_tiered.cuh | 112 ++++++++++++++++++ src/VecSim/vec_sim_common.h | 1 + 6 files changed, 136 insertions(+), 8 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0e7588656..0dd92f683 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ option(USE_CUDA "Build Cuda code" On) if(USE_CUDA) - cmake_minimum_required(VERSION 3.23.1 FATAL_ERROR) + cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) else() cmake_minimum_required(VERSION 3.10) endif() diff --git a/cmake/raft.cmake b/cmake/raft.cmake index 7268fb6e4..1b48cb34f 100644 --- a/cmake/raft.cmake +++ b/cmake/raft.cmake @@ -1,11 +1,15 @@ if(USE_CUDA) # Set which version of RAPIDS to use - set(RAPIDS_VERSION 23.06) + set(RAPIDS_VERSION 23.10) # Set which version of RAFT to use (defined separately for testing # minimal dependency changes if necessary) set(RAFT_VERSION "${RAPIDS_VERSION}") - set(RAFT_FORK "rapidsai") - set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}") + # TODO(wphicks): Reset to main fork and branch after + # https://github.com/rapidsai/raft/pull/1716 has been merged + # set(RAFT_FORK "rapidsai") + set(RAFT_FORK "wphicks") + # set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}") + set(RAFT_PINNED_TAG "fea-resource_manager") # Download CMake file for bootstrapping RAPIDS-CMake, a utility that # simplifies handling of complex RAPIDS dependencies diff --git a/src/VecSim/algorithms/brute_force/brute_force.h b/src/VecSim/algorithms/brute_force/brute_force.h index 5889bc6f9..5c3747432 100644 --- a/src/VecSim/algorithms/brute_force/brute_force.h +++ b/src/VecSim/algorithms/brute_force/brute_force.h @@ -35,6 +35,11 @@ class BruteForceIndex : public VecSimIndexAbstract { public: BruteForceIndex(const BFParams *params, const AbstractIndexInitParams &abstractInitParams); + void clear() { + idToLabelMapping.clear(); + vectorBlocks.clear(); + count = idType{}; + } size_t indexSize() const override; size_t indexCapacity() const override; vecsim_stl::vector computeBlockScores(const DataBlock &block, const void *queryBlob, diff --git a/src/VecSim/algorithms/ivf/ivf.cuh b/src/VecSim/algorithms/ivf/ivf.cuh index 25c53ba60..cc337e4a9 100644 --- a/src/VecSim/algorithms/ivf/ivf.cuh +++ b/src/VecSim/algorithms/ivf/ivf.cuh @@ -87,7 +87,7 @@ private: public: IVFIndex(const IVFParams *ivfParams, const AbstractIndexInitParams & commonParams) : VecSimIndexAbstract{commonParams}, - res_{}, //TODO(wphicks): Construct smartly + res_{raft::resource_manager::get_device_resources()}, build_params_{[ivfParams](){ auto result = ivfParams->pqBits > 0 ? build_params_t{std::in_place_index<1>} : @@ -132,7 +132,7 @@ public: bool overwrite_allowed = true) override { return addVectorBatch(vector_data, &label, 1, overwrite_allowed); } - auto addVectorBatch(const void *vector_data, labelType *label, size_t batch_size, + auto addVectorBatchAsync(const void *vector_data, labelType *label, size_t batch_size, bool overwrite_allowed = true) { // Allocate memory on device to hold vectors to be added auto vector_data_gpu = @@ -189,10 +189,16 @@ public: ); } - // Ensure that above operations have executed on device before + return batch_size; + } + auto addVectorBatch(const void *vector_data, labelType *label, size_t batch_size, + bool overwrite_allowed = true) { + auto result = addVectorBatchAsync(vector_data, label, batch_size, + overwrite_allowed); + // Ensure that above operation has executed on device before // returning from this function on host res_.sync_stream(); - return batch_size; + return result; } auto deleteVector(labelType label) override { assert(!"deleteVector not implemented"); diff --git a/src/VecSim/algorithms/ivf/ivf_tiered.cuh b/src/VecSim/algorithms/ivf/ivf_tiered.cuh index 03e86f37f..b90c5aa6d 100644 --- a/src/VecSim/algorithms/ivf/ivf_tiered.cuh +++ b/src/VecSim/algorithms/ivf/ivf_tiered.cuh @@ -1,4 +1,116 @@ +#include +#include "VecSim/algorithms/ivf/ivf.cuh" #include "VecSim/vec_sim_tiered_index.h" + +struct RAFTTransferJob : public AsyncJob { + bool overwrite_allowed{true}; + RAFTTransferJob( + std::shared_ptr allocator, bool overwrite_allowed_, + JobCallback insertCb, + VecSimIndex *index_ + ) : AsyncJob{allocator, RAFT_TRANSFER_JOB, insertCb, index_}, + overwrite_allowed{overwrite_allowed_} {} +}; + template struct TieredIVFIndex : public VecSimTieredIndex { + auto addVector( + const void* blob, labelType label, bool overwrite_allowed + ) { + auto frontend_lock = std::scoped_lock(this->flatIndexGuard); + auto result = this->frontendIndex->addVector( + blob, label, overwrite_allowed + ); + if (this->frontendIndex->indexSize() >= this->flatBufferLimit) { + transferToBackend(overwrite_allowed); + } + return result; + } + + auto deleteVector(labelType label) { + // TODO(wphicks) + // If in flatIndex, delete + // If being transferred to backend, wait for transfer + // If in backendIndex, delete + } + + auto indexSize() { + auto frontend_lock = std::scoped_lock(this->flatIndexGuard); + auto backend_lock = std::scoped_lock(this->mainIndexGuard); + return ( + getBackendIndex().indexSize() + this->frontendIndex.indexSize() + ); + } + + auto indexLabelCount() const override { + // TODO(wphicks) Count unique labels between both indexes + } + + auto indexCapacity() const override { + return ( + getBackendIndex().indexCapacity() + + this->flatBufferLimit + ); + } + + void increaseCapacity() override { + getBackendIndex().increaseCapacity(); + } + + auto getDistanceFrom(labelType label, const void* blob) { + auto frontend_lock = std::unique_lock(this->flatIndexGuard); + auto flat_dist = this->frontendIndex->getDistanceFrom(label, blob); + frontend_lock.unlock(); + auto backend_lock = std::scoped_lock(this->mainIndexGuard); + auto raft_dist = getBackendIndex().getDistanceFrom(label, blob); + return std::fmin(flat_dist, raft_dist); + } + + void executeTransferJob(RAFTTransferJob* job) { + transferToBackend(job->overwrite_allowed); + } + + private: + vecsim_stl::unordered_map> labelToTransferJobs; + auto& getBackendIndex() { + return *dynamic_cast*>( + this->backendIndex + ); + } + + void transferToBackend(overwrite_allowed=true) { + auto dim = this->index->getDim(); + auto frontend_lock = std::unique_lock(this->flatIndexGuard); + auto nVectors = this->flatBuffer->indexSize(); + const auto &vectorBlocks = this->flatBuffer->getVectorBlocks(); + auto vectorData = raft::make_host_matrix( + getBackendIndex().get_resources(), + nVectors, + dim + ); + + auto* out = vectorData.data_handle(); + for (auto block_id = 0; block_id < vectorBlocks.size(); ++block_id) { + auto* in_begin = reinterpret_cast( + vectorBlocks[block_id].getElement(0) + ); + auto length = vectorBlocks[block_id].getLength(); + std::copy( + in_begin, + in_begin + length, + out + ); + out += length; + } + + auto backend_lock = std::scoped_lock(this->mainIndexGuard); + this->flatBuffer->clear(); + frontend_lock.unlock(); + getBackendIndex().addVectorBatch( + vectorData.data_handle(), + this->flatBuffer->getLabels(), + nVectors, + overwrite_allowed + ); + } }; diff --git a/src/VecSim/vec_sim_common.h b/src/VecSim/vec_sim_common.h index 032e6d98b..eab95bacb 100644 --- a/src/VecSim/vec_sim_common.h +++ b/src/VecSim/vec_sim_common.h @@ -194,6 +194,7 @@ typedef enum { HNSW_REPAIR_NODE_CONNECTIONS_JOB, HNSW_SEARCH_JOB, HNSW_SWAP_JOB, + RAFT_TRANSFER_JOB, INVALID_JOB // to indicate that finding a JobType >= INVALID_JOB is an error } JobType; From 44d6b15f41c3c204df30306116481fa417b9731e Mon Sep 17 00:00:00 2001 From: William Hicks Date: Fri, 4 Aug 2023 17:37:26 -0400 Subject: [PATCH 07/28] Update style --- .../algorithms/brute_force/brute_force.h | 6 +- src/VecSim/algorithms/ivf/ivf.cuh | 382 ++++++++---------- src/VecSim/algorithms/ivf/ivf_tiered.cuh | 157 +++---- src/VecSim/vec_sim_common.h | 31 +- 4 files changed, 242 insertions(+), 334 deletions(-) diff --git a/src/VecSim/algorithms/brute_force/brute_force.h b/src/VecSim/algorithms/brute_force/brute_force.h index 5c3747432..5c602777d 100644 --- a/src/VecSim/algorithms/brute_force/brute_force.h +++ b/src/VecSim/algorithms/brute_force/brute_force.h @@ -36,9 +36,9 @@ class BruteForceIndex : public VecSimIndexAbstract { BruteForceIndex(const BFParams *params, const AbstractIndexInitParams &abstractInitParams); void clear() { - idToLabelMapping.clear(); - vectorBlocks.clear(); - count = idType{}; + idToLabelMapping.clear(); + vectorBlocks.clear(); + count = idType{}; } size_t indexSize() const override; size_t indexCapacity() const override; diff --git a/src/VecSim/algorithms/ivf/ivf.cuh b/src/VecSim/algorithms/ivf/ivf.cuh index cc337e4a9..31d87c024 100644 --- a/src/VecSim/algorithms/ivf/ivf.cuh +++ b/src/VecSim/algorithms/ivf/ivf.cuh @@ -27,178 +27,153 @@ inline auto constexpr GetRaftDistanceType(VecSimMetric vsm) { auto result = raft::distance::DistanceType{}; switch (vsm) { - case VecSimMetric_L2: + case VecSimMetric_L2: result = raft::distance::DistanceType::L2Expanded; break; - case VecSimMetric_IP: + case VecSimMetric_IP: result = raft::distance::DistanceType::InnerProduct; break; - default: + default: throw raft::exception("Metric not supported"); } return result; } inline auto constexpr GetRaftCodebookKind(IVFPQCodebookKind vss_codebook) { - auto result = raft::neighbors::ivf_pq::codebook_gen{}; - switch(vss_codebook) { - case IVFPQCodebookKind_PerCluster: - result = raft::neighbors::ivf_pq::codebook_gen::PER_CLUSTER; - break; + auto result = raft::neighbors::ivf_pq::codebook_gen{}; + switch (vss_codebook) { + case IVFPQCodebookKind_PerCluster: + result = raft::neighbors::ivf_pq::codebook_gen::PER_CLUSTER; + break; case IVFPQCodebookKind_PerSubspace: - result = raft::neighbors::ivf_pq::codebook_gen::PER_SUBSPACE; - break; + result = raft::neighbors::ivf_pq::codebook_gen::PER_SUBSPACE; + break; default: - throw raft::exception("Unexpected IVFPQ codebook kind"); - } - return result; + throw raft::exception("Unexpected IVFPQ codebook kind"); + } + return result; } inline auto constexpr GetCudaType(CudaType vss_type) { - auto result = cudaDataType_t{}; - switch(vss_type) { - case CUDAType_R_32F: - result = CUDA_R_32F; - break; + auto result = cudaDataType_t{}; + switch (vss_type) { + case CUDAType_R_32F: + result = CUDA_R_32F; + break; case CUDAType_R_16F: - result = CUDA_R_16F; - break; + result = CUDA_R_16F; + break; case CUDAType_R_8U: - result = CUDA_R_8U; - break; + result = CUDA_R_8U; + break; default: - throw raft::exception("Unexpected CUDA type"); - } - return result; + throw raft::exception("Unexpected CUDA type"); + } + return result; } -template +template struct IVFIndex : public VecSimIndexAbstract { using data_type = DataType; using dist_type = DistType; private: // Allow either IVF-flat or IVFPQ parameters - using build_params_t = std::variant; - using search_params_t = std::variant; + using build_params_t = std::variant; + using search_params_t = std::variant; using internal_idx_t = std::uint32_t; - using ann_index_t = std::variant, raft::neighbors::ivf_pq::index>; + using ann_index_t = std::variant, + raft::neighbors::ivf_pq::index>; public: - IVFIndex(const IVFParams *ivfParams, const AbstractIndexInitParams & commonParams) + IVFIndex(const IVFParams *ivfParams, const AbstractIndexInitParams &commonParams) : VecSimIndexAbstract{commonParams}, - res_{raft::resource_manager::get_device_resources()}, - build_params_{[ivfParams](){ - auto result = ivfParams->pqBits > 0 ? - build_params_t{std::in_place_index<1>} : - build_params_t{std::in_place_index<0>}; - std::visit( - [ivfParams](auto&& inner) { - inner.metric = GetRaftDistanceType(ivfParams->metric); - inner.n_lists = ivfParams->nLists; - inner.kmeans_n_iters = ivfParams->kmeans_nIters; - inner.kmeans_trainset_fraction = ivfParams->kmeans_trainsetFraction; - inner.conservative_memory_allocation = ivfParams->conservativeMemoryAllocation; - if constexpr (std::is_same_v) { - inner.pq_bits = ivfParams->pqBits; - inner.pq_dim = ivfParams->pqDim; - inner.codebook_kind = GetRaftCodebookKind(ivfParams->codebookKind); - } else { - inner.adaptive_centers = ivfParams->adaptiveCenters; - } - }, result - ); - return result; + res_{raft::resource_manager::get_device_resources()}, build_params_{[ivfParams]() { + auto result = ivfParams->pqBits > 0 ? build_params_t{std::in_place_index<1>} + : build_params_t{std::in_place_index<0>}; + std::visit( + [ivfParams](auto &&inner) { + inner.metric = GetRaftDistanceType(ivfParams->metric); + inner.n_lists = ivfParams->nLists; + inner.kmeans_n_iters = ivfParams->kmeans_nIters; + inner.kmeans_trainset_fraction = ivfParams->kmeans_trainsetFraction; + inner.conservative_memory_allocation = + ivfParams->conservativeMemoryAllocation; + if constexpr (std::is_same_v) { + inner.pq_bits = ivfParams->pqBits; + inner.pq_dim = ivfParams->pqDim; + inner.codebook_kind = GetRaftCodebookKind(ivfParams->codebookKind); + } else { + inner.adaptive_centers = ivfParams->adaptiveCenters; + } + }, + result); + return result; }()}, - search_params_{[ivfParams](){ - auto result = ivfParams->pqBits > 0 ? - search_params_t{std::in_place_index<1>} : - search_params_t{std::in_place_index<0>}; - std::visit( - [ivfParams](auto&& inner) { - inner.n_probes = ivfParams->nProbes; - if constexpr (std::is_same_v) { - inner.lut_dtype = GetCudaType(ivfParams->lutType); - inner.internal_distance_dtype = GetCudaType(ivfParams->internalDistanceType); - inner.preferred_shmem_carvout = ivfParams->preferredShmemCarveout; - } - }, result - ); - return result; + search_params_{[ivfParams]() { + auto result = ivfParams->pqBits > 0 ? search_params_t{std::in_place_index<1>} + : search_params_t{std::in_place_index<0>}; + std::visit( + [ivfParams](auto &&inner) { + inner.n_probes = ivfParams->nProbes; + if constexpr (std::is_same_v) { + inner.lut_dtype = GetCudaType(ivfParams->lutType); + inner.internal_distance_dtype = + GetCudaType(ivfParams->internalDistanceType); + inner.preferred_shmem_carvout = ivfParams->preferredShmemCarveout; + } + }, + result); + return result; }()}, - index_{std::nullopt} - {} + index_{std::nullopt} {} auto addVector(const void *vector_data, labelType label, - bool overwrite_allowed = true) override { + bool overwrite_allowed = true) override { return addVectorBatch(vector_data, &label, 1, overwrite_allowed); } auto addVectorBatchAsync(const void *vector_data, labelType *label, size_t batch_size, - bool overwrite_allowed = true) { - // Allocate memory on device to hold vectors to be added - auto vector_data_gpu = - raft::make_device_matrix(res_, batch_size, - this->dim); - // Allocate memory on device to hold vector labels - auto label_gpu = raft::make_device_vector( - res_, - batch_size - ); + bool overwrite_allowed = true) { + // Allocate memory on device to hold vectors to be added + auto vector_data_gpu = + raft::make_device_matrix(res_, batch_size, this->dim); + // Allocate memory on device to hold vector labels + auto label_gpu = raft::make_device_vector(res_, batch_size); - // Copy vector data to previously allocated device buffer - raft::copy( - vector_data_gpu.data_handle(), - static_cast(vector_data), - this->dim * batch_size, - res_.get_stream() - ); - // Copy label data to previously allocated device buffer - raft::copy( - label_gpu.data_handle(), - label, - batch_size, - res_.get_stream() - ); + // Copy vector data to previously allocated device buffer + raft::copy(vector_data_gpu.data_handle(), static_cast(vector_data), + this->dim * batch_size, res_.get_stream()); + // Copy label data to previously allocated device buffer + raft::copy(label_gpu.data_handle(), label, batch_size, res_.get_stream()); - if (std::holds_alternative(build_params_)) { - if (!index_) { - index_ = raft::neighbors::ivf_flat::build( - res_, - std::get(build_params_), - vector_data_gpu.view() - ); - } - raft::neighbors::ivf_flat::extend( - res_, - vector_data_gpu.view(), - label_gpu, - *index_ - ); - } else { - if (!index_) { - index_ = raft::neighbors::ivf_pq::build( - res_, - std::get(build_params_), - vector_data_gpu.view() - ); + if (std::holds_alternative(build_params_)) { + if (!index_) { + index_ = raft::neighbors::ivf_flat::build( + res_, std::get(build_params_), + vector_data_gpu.view()); + } + raft::neighbors::ivf_flat::extend(res_, vector_data_gpu.view(), label_gpu, *index_); + } else { + if (!index_) { + index_ = raft::neighbors::ivf_pq::build( + res_, std::get(build_params_), + vector_data_gpu.view()); + } + raft::neighbors::ivf_pq::extend(res_, vector_data_gpu.view(), label_gpu, *index_); } - raft::neighbors::ivf_pq::extend( - res_, - vector_data_gpu.view(), - label_gpu, - *index_ - ); - } - return batch_size; + return batch_size; } auto addVectorBatch(const void *vector_data, labelType *label, size_t batch_size, - bool overwrite_allowed = true) { - auto result = addVectorBatchAsync(vector_data, label, batch_size, - overwrite_allowed); - // Ensure that above operation has executed on device before - // returning from this function on host - res_.sync_stream(); - return result; + bool overwrite_allowed = true) { + auto result = addVectorBatchAsync(vector_data, label, batch_size, overwrite_allowed); + // Ensure that above operation has executed on device before + // returning from this function on host + res_.sync_stream(); + return result; } auto deleteVector(labelType label) override { assert(!"deleteVector not implemented"); @@ -212,93 +187,62 @@ public: assert(!"indexCapacity not implemented"); return 0; } - void increaseCapacity() override { - assert(!"increaseCapacity not implemented"); - } + void increaseCapacity() override { assert(!"increaseCapacity not implemented"); } inline auto indexLabelCount() const override { return this->indexSize(); // TODO: Return unique counts } - auto topKQuery(const void *queryBlob, size_t k, - VecSimQueryParams *queryParams) override { - auto result_list = VecSimQueryResult_List{0}; - auto nVectors = this->indexSize(); - if (nVectors == 0) { - result_list.results = array_new(0); - } else { - // Ensure we are not trying to retrieve more vectors than exist in the - // index - k = std::min(k, nVectors); - // Allocate memory on device for search vector - auto vector_data_gpu = raft::make_device_matrix(res_,1, this->dim); - // Allocate memory on device for neighbor results - auto neighbors_gpu = raft::make_device_vector(res_, k); - // Allocate memory on device for distance results - auto distances_gpu = raft::make_device_vector(res_, k); - // Copy query vector to device - raft::copy( - vector_data_gpu.data_handle(), - static_cast(queryBlob), - this->dim, - res_.get_stream() - ); - - // Perform correct search based on index type - if ( - std::holds_alternative(index_) - ) { - raft::neighbors::ivf_flat::search( - res_, - std::get(search_params_), - std::get(*index_), - vector_data_gpu.view(), - neighbors_gpu.view(), - distances_gpu.view() - ) + auto topKQuery(const void *queryBlob, size_t k, VecSimQueryParams *queryParams) override { + auto result_list = VecSimQueryResult_List{0}; + auto nVectors = this->indexSize(); + if (nVectors == 0) { + result_list.results = array_new(0); } else { - raft::neighbors::ivf_pq::search( - res_, - std::get(search_params_), - std::get(*index_), - vector_data_gpu.view(), - neighbors_gpu.view(), - distances_gpu.view() - ) - } + // Ensure we are not trying to retrieve more vectors than exist in the + // index + k = std::min(k, nVectors); + // Allocate memory on device for search vector + auto vector_data_gpu = raft::make_device_matrix(res_, 1, this->dim); + // Allocate memory on device for neighbor results + auto neighbors_gpu = raft::make_device_vector(res_, k); + // Allocate memory on device for distance results + auto distances_gpu = raft::make_device_vector(res_, k); + // Copy query vector to device + raft::copy(vector_data_gpu.data_handle(), static_cast(queryBlob), this->dim, + res_.get_stream()); - // Allocate host buffers to hold returned results - auto neighbors = std::unique_ptr( - array_new_len(k, k), - &array_free - ); - auto distances = std::unique_ptr( - array_new_len(k, k), - &array_free - ); - // Copy data back from device to host - raft::copy( - neighbors.get(), - neighbors_gpu.data_handle(), - this->dim, - res_.get_stream() - ); - raft::copy( - distances.get(), - distances_gpu.data_handle(), - this->dim, - res_.get_stream() - ); + // Perform correct search based on index type + if (std::holds_alternative(index_)) { + raft::neighbors::ivf_flat::search( + res_, std::get(search_params_), + std::get(*index_), vector_data_gpu.view(), + neighbors_gpu.view(), distances_gpu.view()) + } else { + raft::neighbors::ivf_pq::search( + res_, std::get(search_params_), + std::get(*index_), vector_data_gpu.view(), + neighbors_gpu.view(), distances_gpu.view()) + } - result_list.results = array_new_len(k, k); + // Allocate host buffers to hold returned results + auto neighbors = + std::unique_ptr(array_new_len(k, k), &array_free); + auto distances = + std::unique_ptr(array_new_len(k, k), &array_free); + // Copy data back from device to host + raft::copy(neighbors.get(), neighbors_gpu.data_handle(), this->dim, res_.get_stream()); + raft::copy(distances.get(), distances_gpu.data_handle(), this->dim, res_.get_stream()); - // Ensure search is complete and data have been copied back before - // building query result objects on host - res_.sync_stream(); - for (size_t i = 0; i < k; ++i) { - VecSimQueryResult_SetId(result_list.results[i], neighbors[i]); - VecSimQueryResult_SetScore(result_list.results[i], distances[i]); + result_list.results = array_new_len(k, k); + + // Ensure search is complete and data have been copied back before + // building query result objects on host + res_.sync_stream(); + for (size_t i = 0; i < k; ++i) { + VecSimQueryResult_SetId(result_list.results[i], neighbors[i]); + VecSimQueryResult_SetScore(result_list.results[i], distances[i]); + } } - } - return result_list; + return result_list; } VecSimQueryResult_List rangeQuery(const void *queryBlob, double radius, @@ -314,25 +258,21 @@ public: assert(!"preferAdHocSearch not implemented"); } - auto& get_resources() const { return res_; } + auto &get_resources() const { return res_; } auto nLists() { - return std::visit([](auto&& params) { - return params.n_list; - }, build_params_); + return std::visit([](auto &¶ms) { return params.n_list; }, build_params_); } auto indexSize() { - auto result = size_t{}; - if (index_) { - result = std::visit([](auto&& index) { - return index.size(); - }, *index_); - } - return result; + auto result = size_t{}; + if (index_) { + result = std::visit([](auto &&index) { return index.size(); }, *index_); + } + return result; } - private: +private: // An object used to manage common device resources that may be // expensive to build but frequently accessed raft::device_resources res_; diff --git a/src/VecSim/algorithms/ivf/ivf_tiered.cuh b/src/VecSim/algorithms/ivf/ivf_tiered.cuh index b90c5aa6d..d4f5b1641 100644 --- a/src/VecSim/algorithms/ivf/ivf_tiered.cuh +++ b/src/VecSim/algorithms/ivf/ivf_tiered.cuh @@ -3,114 +3,83 @@ #include "VecSim/vec_sim_tiered_index.h" struct RAFTTransferJob : public AsyncJob { - bool overwrite_allowed{true}; - RAFTTransferJob( - std::shared_ptr allocator, bool overwrite_allowed_, - JobCallback insertCb, - VecSimIndex *index_ - ) : AsyncJob{allocator, RAFT_TRANSFER_JOB, insertCb, index_}, - overwrite_allowed{overwrite_allowed_} {} + bool overwrite_allowed{true}; + RAFTTransferJob(std::shared_ptr allocator, bool overwrite_allowed_, + JobCallback insertCb, VecSimIndex *index_) + : AsyncJob{allocator, RAFT_TRANSFER_JOB, insertCb, index_}, + overwrite_allowed{overwrite_allowed_} {} }; template struct TieredIVFIndex : public VecSimTieredIndex { - auto addVector( - const void* blob, labelType label, bool overwrite_allowed - ) { - auto frontend_lock = std::scoped_lock(this->flatIndexGuard); - auto result = this->frontendIndex->addVector( - blob, label, overwrite_allowed - ); - if (this->frontendIndex->indexSize() >= this->flatBufferLimit) { - transferToBackend(overwrite_allowed); + auto addVector(const void *blob, labelType label, bool overwrite_allowed) { + auto frontend_lock = std::scoped_lock(this->flatIndexGuard); + auto result = this->frontendIndex->addVector(blob, label, overwrite_allowed); + if (this->frontendIndex->indexSize() >= this->flatBufferLimit) { + transferToBackend(overwrite_allowed); + } + return result; } - return result; - } - auto deleteVector(labelType label) { - // TODO(wphicks) - // If in flatIndex, delete - // If being transferred to backend, wait for transfer - // If in backendIndex, delete - } + auto deleteVector(labelType label) { + // TODO(wphicks) + // If in flatIndex, delete + // If being transferred to backend, wait for transfer + // If in backendIndex, delete + } + + auto indexSize() { + auto frontend_lock = std::scoped_lock(this->flatIndexGuard); + auto backend_lock = std::scoped_lock(this->mainIndexGuard); + return (getBackendIndex().indexSize() + this->frontendIndex.indexSize()); + } - auto indexSize() { - auto frontend_lock = std::scoped_lock(this->flatIndexGuard); - auto backend_lock = std::scoped_lock(this->mainIndexGuard); - return ( - getBackendIndex().indexSize() + this->frontendIndex.indexSize() - ); - } + auto indexLabelCount() const override { + // TODO(wphicks) Count unique labels between both indexes + } - auto indexLabelCount() const override { - // TODO(wphicks) Count unique labels between both indexes - } + auto indexCapacity() const override { + return (getBackendIndex().indexCapacity() + this->flatBufferLimit); + } - auto indexCapacity() const override { - return ( - getBackendIndex().indexCapacity() + - this->flatBufferLimit - ); - } + void increaseCapacity() override { getBackendIndex().increaseCapacity(); } - void increaseCapacity() override { - getBackendIndex().increaseCapacity(); - } + auto getDistanceFrom(labelType label, const void *blob) { + auto frontend_lock = std::unique_lock(this->flatIndexGuard); + auto flat_dist = this->frontendIndex->getDistanceFrom(label, blob); + frontend_lock.unlock(); + auto backend_lock = std::scoped_lock(this->mainIndexGuard); + auto raft_dist = getBackendIndex().getDistanceFrom(label, blob); + return std::fmin(flat_dist, raft_dist); + } - auto getDistanceFrom(labelType label, const void* blob) { - auto frontend_lock = std::unique_lock(this->flatIndexGuard); - auto flat_dist = this->frontendIndex->getDistanceFrom(label, blob); - frontend_lock.unlock(); - auto backend_lock = std::scoped_lock(this->mainIndexGuard); - auto raft_dist = getBackendIndex().getDistanceFrom(label, blob); - return std::fmin(flat_dist, raft_dist); - } + void executeTransferJob(RAFTTransferJob *job) { transferToBackend(job->overwrite_allowed); } - void executeTransferJob(RAFTTransferJob* job) { - transferToBackend(job->overwrite_allowed); - } +private: + vecsim_stl::unordered_map> labelToTransferJobs; + auto &getBackendIndex() { + return *dynamic_cast *>(this->backendIndex); + } - private: - vecsim_stl::unordered_map> labelToTransferJobs; - auto& getBackendIndex() { - return *dynamic_cast*>( - this->backendIndex - ); - } + void transferToBackend(overwrite_allowed = true) { + auto dim = this->index->getDim(); + auto frontend_lock = std::unique_lock(this->flatIndexGuard); + auto nVectors = this->flatBuffer->indexSize(); + const auto &vectorBlocks = this->flatBuffer->getVectorBlocks(); + auto vectorData = raft::make_host_matrix(getBackendIndex().get_resources(), nVectors, dim); - void transferToBackend(overwrite_allowed=true) { - auto dim = this->index->getDim(); - auto frontend_lock = std::unique_lock(this->flatIndexGuard); - auto nVectors = this->flatBuffer->indexSize(); - const auto &vectorBlocks = this->flatBuffer->getVectorBlocks(); - auto vectorData = raft::make_host_matrix( - getBackendIndex().get_resources(), - nVectors, - dim - ); + auto *out = vectorData.data_handle(); + for (auto block_id = 0; block_id < vectorBlocks.size(); ++block_id) { + auto *in_begin = reinterpret_cast(vectorBlocks[block_id].getElement(0)); + auto length = vectorBlocks[block_id].getLength(); + std::copy(in_begin, in_begin + length, out); + out += length; + } - auto* out = vectorData.data_handle(); - for (auto block_id = 0; block_id < vectorBlocks.size(); ++block_id) { - auto* in_begin = reinterpret_cast( - vectorBlocks[block_id].getElement(0) - ); - auto length = vectorBlocks[block_id].getLength(); - std::copy( - in_begin, - in_begin + length, - out - ); - out += length; + auto backend_lock = std::scoped_lock(this->mainIndexGuard); + this->flatBuffer->clear(); + frontend_lock.unlock(); + getBackendIndex().addVectorBatch(vectorData.data_handle(), this->flatBuffer->getLabels(), + nVectors, overwrite_allowed); } - - auto backend_lock = std::scoped_lock(this->mainIndexGuard); - this->flatBuffer->clear(); - frontend_lock.unlock(); - getBackendIndex().addVectorBatch( - vectorData.data_handle(), - this->flatBuffer->getLabels(), - nVectors, - overwrite_allowed - ); - } }; diff --git a/src/VecSim/vec_sim_common.h b/src/VecSim/vec_sim_common.h index eab95bacb..1f346e250 100644 --- a/src/VecSim/vec_sim_common.h +++ b/src/VecSim/vec_sim_common.h @@ -48,7 +48,6 @@ typedef enum { IVFPQCodebookKind_PerCluster, IVFPQCodebookKind_PerSubspace } IVF // CUDA types supported by GPU-accelerated indexes typedef enum { CUDAType_R_32F, CUDAType_R_16F, CUDAType_R_8U } CudaType; - typedef size_t labelType; typedef unsigned int idType; @@ -138,17 +137,17 @@ typedef struct { } TieredIndexParams; typedef struct { - VecSimType type; // Datatype to index. - size_t dim; // Vector's dimension. - VecSimMetric metric; // Distance metric to use in the index. - bool multi; // Determines if the index should multi-index or not. - size_t nLists; // Number of inverted lists - bool conservativeMemoryAllocation; // Use as little GPU memory as possible - size_t kmeans_nIters; // Iterations for kmeans calculation - double kmeans_trainsetFraction; // Fraction of dataset used for kmeans training - unsigned nProbes; // The number of clusters to search - size_t pqBits; // Bit length of vector element after PQ compression. If set - // to 0, IVF flat will be used instead of IVFPQ. + VecSimType type; // Datatype to index. + size_t dim; // Vector's dimension. + VecSimMetric metric; // Distance metric to use in the index. + bool multi; // Determines if the index should multi-index or not. + size_t nLists; // Number of inverted lists + bool conservativeMemoryAllocation; // Use as little GPU memory as possible + size_t kmeans_nIters; // Iterations for kmeans calculation + double kmeans_trainsetFraction; // Fraction of dataset used for kmeans training + unsigned nProbes; // The number of clusters to search + size_t pqBits; // Bit length of vector element after PQ compression. If set + // to 0, IVF flat will be used instead of IVFPQ. // ***************** IVF-Flat-only parameters ****************** // The following parameters will be ignored if pqBits is set to a // non-zero value. @@ -164,13 +163,13 @@ typedef struct { CudaType lutType; CudaType internalDistanceType; double preferredShmemCarveout; // Fraction of GPU's unified memory / L1 - // cache to be used as shared memory + // cache to be used as shared memory } IVFParams; typedef struct { - IVFParams ivfParams; - TieredIndexParams tieredParams; + IVFParams ivfParams; + TieredIndexParams tieredParams; } TieredIVFParams; typedef union { @@ -279,7 +278,7 @@ typedef struct { typedef struct { size_t nLists; // Number of inverted lists. - size_t pqDim; // Dimensionality of encoded vector after PQ + size_t pqDim; // Dimensionality of encoded vector after PQ size_t pqBits; // Bits per encoded vector element after PQ } ivfInfoStruct; From 4d66fe85e956f9361e5e076f960fa7f484a91566 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 10 Oct 2023 23:01:50 +0200 Subject: [PATCH 08/28] Rename ivf, add factory --- cmake/raft.cmake | 8 +-- .../algorithms/{ivf => raft_ivf}/ivf.cuh | 49 ++++++------- .../{ivf => raft_ivf}/ivf_tiered.cuh | 2 + .../index_factories/raft_ivf_factory.cu | 71 +++++++++++++++++++ src/VecSim/index_factories/raft_ivf_factory.h | 18 +++++ src/VecSim/vec_sim_common.h | 26 +++---- 6 files changed, 128 insertions(+), 46 deletions(-) rename src/VecSim/algorithms/{ivf => raft_ivf}/ivf.cuh (87%) rename src/VecSim/algorithms/{ivf => raft_ivf}/ivf_tiered.cuh (99%) create mode 100644 src/VecSim/index_factories/raft_ivf_factory.cu create mode 100644 src/VecSim/index_factories/raft_ivf_factory.h diff --git a/cmake/raft.cmake b/cmake/raft.cmake index 1b48cb34f..46181397f 100644 --- a/cmake/raft.cmake +++ b/cmake/raft.cmake @@ -4,12 +4,8 @@ if(USE_CUDA) # Set which version of RAFT to use (defined separately for testing # minimal dependency changes if necessary) set(RAFT_VERSION "${RAPIDS_VERSION}") - # TODO(wphicks): Reset to main fork and branch after - # https://github.com/rapidsai/raft/pull/1716 has been merged - # set(RAFT_FORK "rapidsai") - set(RAFT_FORK "wphicks") - # set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}") - set(RAFT_PINNED_TAG "fea-resource_manager") + set(RAFT_FORK "rapidsai") + set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}") # Download CMake file for bootstrapping RAPIDS-CMake, a utility that # simplifies handling of complex RAPIDS dependencies diff --git a/src/VecSim/algorithms/ivf/ivf.cuh b/src/VecSim/algorithms/raft_ivf/ivf.cuh similarity index 87% rename from src/VecSim/algorithms/ivf/ivf.cuh rename to src/VecSim/algorithms/raft_ivf/ivf.cuh index 31d87c024..289997d64 100644 --- a/src/VecSim/algorithms/ivf/ivf.cuh +++ b/src/VecSim/algorithms/raft_ivf/ivf.cuh @@ -1,3 +1,5 @@ +#pragma once + #include #include #include @@ -7,7 +9,7 @@ #include #include #include "VecSim/vec_sim.h" -// For VecSimMetric, IVFParams, labelType +// For VecSimMetric, RaftIvfParams, labelType #include "VecSim/vec_sim_common.h" // For VecSimIndexAbstract #include "VecSim/vec_sim_index.h" @@ -22,7 +24,6 @@ #include #include -#pragma once inline auto constexpr GetRaftDistanceType(VecSimMetric vsm) { auto result = raft::distance::DistanceType{}; @@ -88,43 +89,43 @@ private: raft::neighbors::ivf_pq::index>; public: - IVFIndex(const IVFParams *ivfParams, const AbstractIndexInitParams &commonParams) + IVFIndex(const RaftIvfParams *raftIvfParams, const AbstractIndexInitParams &commonParams) : VecSimIndexAbstract{commonParams}, - res_{raft::resource_manager::get_device_resources()}, build_params_{[ivfParams]() { - auto result = ivfParams->pqBits > 0 ? build_params_t{std::in_place_index<1>} - : build_params_t{std::in_place_index<0>}; + res_{raft::resource_manager::get_device_resources()}, build_params_{[raftIvfParams]() { + auto result = raftIvfParams->usePQ ? build_params_t{std::in_place_index<1>} + : build_params_t{std::in_place_index<0>}; std::visit( - [ivfParams](auto &&inner) { - inner.metric = GetRaftDistanceType(ivfParams->metric); - inner.n_lists = ivfParams->nLists; - inner.kmeans_n_iters = ivfParams->kmeans_nIters; - inner.kmeans_trainset_fraction = ivfParams->kmeans_trainsetFraction; + [raftIvfParams](auto &&inner) { + inner.metric = GetRaftDistanceType(raftIvfParams->metric); + inner.n_lists = raftIvfParams->nLists; + inner.kmeans_n_iters = raftIvfParams->kmeans_nIters; + inner.kmeans_trainset_fraction = raftIvfParams->kmeans_trainsetFraction; inner.conservative_memory_allocation = - ivfParams->conservativeMemoryAllocation; + raftIvfParams->conservativeMemoryAllocation; if constexpr (std::is_same_v) { - inner.pq_bits = ivfParams->pqBits; - inner.pq_dim = ivfParams->pqDim; - inner.codebook_kind = GetRaftCodebookKind(ivfParams->codebookKind); + inner.pq_bits = raftIvfParams->pqBits; + inner.pq_dim = raftIvfParams->pqDim; + inner.codebook_kind = GetRaftCodebookKind(raftIvfParams->codebookKind); } else { - inner.adaptive_centers = ivfParams->adaptiveCenters; + inner.adaptive_centers = raftIvfParams->adaptiveCenters; } }, result); return result; }()}, - search_params_{[ivfParams]() { - auto result = ivfParams->pqBits > 0 ? search_params_t{std::in_place_index<1>} - : search_params_t{std::in_place_index<0>}; + search_params_{[raftIvfParams]() { + auto result = raftIvfParams->usePQ ? search_params_t{std::in_place_index<1>} + : search_params_t{std::in_place_index<0>}; std::visit( - [ivfParams](auto &&inner) { - inner.n_probes = ivfParams->nProbes; + [raftIvfParams](auto &&inner) { + inner.n_probes = raftIvfParams->nProbes; if constexpr (std::is_same_v) { - inner.lut_dtype = GetCudaType(ivfParams->lutType); + inner.lut_dtype = GetCudaType(raftIvfParams->lutType); inner.internal_distance_dtype = - GetCudaType(ivfParams->internalDistanceType); - inner.preferred_shmem_carvout = ivfParams->preferredShmemCarveout; + GetCudaType(raftIvfParams->internalDistanceType); + inner.preferred_shmem_carvout = raftIvfParams->preferredShmemCarveout; } }, result); diff --git a/src/VecSim/algorithms/ivf/ivf_tiered.cuh b/src/VecSim/algorithms/raft_ivf/ivf_tiered.cuh similarity index 99% rename from src/VecSim/algorithms/ivf/ivf_tiered.cuh rename to src/VecSim/algorithms/raft_ivf/ivf_tiered.cuh index d4f5b1641..fb3fee39d 100644 --- a/src/VecSim/algorithms/ivf/ivf_tiered.cuh +++ b/src/VecSim/algorithms/raft_ivf/ivf_tiered.cuh @@ -1,3 +1,5 @@ +#pragma once + #include #include "VecSim/algorithms/ivf/ivf.cuh" #include "VecSim/vec_sim_tiered_index.h" diff --git a/src/VecSim/index_factories/raft_ivf_factory.cu b/src/VecSim/index_factories/raft_ivf_factory.cu new file mode 100644 index 000000000..83c842fe3 --- /dev/null +++ b/src/VecSim/index_factories/raft_ivf_factory.cu @@ -0,0 +1,71 @@ +#include "VecSim/index_factories/brute_force_factory.h" +#include "VecSim/algorithms/raft_ivf/ivf.cuh" + +namespace RaftIVFFactory { + +static AbstractIndexInitParams NewAbstractInitParams(const VecSimParams *params) { + + const RaftIvfParams *raftIvfParams = ¶ms->algoParams.raftIvfParams; + AbstractIndexInitParams abstractInitParams = {.allocator = + VecSimAllocator::newVecsimAllocator(), + .dim = bfParams->dim, + .vecType = bfParams->type, + .metric = bfParams->metric, + .blockSize = bfParams->blockSize, + .multi = bfParams->multi, + .logCtx = params->logCtx}; + return abstractInitParams; +} + +VecSimIndex *NewIndex(const VecSimParams *params) { + const RaftIvfParams *raftIvfParams = ¶ms->algoParams.raftIvfParams; + AbstractIndexInitParams abstractInitParams = NewAbstractInitParams(params); + return NewIndex(raftIvfParams, NewAbstractInitParams(params)); +} + +VecSimIndex *NewIndex(const RaftIvfParams *raftIvfParams, const AbstractIndexInitParams &abstractInitParams) { + if (raftIvfParams->type == VecSimType_FLOAT32) { + return new (abstractInitParams.allocator) + RaftIVFIndex(raftIvfParams, abstractInitParams); + } else if (raftIvfParams->type == VecSimType_FLOAT64) { + return new (abstractInitParams.allocator) + RaftIVFIndex(raftIvfParams, abstractInitParams); + } + + // If we got here something is wrong. + return NULL; +} + +VecSimIndex *NewIndex(const RaftIvfParams *raftIvfParams) { + VecSimParams params = {.algoParams{.raftIvfParams = RaftIvfParams{*raftIvfParams}}}; + return NewIndex(¶ms); +} + +size_t EstimateInitialSize(const RaftIvfParams *params) { + + size_t allocations_overhead = VecSimAllocator::getAllocationOverheadSize(); + + // Constant part (not effected by parameters). + size_t est = sizeof(VecSimAllocator) + allocations_overhead; + est += sizeof(RaftIVFIndex); // Object size + if (!params.usePQ) { + // Size of each cluster data + est += params->nLists * sizeof(raft::neighbors::ivf_flat::list_data); + // Vector of shared ptr to cluster + est += params->nLists * sizeof(std::shared_ptr>); + } else { + // Size of each cluster data + est += params->nLists * sizeof(raft::neighbors::ivf_pq::list_data); + // accum_sorted_sizes_ Array + est += params->nLists * sizeof(std::int64_t); + // vector of shared ptr to cluster + est += params->nLists * sizeof(std::shared_ptr>); + } + return est; +} + +size_t EstimateElementSize(const BFParams *params) { + // Vectors are stored on the GPU. + return 0; +} +}; // namespace RaftIVFFactory diff --git a/src/VecSim/index_factories/raft_ivf_factory.h b/src/VecSim/index_factories/raft_ivf_factory.h new file mode 100644 index 000000000..d2119e699 --- /dev/null +++ b/src/VecSim/index_factories/raft_ivf_factory.h @@ -0,0 +1,18 @@ +#pragma once + +#include // size_t +#include // std::shared_ptr + +#include "VecSim/vec_sim.h" //typedef VecSimIndex +#include "VecSim/vec_sim_common.h" // RaftIvfParams +#include "VecSim/memory/vecsim_malloc.h" // VecSimAllocator +#include "VecSim/vec_sim_index.h" + +namespace RaftIVFFactory { + +VecSimIndex *NewIndex(const VecSimParams *params); +VecSimIndex *NewIndex(const RaftIvfParams *params); +size_t EstimateInitialSize(const RaftIvfParams *params); +size_t EstimateElementSize(const RaftIvfParams *params); + +}; // namespace RaftIVFFactory diff --git a/src/VecSim/vec_sim_common.h b/src/VecSim/vec_sim_common.h index 0a7f91971..ddd49b198 100644 --- a/src/VecSim/vec_sim_common.h +++ b/src/VecSim/vec_sim_common.h @@ -147,18 +147,17 @@ typedef struct { size_t kmeans_nIters; // Iterations for kmeans calculation double kmeans_trainsetFraction; // Fraction of dataset used for kmeans training unsigned nProbes; // The number of clusters to search - size_t pqBits; // Bit length of vector element after PQ compression. If set - // to 0, IVF flat will be used instead of IVFPQ. + bool usePQ; // If false IVF-Flat is used. If true IVF-PQ is used. // ***************** IVF-Flat-only parameters ****************** - // The following parameters will be ignored if pqBits is set to a - // non-zero value. + // The following parameters will be ignored if usePQ is set to true. bool adaptiveCenters; // If index should be updated for new vectors - // ******************* IVFPQ-only parameters ******************* - // The following parameters will be ignored if pqBits is set to 0 - size_t pqDim; // The dimensionality of an encoded vector after PQ - // compression. If set to 0, a heuristic will be used to - // select the dimensionality. + // ******************* IVF-PQ-only parameters ******************* + // The following parameters will be ignored if usePQ is set to false. + size_t pqBits; // Bit length of vector element after PQ compression. + size_t pqDim; // The dimensionality of an encoded vector after PQ + // compression. If set to 0, a heuristic will be used to + // select the dimensionality. IVFPQCodebookKind codebookKind; CudaType lutType; @@ -166,17 +165,12 @@ typedef struct { double preferredShmemCarveout; // Fraction of GPU's unified memory / L1 // cache to be used as shared memory -} IVFParams; - -typedef struct { - IVFParams ivfParams; - TieredIndexParams tieredParams; -} TieredIVFParams; +} RaftIvfParams; typedef union { HNSWParams hnswParams; BFParams bfParams; - IVFParams ivfParams; + RaftIvfParams raftIvfParams; TieredIndexParams tieredParams; } AlgoParams; From dbe8bea9eabbae094c15271b44badb9b2b900846 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 13 Oct 2023 13:36:50 +0200 Subject: [PATCH 09/28] Fix add and searches --- src/VecSim/CMakeLists.txt | 1 + src/VecSim/algorithms/raft_ivf/ivf.cuh | 253 ++++++++++-------- src/VecSim/index_factories/index_factory.cpp | 9 + .../index_factories/raft_ivf_factory.cu | 46 ++-- src/VecSim/index_factories/raft_ivf_factory.h | 4 +- src/VecSim/utils/vec_utils.cpp | 3 + src/VecSim/utils/vec_utils.h | 1 + src/VecSim/vec_sim_common.h | 5 +- src/VecSim/vec_sim_tiered_index.h | 1 + 9 files changed, 182 insertions(+), 141 deletions(-) diff --git a/src/VecSim/CMakeLists.txt b/src/VecSim/CMakeLists.txt index 551bd446e..6bd0014d1 100644 --- a/src/VecSim/CMakeLists.txt +++ b/src/VecSim/CMakeLists.txt @@ -19,6 +19,7 @@ add_library(VectorSimilarity ${VECSIM_LIBTYPE} index_factories/hnsw_factory.cpp index_factories/tiered_factory.cpp index_factories/index_factory.cpp + $<$:index_factories/raft_ivf_factory.cu> algorithms/hnsw/visited_nodes_handler.cpp vec_sim.cpp vec_sim_interface.cpp diff --git a/src/VecSim/algorithms/raft_ivf/ivf.cuh b/src/VecSim/algorithms/raft_ivf/ivf.cuh index 289997d64..a3b93be14 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf.cuh +++ b/src/VecSim/algorithms/raft_ivf/ivf.cuh @@ -13,9 +13,10 @@ #include "VecSim/vec_sim_common.h" // For VecSimIndexAbstract #include "VecSim/vec_sim_index.h" -#include "VecSim/query_result_struct.h" +#include "VecSim/query_result_definitions.h" // VecSimQueryResult VecSimQueryReply #include "VecSim/memory/vecsim_malloc.h" +#include #include #include #include @@ -24,7 +25,6 @@ #include #include - inline auto constexpr GetRaftDistanceType(VecSimMetric vsm) { auto result = raft::distance::DistanceType{}; switch (vsm) { @@ -74,7 +74,7 @@ inline auto constexpr GetCudaType(CudaType vss_type) { } template -struct IVFIndex : public VecSimIndexAbstract { +struct RaftIVFIndex : public VecSimIndexAbstract { using data_type = DataType; using dist_type = DistType; @@ -85,102 +85,108 @@ private: using search_params_t = std::variant; using internal_idx_t = std::uint32_t; - using ann_index_t = std::variant, - raft::neighbors::ivf_pq::index>; + using index_flat_t = raft::neighbors::ivf_flat::index; + using index_pq_t = raft::neighbors::ivf_pq::index; + using ann_index_t = std::variant; public: - IVFIndex(const RaftIvfParams *raftIvfParams, const AbstractIndexInitParams &commonParams) + RaftIVFIndex(const RaftIvfParams *raftIvfParams, const AbstractIndexInitParams &commonParams) : VecSimIndexAbstract{commonParams}, - res_{raft::resource_manager::get_device_resources()}, build_params_{[raftIvfParams]() { - auto result = raftIvfParams->usePQ ? build_params_t{std::in_place_index<1>} - : build_params_t{std::in_place_index<0>}; - std::visit( - [raftIvfParams](auto &&inner) { - inner.metric = GetRaftDistanceType(raftIvfParams->metric); - inner.n_lists = raftIvfParams->nLists; - inner.kmeans_n_iters = raftIvfParams->kmeans_nIters; - inner.kmeans_trainset_fraction = raftIvfParams->kmeans_trainsetFraction; - inner.conservative_memory_allocation = - raftIvfParams->conservativeMemoryAllocation; - if constexpr (std::is_same_v) { - inner.pq_bits = raftIvfParams->pqBits; - inner.pq_dim = raftIvfParams->pqDim; - inner.codebook_kind = GetRaftCodebookKind(raftIvfParams->codebookKind); - } else { - inner.adaptive_centers = raftIvfParams->adaptiveCenters; - } - }, - result); - return result; - }()}, - search_params_{[raftIvfParams]() { - auto result = raftIvfParams->usePQ ? search_params_t{std::in_place_index<1>} - : search_params_t{std::in_place_index<0>}; - std::visit( - [raftIvfParams](auto &&inner) { - inner.n_probes = raftIvfParams->nProbes; - if constexpr (std::is_same_v) { - inner.lut_dtype = GetCudaType(raftIvfParams->lutType); - inner.internal_distance_dtype = - GetCudaType(raftIvfParams->internalDistanceType); - inner.preferred_shmem_carvout = raftIvfParams->preferredShmemCarveout; - } - }, - result); - return result; - }()}, - index_{std::nullopt} {} - auto addVector(const void *vector_data, labelType label, - bool overwrite_allowed = true) override { - return addVectorBatch(vector_data, &label, 1, overwrite_allowed); + res_{raft::device_resources_manager::get_device_resources()}, + build_params_{raftIvfParams->usePQ ? build_params_t{std::in_place_index<1>} + : build_params_t{std::in_place_index<0>}}, + search_params_{raftIvfParams->usePQ ? search_params_t{std::in_place_index<1>} + : search_params_t{std::in_place_index<0>}}, + index_{std::nullopt} { + std::visit( + [raftIvfParams](auto &&inner) { + inner.metric = GetRaftDistanceType(raftIvfParams->metric); + inner.n_lists = raftIvfParams->nLists; + inner.kmeans_n_iters = raftIvfParams->kmeans_nIters; + inner.kmeans_trainset_fraction = raftIvfParams->kmeans_trainsetFraction; + inner.conservative_memory_allocation = raftIvfParams->conservativeMemoryAllocation; + if constexpr (std::is_same_v) { + inner.adaptive_centers = raftIvfParams->adaptiveCenters; + } else if constexpr (std::is_same_v) { + inner.pq_bits = raftIvfParams->pqBits; + inner.pq_dim = raftIvfParams->pqDim; + inner.codebook_kind = GetRaftCodebookKind(raftIvfParams->codebookKind); + } + }, + build_params_); + std::visit( + [raftIvfParams](auto &&inner) { + inner.n_probes = raftIvfParams->nProbes; + if constexpr (std::is_same_v) { + inner.lut_dtype = GetCudaType(raftIvfParams->lutType); + inner.internal_distance_dtype = + GetCudaType(raftIvfParams->internalDistanceType); + inner.preferred_shmem_carvout = raftIvfParams->preferredShmemCarveout; + } + }, + search_params_); + } + int addVector(const void *vector_data, labelType label, void *auxiliaryCtx = nullptr) override { + return addVectorBatch(vector_data, &label, 1, auxiliaryCtx); } - auto addVectorBatchAsync(const void *vector_data, labelType *label, size_t batch_size, - bool overwrite_allowed = true) { + int addVectorBatchAsync(const void *vector_data, labelType *label, size_t batch_size, + void *auxiliaryCtx = nullptr) { + // Convert labels to internal data type + auto label_original = std::vector(label, label + batch_size); + auto label_converted = + std::vector(label_original.begin(), label_original.end()); // Allocate memory on device to hold vectors to be added auto vector_data_gpu = raft::make_device_matrix(res_, batch_size, this->dim); // Allocate memory on device to hold vector labels - auto label_gpu = raft::make_device_vector(res_, batch_size); + auto label_gpu = raft::make_device_vector(res_, batch_size); // Copy vector data to previously allocated device buffer raft::copy(vector_data_gpu.data_handle(), static_cast(vector_data), this->dim * batch_size, res_.get_stream()); // Copy label data to previously allocated device buffer - raft::copy(label_gpu.data_handle(), label, batch_size, res_.get_stream()); + raft::copy(label_gpu.data_handle(), label_converted.data(), batch_size, res_.get_stream()); if (std::holds_alternative(build_params_)) { if (!index_) { index_ = raft::neighbors::ivf_flat::build( res_, std::get(build_params_), - vector_data_gpu.view()); + raft::make_const_mdspan(vector_data_gpu.view())); } - raft::neighbors::ivf_flat::extend(res_, vector_data_gpu.view(), label_gpu, *index_); + raft::neighbors::ivf_flat::extend( + res_, raft::make_const_mdspan(vector_data_gpu.view()), + std::make_optional(raft::make_const_mdspan(label_gpu.view())), + std::get(*index_)); } else { if (!index_) { index_ = raft::neighbors::ivf_pq::build( res_, std::get(build_params_), - vector_data_gpu.view()); + raft::make_const_mdspan(vector_data_gpu.view())); } - raft::neighbors::ivf_pq::extend(res_, vector_data_gpu.view(), label_gpu, *index_); + raft::neighbors::ivf_pq::extend( + res_, raft::make_const_mdspan(vector_data_gpu.view()), + std::make_optional(raft::make_const_mdspan(label_gpu.view())), + std::get(*index_)); } return batch_size; } - auto addVectorBatch(const void *vector_data, labelType *label, size_t batch_size, - bool overwrite_allowed = true) { - auto result = addVectorBatchAsync(vector_data, label, batch_size, overwrite_allowed); + int addVectorBatch(const void *vector_data, labelType *label, size_t batch_size, + void *auxiliaryCtx = nullptr) { + auto result = addVectorBatchAsync(vector_data, label, batch_size, auxiliaryCtx); // Ensure that above operation has executed on device before // returning from this function on host res_.sync_stream(); return result; } - auto deleteVector(labelType label) override { + int deleteVector(labelType label) override { assert(!"deleteVector not implemented"); return 0; } - double getDistanceFrom(labelType label, const void *vector_data) const override { + double getDistanceFrom_Unsafe(labelType label, const void *vector_data) const override { assert(!"getDistanceFrom not implemented"); return INVALID_SCORE; } @@ -188,66 +194,67 @@ public: assert(!"indexCapacity not implemented"); return 0; } - void increaseCapacity() override { assert(!"increaseCapacity not implemented"); } - inline auto indexLabelCount() const override { + // void increaseCapacity() override { assert(!"increaseCapacity not implemented"); } + inline size_t indexLabelCount() const override { return this->indexSize(); // TODO: Return unique counts } - auto topKQuery(const void *queryBlob, size_t k, VecSimQueryParams *queryParams) override { - auto result_list = VecSimQueryResult_List{0}; + VecSimQueryReply *topKQuery(const void *queryBlob, size_t k, + VecSimQueryParams *queryParams) const override { + auto result_list = new VecSimQueryReply(this->allocator); auto nVectors = this->indexSize(); - if (nVectors == 0) { - result_list.results = array_new(0); - } else { - // Ensure we are not trying to retrieve more vectors than exist in the - // index - k = std::min(k, nVectors); - // Allocate memory on device for search vector - auto vector_data_gpu = raft::make_device_matrix(res_, 1, this->dim); - // Allocate memory on device for neighbor results - auto neighbors_gpu = raft::make_device_vector(res_, k); - // Allocate memory on device for distance results - auto distances_gpu = raft::make_device_vector(res_, k); - // Copy query vector to device - raft::copy(vector_data_gpu.data_handle(), static_cast(queryBlob), this->dim, - res_.get_stream()); + if (nVectors == 0 || k == 0 || !index_.has_value()) { + return result_list; + } + // Ensure we are not trying to retrieve more vectors than exist in the + // index + k = std::min(k, nVectors); + // Allocate memory on device for search vector + auto vector_data_gpu = + raft::make_device_matrix(res_, 1, this->dim); + // Allocate memory on device for neighbor and distance results + auto neighbors_gpu = raft::make_device_matrix(res_, 1, k); + auto distances_gpu = raft::make_device_matrix(res_, 1, k); + // Copy query vector to device + raft::copy(vector_data_gpu.data_handle(), static_cast(queryBlob), this->dim, + res_.get_stream()); - // Perform correct search based on index type - if (std::holds_alternative(index_)) { - raft::neighbors::ivf_flat::search( - res_, std::get(search_params_), - std::get(*index_), vector_data_gpu.view(), - neighbors_gpu.view(), distances_gpu.view()) - } else { - raft::neighbors::ivf_pq::search( - res_, std::get(search_params_), - std::get(*index_), vector_data_gpu.view(), - neighbors_gpu.view(), distances_gpu.view()) - } + // Perform correct search based on index type + if (std::holds_alternative(*index_)) { + raft::neighbors::ivf_flat::search( + res_, std::get(search_params_), + std::get(*index_), raft::make_const_mdspan(vector_data_gpu.view()), + neighbors_gpu.view(), distances_gpu.view()); + // TODO ADD STREAM MANAGER + } else { + raft::neighbors::ivf_pq::search( + res_, std::get(search_params_), + std::get(*index_), raft::make_const_mdspan(vector_data_gpu.view()), + neighbors_gpu.view(), distances_gpu.view()); + // TODO ADD STREAM MANAGER + } - // Allocate host buffers to hold returned results - auto neighbors = - std::unique_ptr(array_new_len(k, k), &array_free); - auto distances = - std::unique_ptr(array_new_len(k, k), &array_free); - // Copy data back from device to host - raft::copy(neighbors.get(), neighbors_gpu.data_handle(), this->dim, res_.get_stream()); - raft::copy(distances.get(), distances_gpu.data_handle(), this->dim, res_.get_stream()); + // Allocate host buffers to hold returned results + auto neighbors = vecsim_stl::vector(k, this->allocator); + auto distances = vecsim_stl::vector(k, this->allocator); + // Copy data back from device to host + raft::copy(neighbors.data(), neighbors_gpu.data_handle(), this->dim, res_.get_stream()); + raft::copy(distances.data(), distances_gpu.data_handle(), this->dim, res_.get_stream()); - result_list.results = array_new_len(k, k); + // Ensure search is complete and data have been copied back before + // building query result objects on host + res_.sync_stream(); - // Ensure search is complete and data have been copied back before - // building query result objects on host - res_.sync_stream(); - for (size_t i = 0; i < k; ++i) { - VecSimQueryResult_SetId(result_list.results[i], neighbors[i]); - VecSimQueryResult_SetScore(result_list.results[i], distances[i]); - } + result_list->results.resize(k); + for (auto i = 0; i < k; ++i) { + result_list->results[i].id = labelType{neighbors[i]}; + result_list->results[i].score = distances[i]; } + return result_list; } - VecSimQueryResult_List rangeQuery(const void *queryBlob, double radius, - VecSimQueryParams *queryParams) override { + VecSimQueryReply *rangeQuery(const void *queryBlob, double radius, + VecSimQueryParams *queryParams) const override { assert(!"RangeQuery not implemented"); } VecSimInfoIterator *infoIterator() const override { assert(!"infoIterator not implemented"); } @@ -255,7 +262,7 @@ public: VecSimQueryParams *queryParams) const override { assert(!"newBatchIterator not implemented"); } - bool preferAdHocSearch(size_t subsetSize, size_t k, bool initial_check) override { + bool preferAdHocSearch(size_t subsetSize, size_t k, bool initial_check) const override { assert(!"preferAdHocSearch not implemented"); } @@ -265,13 +272,31 @@ public: return std::visit([](auto &¶ms) { return params.n_list; }, build_params_); } - auto indexSize() { + size_t indexSize() const override { auto result = size_t{}; if (index_) { result = std::visit([](auto &&index) { return index.size(); }, *index_); } return result; } + VecSimIndexBasicInfo basicInfo() const override { + VecSimIndexBasicInfo info = this->getBasicInfo(); + info.algo = VecSimAlgo_RaftIVF; + info.isTiered = false; + return info; + } + VecSimIndexInfo info() const override { + VecSimIndexInfo info; + info.commonInfo = this->getCommonInfo(); + info.raftIvfInfo.nLists = nLists(); + if (std::holds_alternative(build_params_)) { + const auto build_params_pq = + std::get(build_params_); + info.raftIvfInfo.pqBits = build_params_pq.pq_bits; + info.raftIvfInfo.pqDim = build_params_pq.pq_dim; + } + return info; + } private: // An object used to manage common device resources that may be diff --git a/src/VecSim/index_factories/index_factory.cpp b/src/VecSim/index_factories/index_factory.cpp index e71affa4f..f2d419c63 100644 --- a/src/VecSim/index_factories/index_factory.cpp +++ b/src/VecSim/index_factories/index_factory.cpp @@ -8,6 +8,7 @@ #include "hnsw_factory.h" #include "brute_force_factory.h" #include "tiered_factory.h" +#include "raft_ivf_factory.h" #include "VecSim/vec_sim_index.h" namespace VecSimFactory { @@ -25,6 +26,10 @@ VecSimIndex *NewIndex(const VecSimParams *params) { index = BruteForceFactory::NewIndex(params); break; } + case VecSimAlgo_RaftIVF: { + index = RaftIvfFactory::NewIndex(¶ms->algoParams.raftIvfParams); + break; + } case VecSimAlgo_TIERED: { index = TieredFactory::NewIndex(¶ms->algoParams.tieredParams); break; @@ -42,6 +47,8 @@ size_t EstimateInitialSize(const VecSimParams *params) { return HNSWFactory::EstimateInitialSize(¶ms->algoParams.hnswParams); case VecSimAlgo_BF: return BruteForceFactory::EstimateInitialSize(¶ms->algoParams.bfParams); + case VecSimAlgo_RaftIVF: + return RaftIvfFactory::EstimateInitialSize(¶ms->algoParams.raftIvfParams); case VecSimAlgo_TIERED: return TieredFactory::EstimateInitialSize(¶ms->algoParams.tieredParams); } @@ -54,6 +61,8 @@ size_t EstimateElementSize(const VecSimParams *params) { return HNSWFactory::EstimateElementSize(¶ms->algoParams.hnswParams); case VecSimAlgo_BF: return BruteForceFactory::EstimateElementSize(¶ms->algoParams.bfParams); + case VecSimAlgo_RaftIVF: + return RaftIvfFactory::EstimateElementSize(¶ms->algoParams.raftIvfParams); case VecSimAlgo_TIERED: return TieredFactory::EstimateElementSize(¶ms->algoParams.tieredParams); } diff --git a/src/VecSim/index_factories/raft_ivf_factory.cu b/src/VecSim/index_factories/raft_ivf_factory.cu index 83c842fe3..6a7e64271 100644 --- a/src/VecSim/index_factories/raft_ivf_factory.cu +++ b/src/VecSim/index_factories/raft_ivf_factory.cu @@ -1,28 +1,22 @@ #include "VecSim/index_factories/brute_force_factory.h" #include "VecSim/algorithms/raft_ivf/ivf.cuh" -namespace RaftIVFFactory { +namespace RaftIvfFactory { static AbstractIndexInitParams NewAbstractInitParams(const VecSimParams *params) { const RaftIvfParams *raftIvfParams = ¶ms->algoParams.raftIvfParams; AbstractIndexInitParams abstractInitParams = {.allocator = VecSimAllocator::newVecsimAllocator(), - .dim = bfParams->dim, - .vecType = bfParams->type, - .metric = bfParams->metric, - .blockSize = bfParams->blockSize, - .multi = bfParams->multi, - .logCtx = params->logCtx}; + .dim = raftIvfParams->dim, + .vecType = raftIvfParams->type, + .metric = raftIvfParams->metric, + //.multi = raftIvfParams->multi, + //.logCtx = params->logCtx + }; return abstractInitParams; } -VecSimIndex *NewIndex(const VecSimParams *params) { - const RaftIvfParams *raftIvfParams = ¶ms->algoParams.raftIvfParams; - AbstractIndexInitParams abstractInitParams = NewAbstractInitParams(params); - return NewIndex(raftIvfParams, NewAbstractInitParams(params)); -} - VecSimIndex *NewIndex(const RaftIvfParams *raftIvfParams, const AbstractIndexInitParams &abstractInitParams) { if (raftIvfParams->type == VecSimType_FLOAT32) { return new (abstractInitParams.allocator) @@ -36,36 +30,42 @@ VecSimIndex *NewIndex(const RaftIvfParams *raftIvfParams, const AbstractIndexIni return NULL; } +VecSimIndex *NewIndex(const VecSimParams *params) { + const RaftIvfParams *raftIvfParams = ¶ms->algoParams.raftIvfParams; + AbstractIndexInitParams abstractInitParams = NewAbstractInitParams(params); + return NewIndex(raftIvfParams, NewAbstractInitParams(params)); +} + VecSimIndex *NewIndex(const RaftIvfParams *raftIvfParams) { VecSimParams params = {.algoParams{.raftIvfParams = RaftIvfParams{*raftIvfParams}}}; return NewIndex(¶ms); } -size_t EstimateInitialSize(const RaftIvfParams *params) { +size_t EstimateInitialSize(const RaftIvfParams *raftIvfParams) { size_t allocations_overhead = VecSimAllocator::getAllocationOverheadSize(); // Constant part (not effected by parameters). size_t est = sizeof(VecSimAllocator) + allocations_overhead; - est += sizeof(RaftIVFIndex); // Object size - if (!params.usePQ) { + est += sizeof(RaftIVFIndex); // Object size + if (!raftIvfParams->usePQ) { // Size of each cluster data - est += params->nLists * sizeof(raft::neighbors::ivf_flat::list_data); + est += raftIvfParams->nLists * sizeof(raft::neighbors::ivf_flat::list_data); // Vector of shared ptr to cluster - est += params->nLists * sizeof(std::shared_ptr>); + est += raftIvfParams->nLists * sizeof(std::shared_ptr>); } else { // Size of each cluster data - est += params->nLists * sizeof(raft::neighbors::ivf_pq::list_data); + est += raftIvfParams->nLists * sizeof(raft::neighbors::ivf_pq::list_data); // accum_sorted_sizes_ Array - est += params->nLists * sizeof(std::int64_t); + est += raftIvfParams->nLists * sizeof(std::int64_t); // vector of shared ptr to cluster - est += params->nLists * sizeof(std::shared_ptr>); + est += raftIvfParams->nLists * sizeof(std::shared_ptr>); } return est; } -size_t EstimateElementSize(const BFParams *params) { +size_t EstimateElementSize(const RaftIvfParams *raftIvfParams) { // Vectors are stored on the GPU. return 0; } -}; // namespace RaftIVFFactory +}; // namespace RaftIvfFactory diff --git a/src/VecSim/index_factories/raft_ivf_factory.h b/src/VecSim/index_factories/raft_ivf_factory.h index d2119e699..040c1c9d9 100644 --- a/src/VecSim/index_factories/raft_ivf_factory.h +++ b/src/VecSim/index_factories/raft_ivf_factory.h @@ -8,11 +8,11 @@ #include "VecSim/memory/vecsim_malloc.h" // VecSimAllocator #include "VecSim/vec_sim_index.h" -namespace RaftIVFFactory { +namespace RaftIvfFactory { VecSimIndex *NewIndex(const VecSimParams *params); VecSimIndex *NewIndex(const RaftIvfParams *params); size_t EstimateInitialSize(const RaftIvfParams *params); size_t EstimateElementSize(const RaftIvfParams *params); -}; // namespace RaftIVFFactory +}; // namespace RaftIvfFactory diff --git a/src/VecSim/utils/vec_utils.cpp b/src/VecSim/utils/vec_utils.cpp index b061bddcf..27cae0e06 100644 --- a/src/VecSim/utils/vec_utils.cpp +++ b/src/VecSim/utils/vec_utils.cpp @@ -15,6 +15,7 @@ const char *VecSimCommonStrings::ALGORITHM_STRING = "ALGORITHM"; const char *VecSimCommonStrings::FLAT_STRING = "FLAT"; const char *VecSimCommonStrings::HNSW_STRING = "HNSW"; +const char *VecSimCommonStrings::RAFTIVF_STRING = "RAFT_IVF"; const char *VecSimCommonStrings::TIERED_STRING = "TIERED"; const char *VecSimCommonStrings::TYPE_STRING = "TYPE"; @@ -125,6 +126,8 @@ const char *VecSimAlgo_ToString(VecSimAlgo vecsimAlgo) { return VecSimCommonStrings::FLAT_STRING; case VecSimAlgo_HNSWLIB: return VecSimCommonStrings::HNSW_STRING; + case VecSimAlgo_RaftIVF: + return VecSimCommonStrings::RAFTIVF_STRING; case VecSimAlgo_TIERED: return VecSimCommonStrings::TIERED_STRING; } diff --git a/src/VecSim/utils/vec_utils.h b/src/VecSim/utils/vec_utils.h index 79c8011e7..1916444b2 100644 --- a/src/VecSim/utils/vec_utils.h +++ b/src/VecSim/utils/vec_utils.h @@ -18,6 +18,7 @@ struct VecSimCommonStrings { static const char *ALGORITHM_STRING; static const char *FLAT_STRING; static const char *HNSW_STRING; + static const char *RAFTIVF_STRING; static const char *TIERED_STRING; static const char *TYPE_STRING; diff --git a/src/VecSim/vec_sim_common.h b/src/VecSim/vec_sim_common.h index ddd49b198..db4a00779 100644 --- a/src/VecSim/vec_sim_common.h +++ b/src/VecSim/vec_sim_common.h @@ -38,7 +38,7 @@ typedef enum { } VecSimType; // Algorithm type/library. -typedef enum { VecSimAlgo_BF, VecSimAlgo_HNSWLIB, VecSimAlgo_TIERED } VecSimAlgo; +typedef enum { VecSimAlgo_BF, VecSimAlgo_HNSWLIB, VecSimAlgo_RaftIVF, VecSimAlgo_TIERED } VecSimAlgo; // Distance metric typedef enum { VecSimMetric_L2, VecSimMetric_IP, VecSimMetric_Cosine } VecSimMetric; @@ -275,7 +275,7 @@ typedef struct { size_t nLists; // Number of inverted lists. size_t pqDim; // Dimensionality of encoded vector after PQ size_t pqBits; // Bits per encoded vector element after PQ -} ivfInfoStruct; +} raftIvfInfoStruct; typedef struct HnswTieredInfo { size_t pendingSwapJobsThreshold; @@ -309,6 +309,7 @@ typedef struct { union { bfInfoStruct bfInfo; hnswInfoStruct hnswInfo; + raftIvfInfoStruct raftInfo; tieredInfoStruct tieredInfo; }; } VecSimIndexInfo; diff --git a/src/VecSim/vec_sim_tiered_index.h b/src/VecSim/vec_sim_tiered_index.h index bc5f53d71..768bf02b8 100644 --- a/src/VecSim/vec_sim_tiered_index.h +++ b/src/VecSim/vec_sim_tiered_index.h @@ -284,6 +284,7 @@ VecSimIndexInfo VecSimTieredIndex::info() const { case VecSimAlgo_HNSWLIB: info.tieredInfo.backendInfo.hnswInfo = backendInfo.hnswInfo; break; + case VecSimAlgo_RaftIVF: // TODO Add RaftIVF info case VecSimAlgo_BF: case VecSimAlgo_TIERED: assert(false && "Invalid backend algorithm"); From 1d6282479f46e2c105af52b58127743a0387df68 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 18 Oct 2023 16:21:28 +0200 Subject: [PATCH 10/28] Add size computation of tiered index --- src/VecSim/index_factories/tiered_factory.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/VecSim/index_factories/tiered_factory.cpp b/src/VecSim/index_factories/tiered_factory.cpp index fbf89c481..3e9a56253 100644 --- a/src/VecSim/index_factories/tiered_factory.cpp +++ b/src/VecSim/index_factories/tiered_factory.cpp @@ -89,6 +89,13 @@ VecSimIndex *NewIndex(const TieredIndexParams *params) { } else if (type == VecSimType_FLOAT64) { return TieredHNSWFactory::NewIndex(params); } + } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFTIVF) { + VecSimType type = params->primaryIndexParams->algoParams.raftIvfParams.type; + if (type == VecSimType_FLOAT32) { + return TieredRaftIvfFactory::NewIndex(params); + } else if (type == VecSimType_FLOAT64) { + return TieredRaftIvfFactory::NewIndex(params); + } } return nullptr; // Invalid algorithm or type. } @@ -99,6 +106,8 @@ size_t EstimateInitialSize(const TieredIndexParams *params) { BFParams bf_params{}; if (params->primaryIndexParams->algo == VecSimAlgo_HNSWLIB) { est += TieredHNSWFactory::EstimateInitialSize(params, bf_params); + } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFTIVF) { + est += TieredRaftIvfFactory::EstimateInitialSize(params, bf_params); } est += BruteForceFactory::EstimateInitialSize(&bf_params); @@ -109,6 +118,8 @@ size_t EstimateElementSize(const TieredIndexParams *params) { size_t est = 0; if (params->primaryIndexParams->algo == VecSimAlgo_HNSWLIB) { est = HNSWFactory::EstimateElementSize(¶ms->primaryIndexParams->algoParams.hnswParams); + } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFTIVF) { + est = RaftIvfFactory::EstimateElementSize(¶ms->primaryIndexParams->algoParams.raftIvfParams); } return est; } From 32ce6cc58b84dc6778bb2a6c4528ec24f5b1a2ec Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 19 Oct 2023 19:55:47 +0200 Subject: [PATCH 11/28] Update tiered index --- src/VecSim/CMakeLists.txt | 1 + .../algorithms/brute_force/brute_force.h | 6 +- .../brute_force/brute_force_multi.h | 6 ++ .../brute_force/brute_force_single.h | 6 ++ src/VecSim/algorithms/raft_ivf/ivf.cuh | 4 +- src/VecSim/algorithms/raft_ivf/ivf_tiered.cuh | 45 +++++++------ .../index_factories/raft_ivf_factory.cu | 17 ++++- .../raft_ivf_tiered_factory.cu | 64 +++++++++++++++++++ .../index_factories/raft_ivf_tiered_factory.h | 19 ++++++ src/VecSim/index_factories/tiered_factory.cpp | 10 +-- 10 files changed, 144 insertions(+), 34 deletions(-) create mode 100644 src/VecSim/index_factories/raft_ivf_tiered_factory.cu create mode 100644 src/VecSim/index_factories/raft_ivf_tiered_factory.h diff --git a/src/VecSim/CMakeLists.txt b/src/VecSim/CMakeLists.txt index 6bd0014d1..323121d2d 100644 --- a/src/VecSim/CMakeLists.txt +++ b/src/VecSim/CMakeLists.txt @@ -20,6 +20,7 @@ add_library(VectorSimilarity ${VECSIM_LIBTYPE} index_factories/tiered_factory.cpp index_factories/index_factory.cpp $<$:index_factories/raft_ivf_factory.cu> + $<$:index_factories/raft_ivf_tiered_factory.cu> algorithms/hnsw/visited_nodes_handler.cpp vec_sim.cpp vec_sim_interface.cpp diff --git a/src/VecSim/algorithms/brute_force/brute_force.h b/src/VecSim/algorithms/brute_force/brute_force.h index 70758843e..c75546e8c 100644 --- a/src/VecSim/algorithms/brute_force/brute_force.h +++ b/src/VecSim/algorithms/brute_force/brute_force.h @@ -35,11 +35,7 @@ class BruteForceIndex : public VecSimIndexAbstract { public: BruteForceIndex(const BFParams *params, const AbstractIndexInitParams &abstractInitParams); - void clear() { - idToLabelMapping.clear(); - vectorBlocks.clear(); - count = idType{}; - } + virtual void clear() = 0; size_t indexSize() const override; size_t indexCapacity() const override; vecsim_stl::vector computeBlockScores(const DataBlock &block, const void *queryBlob, diff --git a/src/VecSim/algorithms/brute_force/brute_force_multi.h b/src/VecSim/algorithms/brute_force/brute_force_multi.h index 086adc13e..c93776096 100644 --- a/src/VecSim/algorithms/brute_force/brute_force_multi.h +++ b/src/VecSim/algorithms/brute_force/brute_force_multi.h @@ -23,6 +23,12 @@ class BruteForceIndex_Multi : public BruteForceIndex { ~BruteForceIndex_Multi() {} + void clear() override { + this->labelToIdsLookup.clear(); + this->idToLabelMapping.clear(); + this->vectorBlocks.clear(); + this->count = idType{}; + } int addVector(const void *vector_data, labelType label, void *auxiliaryCtx = nullptr) override; int deleteVector(labelType labelType) override; int deleteVectorById(labelType label, idType id) override; diff --git a/src/VecSim/algorithms/brute_force/brute_force_single.h b/src/VecSim/algorithms/brute_force/brute_force_single.h index dba740c89..88237c338 100644 --- a/src/VecSim/algorithms/brute_force/brute_force_single.h +++ b/src/VecSim/algorithms/brute_force/brute_force_single.h @@ -21,6 +21,12 @@ class BruteForceIndex_Single : public BruteForceIndex { const AbstractIndexInitParams &abstractInitParams); ~BruteForceIndex_Single(); + void clear() override { + this->labelToIdLookup.clear(); + this->idToLabelMapping.clear(); + this->vectorBlocks.clear(); + this->count = idType{}; + } int addVector(const void *vector_data, labelType label, void *auxiliaryCtx = nullptr) override; int deleteVector(labelType label) override; int deleteVectorById(labelType label, idType id) override; diff --git a/src/VecSim/algorithms/raft_ivf/ivf.cuh b/src/VecSim/algorithms/raft_ivf/ivf.cuh index a3b93be14..e8cf4658f 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf.cuh +++ b/src/VecSim/algorithms/raft_ivf/ivf.cuh @@ -84,7 +84,7 @@ private: raft::neighbors::ivf_pq::index_params>; using search_params_t = std::variant; - using internal_idx_t = std::uint32_t; + using internal_idx_t = std::int64_t; using index_flat_t = raft::neighbors::ivf_flat::index; using index_pq_t = raft::neighbors::ivf_pq::index; using ann_index_t = std::variant; @@ -253,7 +253,7 @@ public: return result_list; } - VecSimQueryReply *rangeQuery(const void *queryBlob, double radius, + virtual VecSimQueryReply *rangeQuery(const void *queryBlob, double radius, VecSimQueryParams *queryParams) const override { assert(!"RangeQuery not implemented"); } diff --git a/src/VecSim/algorithms/raft_ivf/ivf_tiered.cuh b/src/VecSim/algorithms/raft_ivf/ivf_tiered.cuh index fb3fee39d..12a5e54c1 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf_tiered.cuh +++ b/src/VecSim/algorithms/raft_ivf/ivf_tiered.cuh @@ -1,61 +1,62 @@ #pragma once #include -#include "VecSim/algorithms/ivf/ivf.cuh" +#include "VecSim/algorithms/raft_ivf/ivf.cuh" #include "VecSim/vec_sim_tiered_index.h" struct RAFTTransferJob : public AsyncJob { - bool overwrite_allowed{true}; - RAFTTransferJob(std::shared_ptr allocator, bool overwrite_allowed_, + RAFTTransferJob(std::shared_ptr allocator, JobCallback insertCb, VecSimIndex *index_) - : AsyncJob{allocator, RAFT_TRANSFER_JOB, insertCb, index_}, - overwrite_allowed{overwrite_allowed_} {} + : AsyncJob{allocator, RAFT_TRANSFER_JOB, insertCb, index_} + { + } }; template -struct TieredIVFIndex : public VecSimTieredIndex { - auto addVector(const void *blob, labelType label, bool overwrite_allowed) { +struct TieredRaftIVFIndex : public VecSimTieredIndex { + int addVector(const void *blob, labelType label, void *auxiliaryCtx) override { auto frontend_lock = std::scoped_lock(this->flatIndexGuard); - auto result = this->frontendIndex->addVector(blob, label, overwrite_allowed); + auto result = this->frontendIndex->addVector(blob, label); if (this->frontendIndex->indexSize() >= this->flatBufferLimit) { - transferToBackend(overwrite_allowed); + transferToBackend(); } return result; } - auto deleteVector(labelType label) { + int deleteVector(labelType label) override { // TODO(wphicks) // If in flatIndex, delete // If being transferred to backend, wait for transfer // If in backendIndex, delete + return 0; } - auto indexSize() { + size_t indexSize() { auto frontend_lock = std::scoped_lock(this->flatIndexGuard); auto backend_lock = std::scoped_lock(this->mainIndexGuard); return (getBackendIndex().indexSize() + this->frontendIndex.indexSize()); } - auto indexLabelCount() const override { + size_t indexLabelCount() const override { // TODO(wphicks) Count unique labels between both indexes } - auto indexCapacity() const override { + size_t indexCapacity() const override { return (getBackendIndex().indexCapacity() + this->flatBufferLimit); } void increaseCapacity() override { getBackendIndex().increaseCapacity(); } - auto getDistanceFrom(labelType label, const void *blob) { + double getDistanceFrom_Unsafe(labelType label, const void *blob) const override { auto frontend_lock = std::unique_lock(this->flatIndexGuard); - auto flat_dist = this->frontendIndex->getDistanceFrom(label, blob); + auto flat_dist = this->frontendIndex->getDistanceFrom_Unsafe(label, blob); frontend_lock.unlock(); auto backend_lock = std::scoped_lock(this->mainIndexGuard); - auto raft_dist = getBackendIndex().getDistanceFrom(label, blob); + auto raft_dist = getBackendIndex().getDistanceFrom_Unsafe(label, blob); return std::fmin(flat_dist, raft_dist); } - void executeTransferJob(RAFTTransferJob *job) { transferToBackend(job->overwrite_allowed); } + void executeTransferJob(RAFTTransferJob *job) { transferToBackend(); } private: vecsim_stl::unordered_map> labelToTransferJobs; @@ -63,10 +64,14 @@ private: return *dynamic_cast *>(this->backendIndex); } - void transferToBackend(overwrite_allowed = true) { - auto dim = this->index->getDim(); + void transferToBackend() { auto frontend_lock = std::unique_lock(this->flatIndexGuard); auto nVectors = this->flatBuffer->indexSize(); + if (nVectors == 0) { + frontend_lock.unlock(); + return; + } + auto dim = this->index->getDim(); const auto &vectorBlocks = this->flatBuffer->getVectorBlocks(); auto vectorData = raft::make_host_matrix(getBackendIndex().get_resources(), nVectors, dim); @@ -82,6 +87,6 @@ private: this->flatBuffer->clear(); frontend_lock.unlock(); getBackendIndex().addVectorBatch(vectorData.data_handle(), this->flatBuffer->getLabels(), - nVectors, overwrite_allowed); + nVectors); } }; diff --git a/src/VecSim/index_factories/raft_ivf_factory.cu b/src/VecSim/index_factories/raft_ivf_factory.cu index 6a7e64271..355afebb4 100644 --- a/src/VecSim/index_factories/raft_ivf_factory.cu +++ b/src/VecSim/index_factories/raft_ivf_factory.cu @@ -1,5 +1,6 @@ #include "VecSim/index_factories/brute_force_factory.h" #include "VecSim/algorithms/raft_ivf/ivf.cuh" +#include namespace RaftIvfFactory { @@ -42,7 +43,6 @@ VecSimIndex *NewIndex(const RaftIvfParams *raftIvfParams) { } size_t EstimateInitialSize(const RaftIvfParams *raftIvfParams) { - size_t allocations_overhead = VecSimAllocator::getAllocationOverheadSize(); // Constant part (not effected by parameters). @@ -65,7 +65,18 @@ size_t EstimateInitialSize(const RaftIvfParams *raftIvfParams) { } size_t EstimateElementSize(const RaftIvfParams *raftIvfParams) { - // Vectors are stored on the GPU. - return 0; + // Those elements are stored only on GPU. + size_t est = 0; + if (!raftIvfParams->usePQ) { + // Size of vec + size of label. + est += raftIvfParams->dim * VecSimType_sizeof(raftIvfParams->type) + sizeof(labelType); + } else { + size_t pq_dim = raftIvfParams->pqDim; + if (pq_dim == 0) + pq_dim = raft::neighbors::ivf_pq::calculate_pq_dim(raftIvfParams->dim); + // Size of vec after compression + size of label + est += raftIvfParams->pqBits * pq_dim + sizeof(labelType); + } + return est; } }; // namespace RaftIvfFactory diff --git a/src/VecSim/index_factories/raft_ivf_tiered_factory.cu b/src/VecSim/index_factories/raft_ivf_tiered_factory.cu new file mode 100644 index 000000000..97f246f97 --- /dev/null +++ b/src/VecSim/index_factories/raft_ivf_tiered_factory.cu @@ -0,0 +1,64 @@ +#include "VecSim/index_factories/brute_force_factory.h" +#include "VecSim/algorithms/raft_ivf/ivf_tiered.cuh" +#include "VecSim/index_factories/tiered_factory.h" +#include "VecSim/index_factories/raft_ivf_factory.h" + +namespace TieredRaftIvfFactory { + +template +VecSimIndex *NewIndex(const TieredIndexParams *params) +{ + // initialize raft index + auto *raft_index = reinterpret_cast *>( + RaftIVFFactory::NewIndex(params->primaryIndexParams)); + // initialize brute force index + BFParams bf_params = { + .type = params->primaryIndexParams->algoParams.raftIvfParams.type, + .dim = params->primaryIndexParams->algoParams.raftIvfParams.dim, + .metric = params->primaryIndexParams->algoParams.raftIvfParams.metric, + .multi = params->primaryIndexParams->algoParams.raftIvfParams.multi, + .blockSize = params->primaryIndexParams->algoParams.raftIvfParams.blockSize}; + + std::shared_ptr flat_allocator = VecSimAllocator::newVecsimAllocator(); + AbstractIndexInitParams abstractInitParams = {.allocator = flat_allocator, + .dim = bf_params.dim, + .vecType = bf_params.type, + .metric = bf_params.metric, + .blockSize = bf_params.blockSize, + .multi = bf_params.multi, + .logCtx = params->primaryIndexParams->logCtx}; + auto frontendIndex = static_cast *>( + BruteForceFactory::NewIndex(&bf_params, abstractInitParams)); + + // Create new tiered hnsw index + std::shared_ptr management_layer_allocator = + VecSimAllocator::newVecsimAllocator(); + + return new (management_layer_allocator) TieredRaftIVFIndex( + raft_index, frontendIndex, *params, management_layer_allocator); +} + +// The size estimation is the sum of the buffer (brute force) and main index initial sizes +// estimations, plus the tiered index class size. Note it does not include the size of internal +// containers such as the job queue, as those depend on the user implementation. +inline size_t EstimateInitialSize(const TieredIndexParams *params) { + auto raft_ivf_params = params->primaryIndexParams->algoParams.raftIvfParams; + + // Add size estimation of VecSimTieredIndex sub indexes. + size_t est = RaftIvfFactory::EstimateInitialSize(&raft_ivf_params); + + // Management layer allocator overhead. + size_t allocations_overhead = VecSimAllocator::getAllocationOverheadSize(); + est += sizeof(VecSimAllocator) + allocations_overhead; + + // Size of the TieredRaftIVFIndex struct. + if (raft_ivf_params.type == VecSimType_FLOAT32) { + est += sizeof(TieredRaftIVFIndex); + } else if (raft_ivf_params.type == VecSimType_FLOAT64) { + est += sizeof(TieredRaftIVFIndex); + } + + return est; +} + +}; // namespace TieredRaftIvfFactory diff --git a/src/VecSim/index_factories/raft_ivf_tiered_factory.h b/src/VecSim/index_factories/raft_ivf_tiered_factory.h new file mode 100644 index 000000000..e38cdd204 --- /dev/null +++ b/src/VecSim/index_factories/raft_ivf_tiered_factory.h @@ -0,0 +1,19 @@ +#pragma once + +#include "VecSim/vec_sim.h" //typedef VecSimIndex +#include "VecSim/vec_sim_common.h" // RaftIvfParams +#include "VecSim/memory/vecsim_malloc.h" // VecSimAllocator +#include "VecSim/vec_sim_index.h" + +namespace TieredRaftIvfFactory { + +template +VecSimIndex *NewIndex(const TieredIndexParams *params); + +// The size estimation is the sum of the buffer (brute force) and main index initial sizes +// estimations, plus the tiered index class size. Note it does not include the size of internal +// containers such as the job queue, as those depend on the user implementation. +size_t EstimateInitialSize(const TieredIndexParams *params); +size_t EstimateElementSize(const TieredIndexParams *params); + +}; // namespace TieredRaftIvfFactory diff --git a/src/VecSim/index_factories/tiered_factory.cpp b/src/VecSim/index_factories/tiered_factory.cpp index 3e9a56253..76b922867 100644 --- a/src/VecSim/index_factories/tiered_factory.cpp +++ b/src/VecSim/index_factories/tiered_factory.cpp @@ -7,6 +7,8 @@ #include "VecSim/index_factories/tiered_factory.h" #include "VecSim/index_factories/hnsw_factory.h" #include "VecSim/index_factories/brute_force_factory.h" +#include "VecSim/index_factories/raft_ivf_tiered_factory.h" +#include "VecSim/index_factories/raft_ivf_factory.h" #include "VecSim/algorithms/hnsw/hnsw_tiered.h" @@ -89,7 +91,7 @@ VecSimIndex *NewIndex(const TieredIndexParams *params) { } else if (type == VecSimType_FLOAT64) { return TieredHNSWFactory::NewIndex(params); } - } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFTIVF) { + } else if (params->primaryIndexParams->algo == VecSimAlgo_RaftIVF) { VecSimType type = params->primaryIndexParams->algoParams.raftIvfParams.type; if (type == VecSimType_FLOAT32) { return TieredRaftIvfFactory::NewIndex(params); @@ -106,8 +108,8 @@ size_t EstimateInitialSize(const TieredIndexParams *params) { BFParams bf_params{}; if (params->primaryIndexParams->algo == VecSimAlgo_HNSWLIB) { est += TieredHNSWFactory::EstimateInitialSize(params, bf_params); - } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFTIVF) { - est += TieredRaftIvfFactory::EstimateInitialSize(params, bf_params); + } else if (params->primaryIndexParams->algo == VecSimAlgo_RaftIVF) { + est += TieredRaftIvfFactory::EstimateInitialSize(params); } est += BruteForceFactory::EstimateInitialSize(&bf_params); @@ -118,7 +120,7 @@ size_t EstimateElementSize(const TieredIndexParams *params) { size_t est = 0; if (params->primaryIndexParams->algo == VecSimAlgo_HNSWLIB) { est = HNSWFactory::EstimateElementSize(¶ms->primaryIndexParams->algoParams.hnswParams); - } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFTIVF) { + } else if (params->primaryIndexParams->algo == VecSimAlgo_RaftIVF) { est = RaftIvfFactory::EstimateElementSize(¶ms->primaryIndexParams->algoParams.raftIvfParams); } return est; From f3cd7b1df1735c15b5044cf0dfcae3fa158f4b55 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 20 Oct 2023 14:58:55 +0200 Subject: [PATCH 12/28] Add CUDA_ARCHITECTURE for half type --- CMakeLists.txt | 2 ++ src/VecSim/algorithms/raft_ivf/ivf_tiered.cuh | 4 +--- src/VecSim/index_factories/raft_ivf_factory.cu | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0dd92f683..ab7e7bc6e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -30,6 +30,8 @@ include(cmake/san.cmake) project(VectorSimilarity) if (USE_CUDA) + # List of architectures to generate device code + set(CMAKE_CUDA_ARCHITECTURES 80;75;70;61) # Enable CUDA compilation for this project enable_language(CUDA) # Add definition for conditional compilation of CUDA components diff --git a/src/VecSim/algorithms/raft_ivf/ivf_tiered.cuh b/src/VecSim/algorithms/raft_ivf/ivf_tiered.cuh index 12a5e54c1..7f8e48e57 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf_tiered.cuh +++ b/src/VecSim/algorithms/raft_ivf/ivf_tiered.cuh @@ -31,7 +31,7 @@ struct TieredRaftIVFIndex : public VecSimTieredIndex { return 0; } - size_t indexSize() { + size_t indexSize() const override { auto frontend_lock = std::scoped_lock(this->flatIndexGuard); auto backend_lock = std::scoped_lock(this->mainIndexGuard); return (getBackendIndex().indexSize() + this->frontendIndex.indexSize()); @@ -45,8 +45,6 @@ struct TieredRaftIVFIndex : public VecSimTieredIndex { return (getBackendIndex().indexCapacity() + this->flatBufferLimit); } - void increaseCapacity() override { getBackendIndex().increaseCapacity(); } - double getDistanceFrom_Unsafe(labelType label, const void *blob) const override { auto frontend_lock = std::unique_lock(this->flatIndexGuard); auto flat_dist = this->frontendIndex->getDistanceFrom_Unsafe(label, blob); diff --git a/src/VecSim/index_factories/raft_ivf_factory.cu b/src/VecSim/index_factories/raft_ivf_factory.cu index 355afebb4..6a73104bd 100644 --- a/src/VecSim/index_factories/raft_ivf_factory.cu +++ b/src/VecSim/index_factories/raft_ivf_factory.cu @@ -72,8 +72,8 @@ size_t EstimateElementSize(const RaftIvfParams *raftIvfParams) { est += raftIvfParams->dim * VecSimType_sizeof(raftIvfParams->type) + sizeof(labelType); } else { size_t pq_dim = raftIvfParams->pqDim; - if (pq_dim == 0) - pq_dim = raft::neighbors::ivf_pq::calculate_pq_dim(raftIvfParams->dim); + if (pq_dim == 0) // Estimation. + pq_dim = raftIvfParams->dim >= 128 ? raftIvfParams->dim / 2 : raftIvfParams->dim; // Size of vec after compression + size of label est += raftIvfParams->pqBits * pq_dim + sizeof(labelType); } From 28765de49292132469590c4a744bf78ee6a8ee92 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 2 Nov 2023 16:13:14 +0100 Subject: [PATCH 13/28] Add Tiered index update and test --- CMakeLists.txt | 2 +- cmake/raft.cmake | 2 +- .../algorithms/brute_force/brute_force.h | 1 + src/VecSim/algorithms/raft_ivf/ivf.cuh | 25 ++-- src/VecSim/algorithms/raft_ivf/ivf_tiered.cuh | 141 ++++++++++++++---- src/VecSim/index_factories/index_factory.cpp | 6 +- .../index_factories/raft_ivf_factory.cu | 10 +- .../raft_ivf_tiered_factory.cu | 26 ++-- .../index_factories/raft_ivf_tiered_factory.h | 2 - src/VecSim/index_factories/tiered_factory.cpp | 15 +- src/VecSim/utils/vec_utils.cpp | 2 +- src/VecSim/vec_sim_common.h | 8 +- src/VecSim/vec_sim_tiered_index.h | 2 +- tests/unit/CMakeLists.txt | 9 ++ tests/unit/test_raft_ivf_tiered.cpp | 132 ++++++++++++++++ 15 files changed, 303 insertions(+), 80 deletions(-) create mode 100644 tests/unit/test_raft_ivf_tiered.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index ab7e7bc6e..b04480539 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,7 +31,7 @@ project(VectorSimilarity) if (USE_CUDA) # List of architectures to generate device code - set(CMAKE_CUDA_ARCHITECTURES 80;75;70;61) + set(CMAKE_CUDA_ARCHITECTURES "native") # Enable CUDA compilation for this project enable_language(CUDA) # Add definition for conditional compilation of CUDA components diff --git a/cmake/raft.cmake b/cmake/raft.cmake index 46181397f..2c9ac5509 100644 --- a/cmake/raft.cmake +++ b/cmake/raft.cmake @@ -1,6 +1,6 @@ if(USE_CUDA) # Set which version of RAPIDS to use - set(RAPIDS_VERSION 23.10) + set(RAPIDS_VERSION 23.12) # Set which version of RAFT to use (defined separately for testing # minimal dependency changes if necessary) set(RAFT_VERSION "${RAPIDS_VERSION}") diff --git a/src/VecSim/algorithms/brute_force/brute_force.h b/src/VecSim/algorithms/brute_force/brute_force.h index c75546e8c..3e223320d 100644 --- a/src/VecSim/algorithms/brute_force/brute_force.h +++ b/src/VecSim/algorithms/brute_force/brute_force.h @@ -55,6 +55,7 @@ class BruteForceIndex : public VecSimIndexAbstract { VecSimQueryParams *queryParams) const override; bool preferAdHocSearch(size_t subsetSize, size_t k, bool initial_check) const override; inline labelType getVectorLabel(idType id) const { return idToLabelMapping.at(id); } + inline vecsim_stl::vector getLabels() const { return idToLabelMapping; } inline const vecsim_stl::vector &getVectorBlocks() const { return vectorBlocks; } inline const labelType getLabelByInternalId(idType internal_id) const { diff --git a/src/VecSim/algorithms/raft_ivf/ivf.cuh b/src/VecSim/algorithms/raft_ivf/ivf.cuh index e8cf4658f..a56138e2d 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf.cuh +++ b/src/VecSim/algorithms/raft_ivf/ivf.cuh @@ -40,13 +40,13 @@ inline auto constexpr GetRaftDistanceType(VecSimMetric vsm) { return result; } -inline auto constexpr GetRaftCodebookKind(IVFPQCodebookKind vss_codebook) { +inline auto constexpr GetRaftCodebookKind(RaftIVFPQCodebookKind vss_codebook) { auto result = raft::neighbors::ivf_pq::codebook_gen{}; switch (vss_codebook) { - case IVFPQCodebookKind_PerCluster: + case RaftIVFPQCodebookKind_PerCluster: result = raft::neighbors::ivf_pq::codebook_gen::PER_CLUSTER; break; - case IVFPQCodebookKind_PerSubspace: + case RaftIVFPQCodebookKind_PerSubspace: result = raft::neighbors::ivf_pq::codebook_gen::PER_SUBSPACE; break; default: @@ -74,7 +74,7 @@ inline auto constexpr GetCudaType(CudaType vss_type) { } template -struct RaftIVFIndex : public VecSimIndexAbstract { +struct RaftIvfIndex : public VecSimIndexAbstract { using data_type = DataType; using dist_type = DistType; @@ -90,7 +90,7 @@ private: using ann_index_t = std::variant; public: - RaftIVFIndex(const RaftIvfParams *raftIvfParams, const AbstractIndexInitParams &commonParams) + RaftIvfIndex(const RaftIvfParams *raftIvfParams, const AbstractIndexInitParams &commonParams) : VecSimIndexAbstract{commonParams}, res_{raft::device_resources_manager::get_device_resources()}, build_params_{raftIvfParams->usePQ ? build_params_t{std::in_place_index<1>} @@ -103,6 +103,7 @@ public: inner.metric = GetRaftDistanceType(raftIvfParams->metric); inner.n_lists = raftIvfParams->nLists; inner.kmeans_n_iters = raftIvfParams->kmeans_nIters; + inner.add_data_on_build = false; inner.kmeans_trainset_fraction = raftIvfParams->kmeans_trainsetFraction; inner.conservative_memory_allocation = raftIvfParams->conservativeMemoryAllocation; if constexpr (std::is_same_v(*index_)); + &std::get(*index_)); } else { if (!index_) { index_ = raft::neighbors::ivf_pq::build( @@ -169,7 +170,7 @@ public: raft::neighbors::ivf_pq::extend( res_, raft::make_const_mdspan(vector_data_gpu.view()), std::make_optional(raft::make_const_mdspan(label_gpu.view())), - std::get(*index_)); + &std::get(*index_)); } return batch_size; @@ -237,8 +238,8 @@ public: auto neighbors = vecsim_stl::vector(k, this->allocator); auto distances = vecsim_stl::vector(k, this->allocator); // Copy data back from device to host - raft::copy(neighbors.data(), neighbors_gpu.data_handle(), this->dim, res_.get_stream()); - raft::copy(distances.data(), distances_gpu.data_handle(), this->dim, res_.get_stream()); + raft::copy(neighbors.data(), neighbors_gpu.data_handle(), k, res_.get_stream()); + raft::copy(distances.data(), distances_gpu.data_handle(), k, res_.get_stream()); // Ensure search is complete and data have been copied back before // building query result objects on host @@ -268,8 +269,8 @@ public: auto &get_resources() const { return res_; } - auto nLists() { - return std::visit([](auto &¶ms) { return params.n_list; }, build_params_); + auto nLists() const { + return std::visit([](auto &¶ms) { return params.n_lists; }, build_params_); } size_t indexSize() const override { @@ -281,7 +282,7 @@ public: } VecSimIndexBasicInfo basicInfo() const override { VecSimIndexBasicInfo info = this->getBasicInfo(); - info.algo = VecSimAlgo_RaftIVF; + info.algo = VecSimAlgo_RAFTIVF; info.isTiered = false; return info; } diff --git a/src/VecSim/algorithms/raft_ivf/ivf_tiered.cuh b/src/VecSim/algorithms/raft_ivf/ivf_tiered.cuh index 7f8e48e57..0df6e1cb9 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf_tiered.cuh +++ b/src/VecSim/algorithms/raft_ivf/ivf_tiered.cuh @@ -13,28 +13,65 @@ struct RAFTTransferJob : public AsyncJob { }; template -struct TieredRaftIVFIndex : public VecSimTieredIndex { +struct TieredRaftIvfIndex : public VecSimTieredIndex { + TieredRaftIvfIndex(RaftIvfIndex* raftIvfIndex, + BruteForceIndex *bf_index, + const TieredIndexParams &tieredParams, + std::shared_ptr allocator) + : VecSimTieredIndex(raftIvfIndex, bf_index, tieredParams, allocator) + { + assert(raftIvfIndex->nLists() < this->flatBufferLimit && + "The flat buffer limit must be greater than the number of lists in the backend index"); + } + ~TieredRaftIvfIndex() { + // Delete all the pending jobs + } + int addVector(const void *blob, labelType label, void *auxiliaryCtx) override { - auto frontend_lock = std::scoped_lock(this->flatIndexGuard); - auto result = this->frontendIndex->addVector(blob, label); + int ret = 1; + // If the flat index is full, write to the backend index if (this->frontendIndex->indexSize() >= this->flatBufferLimit) { - transferToBackend(); + // If the backend index is empty, build it with all the vectors + // Otherwise, just add the vector to the backend index + if (this->backendIndex->indexSize() == 0) { + executeTransferJob(); + } else { + this->mainIndexGuard.lock(); + ret = this->backendIndex->addVector(blob, label); + this->mainIndexGuard.unlock(); + return ret; + } } - return result; + + // Add the vector to the flat index + this->flatIndexGuard.lock(); + ret = this->frontendIndex->addVector(blob, label); + this->flatIndexGuard.unlock(); + + // Submit a transfer job + AsyncJob *new_insert_job = new (this->allocator) + RAFTTransferJob(this->allocator, executeTransferJobWrapper, this); + this->submitSingleJob(new_insert_job); + return ret; } int deleteVector(labelType label) override { - // TODO(wphicks) - // If in flatIndex, delete - // If being transferred to backend, wait for transfer - // If in backendIndex, delete - return 0; + this->flatIndexGuard.lock(); + auto result = this->frontendIndex->deleteVector(label); + this->mainIndexGuard.lock(); + this->flatIndexGuard.unlock(); + result += this->backendIndex->deleteVector(label); + this->mainIndexGuard.unlock(); + return result; } size_t indexSize() const override { - auto frontend_lock = std::scoped_lock(this->flatIndexGuard); - auto backend_lock = std::scoped_lock(this->mainIndexGuard); - return (getBackendIndex().indexSize() + this->frontendIndex.indexSize()); + this->flatIndexGuard.lock_shared(); + this->mainIndexGuard.lock_shared(); + size_t result = (this->backendIndex->indexSize() + this->frontendIndex->indexSize()); + this->flatIndexGuard.unlock_shared(); + this->mainIndexGuard.unlock_shared(); + return result; } size_t indexLabelCount() const override { @@ -42,49 +79,87 @@ struct TieredRaftIVFIndex : public VecSimTieredIndex { } size_t indexCapacity() const override { - return (getBackendIndex().indexCapacity() + this->flatBufferLimit); + return (this->backendIndex->indexCapacity() + this->frontendIndex->indexCapacity()); } double getDistanceFrom_Unsafe(labelType label, const void *blob) const override { - auto frontend_lock = std::unique_lock(this->flatIndexGuard); auto flat_dist = this->frontendIndex->getDistanceFrom_Unsafe(label, blob); - frontend_lock.unlock(); - auto backend_lock = std::scoped_lock(this->mainIndexGuard); auto raft_dist = getBackendIndex().getDistanceFrom_Unsafe(label, blob); return std::fmin(flat_dist, raft_dist); } - void executeTransferJob(RAFTTransferJob *job) { transferToBackend(); } + static void executeTransferJobWrapper(AsyncJob *job) { + if (job->isValid) { + auto *transfer_job = reinterpret_cast(job); + auto *job_index = reinterpret_cast *>(transfer_job->index); + job_index->executeTransferJob(); + } + delete job; + } + + VecSimIndexBasicInfo basicInfo() const override{} + + VecSimBatchIterator *newBatchIterator(const void *queryBlob, + VecSimQueryParams *queryParams) const override {} + + inline void setLastSearchMode(VecSearchMode mode) override {} + + void runGC() override {} + + void acquireSharedLocks() override { + this->flatIndexGuard.lock_shared(); + this->mainIndexGuard.lock_shared(); + } + + void releaseSharedLocks() override { + this->flatIndexGuard.unlock_shared(); + this->mainIndexGuard.unlock_shared(); + } private: - vecsim_stl::unordered_map> labelToTransferJobs; - auto &getBackendIndex() { - return *dynamic_cast *>(this->backendIndex); + + inline auto &getBackendIndex() const { + return *dynamic_cast *>(this->backendIndex); } - void transferToBackend() { + void executeTransferJob() { auto frontend_lock = std::unique_lock(this->flatIndexGuard); - auto nVectors = this->flatBuffer->indexSize(); + auto nVectors = this->frontendIndex->indexSize(); + // No vectors to transfer if (nVectors == 0) { frontend_lock.unlock(); return; } - auto dim = this->index->getDim(); - const auto &vectorBlocks = this->flatBuffer->getVectorBlocks(); - auto vectorData = raft::make_host_matrix(getBackendIndex().get_resources(), nVectors, dim); - auto *out = vectorData.data_handle(); + // If the backend index is empty, don't transfer less than nLists vectors + this->mainIndexGuard.lock_shared(); + auto main_nVectors = this->backendIndex->indexSize(); + this->mainIndexGuard.unlock_shared(); + if (main_nVectors == 0) { + if (nVectors < getBackendIndex().nLists()) { + frontend_lock.unlock(); + return; + } + } + auto dim = this->backendIndex->getDim(); + const auto &vectorBlocks = this->frontendIndex->getVectorBlocks(); + auto* vectorData = (DataType *)this->allocator->allocate(nVectors * dim * sizeof (DataType)); + + // Transfer vectors to a contiguous buffer + auto *curr_ptr = vectorData; for (auto block_id = 0; block_id < vectorBlocks.size(); ++block_id) { - auto *in_begin = reinterpret_cast(vectorBlocks[block_id].getElement(0)); + const auto *in_begin = reinterpret_cast(vectorBlocks[block_id].getElement(0)); auto length = vectorBlocks[block_id].getLength(); - std::copy(in_begin, in_begin + length, out); - out += length; + std::copy(in_begin, in_begin + (length * dim), curr_ptr); + curr_ptr += length * dim; } + // Add the vectors to the backend index auto backend_lock = std::scoped_lock(this->mainIndexGuard); - this->flatBuffer->clear(); - frontend_lock.unlock(); - getBackendIndex().addVectorBatch(vectorData.data_handle(), this->flatBuffer->getLabels(), + getBackendIndex().addVectorBatch(vectorData, this->frontendIndex->getLabels().data(), nVectors); + this->frontendIndex->clear(); + frontend_lock.unlock(); + this->allocator->free_allocation(vectorData); } }; diff --git a/src/VecSim/index_factories/index_factory.cpp b/src/VecSim/index_factories/index_factory.cpp index f2d419c63..4f34e2eaf 100644 --- a/src/VecSim/index_factories/index_factory.cpp +++ b/src/VecSim/index_factories/index_factory.cpp @@ -26,7 +26,7 @@ VecSimIndex *NewIndex(const VecSimParams *params) { index = BruteForceFactory::NewIndex(params); break; } - case VecSimAlgo_RaftIVF: { + case VecSimAlgo_RAFTIVF: { index = RaftIvfFactory::NewIndex(¶ms->algoParams.raftIvfParams); break; } @@ -47,7 +47,7 @@ size_t EstimateInitialSize(const VecSimParams *params) { return HNSWFactory::EstimateInitialSize(¶ms->algoParams.hnswParams); case VecSimAlgo_BF: return BruteForceFactory::EstimateInitialSize(¶ms->algoParams.bfParams); - case VecSimAlgo_RaftIVF: + case VecSimAlgo_RAFTIVF: return RaftIvfFactory::EstimateInitialSize(¶ms->algoParams.raftIvfParams); case VecSimAlgo_TIERED: return TieredFactory::EstimateInitialSize(¶ms->algoParams.tieredParams); @@ -61,7 +61,7 @@ size_t EstimateElementSize(const VecSimParams *params) { return HNSWFactory::EstimateElementSize(¶ms->algoParams.hnswParams); case VecSimAlgo_BF: return BruteForceFactory::EstimateElementSize(¶ms->algoParams.bfParams); - case VecSimAlgo_RaftIVF: + case VecSimAlgo_RAFTIVF: return RaftIvfFactory::EstimateElementSize(¶ms->algoParams.raftIvfParams); case VecSimAlgo_TIERED: return TieredFactory::EstimateElementSize(¶ms->algoParams.tieredParams); diff --git a/src/VecSim/index_factories/raft_ivf_factory.cu b/src/VecSim/index_factories/raft_ivf_factory.cu index 6a73104bd..d1aa7705a 100644 --- a/src/VecSim/index_factories/raft_ivf_factory.cu +++ b/src/VecSim/index_factories/raft_ivf_factory.cu @@ -19,13 +19,11 @@ static AbstractIndexInitParams NewAbstractInitParams(const VecSimParams *params) } VecSimIndex *NewIndex(const RaftIvfParams *raftIvfParams, const AbstractIndexInitParams &abstractInitParams) { + assert(raftIvfParams->type == VecSimType_FLOAT32 && "Invalid IVF data type algorithm"); if (raftIvfParams->type == VecSimType_FLOAT32) { return new (abstractInitParams.allocator) - RaftIVFIndex(raftIvfParams, abstractInitParams); - } else if (raftIvfParams->type == VecSimType_FLOAT64) { - return new (abstractInitParams.allocator) - RaftIVFIndex(raftIvfParams, abstractInitParams); - } + RaftIvfIndex(raftIvfParams, abstractInitParams); + } // If we got here something is wrong. return NULL; @@ -47,7 +45,7 @@ size_t EstimateInitialSize(const RaftIvfParams *raftIvfParams) { // Constant part (not effected by parameters). size_t est = sizeof(VecSimAllocator) + allocations_overhead; - est += sizeof(RaftIVFIndex); // Object size + est += sizeof(RaftIvfIndex); // Object size if (!raftIvfParams->usePQ) { // Size of each cluster data est += raftIvfParams->nLists * sizeof(raft::neighbors::ivf_flat::list_data); diff --git a/src/VecSim/index_factories/raft_ivf_tiered_factory.cu b/src/VecSim/index_factories/raft_ivf_tiered_factory.cu index 97f246f97..dc75d7d3e 100644 --- a/src/VecSim/index_factories/raft_ivf_tiered_factory.cu +++ b/src/VecSim/index_factories/raft_ivf_tiered_factory.cu @@ -1,23 +1,29 @@ #include "VecSim/index_factories/brute_force_factory.h" #include "VecSim/algorithms/raft_ivf/ivf_tiered.cuh" +#include "VecSim/algorithms/raft_ivf/ivf.cuh" #include "VecSim/index_factories/tiered_factory.h" #include "VecSim/index_factories/raft_ivf_factory.h" namespace TieredRaftIvfFactory { -template VecSimIndex *NewIndex(const TieredIndexParams *params) { + assert(params->primaryIndexParams->algoParams.raftIvfParams.type == VecSimType_FLOAT32 && + "Invalid IVF data type algorithm"); + + using DataType = float; + using DistType = float; // initialize raft index - auto *raft_index = reinterpret_cast *>( - RaftIVFFactory::NewIndex(params->primaryIndexParams)); + auto *raft_index = reinterpret_cast *>( + RaftIvfFactory::NewIndex(params->primaryIndexParams)); // initialize brute force index BFParams bf_params = { .type = params->primaryIndexParams->algoParams.raftIvfParams.type, .dim = params->primaryIndexParams->algoParams.raftIvfParams.dim, .metric = params->primaryIndexParams->algoParams.raftIvfParams.metric, .multi = params->primaryIndexParams->algoParams.raftIvfParams.multi, - .blockSize = params->primaryIndexParams->algoParams.raftIvfParams.blockSize}; + //.blockSize = params->primaryIndexParams->algoParams.raftIvfParams.blockSize + }; std::shared_ptr flat_allocator = VecSimAllocator::newVecsimAllocator(); AbstractIndexInitParams abstractInitParams = {.allocator = flat_allocator, @@ -30,18 +36,18 @@ VecSimIndex *NewIndex(const TieredIndexParams *params) auto frontendIndex = static_cast *>( BruteForceFactory::NewIndex(&bf_params, abstractInitParams)); - // Create new tiered hnsw index + // Create new tiered RaftIVF index std::shared_ptr management_layer_allocator = VecSimAllocator::newVecsimAllocator(); - return new (management_layer_allocator) TieredRaftIVFIndex( + return new (management_layer_allocator) TieredRaftIvfIndex( raft_index, frontendIndex, *params, management_layer_allocator); } // The size estimation is the sum of the buffer (brute force) and main index initial sizes // estimations, plus the tiered index class size. Note it does not include the size of internal // containers such as the job queue, as those depend on the user implementation. -inline size_t EstimateInitialSize(const TieredIndexParams *params) { +size_t EstimateInitialSize(const TieredIndexParams *params) { auto raft_ivf_params = params->primaryIndexParams->algoParams.raftIvfParams; // Add size estimation of VecSimTieredIndex sub indexes. @@ -51,11 +57,11 @@ inline size_t EstimateInitialSize(const TieredIndexParams *params) { size_t allocations_overhead = VecSimAllocator::getAllocationOverheadSize(); est += sizeof(VecSimAllocator) + allocations_overhead; - // Size of the TieredRaftIVFIndex struct. + // Size of the TieredRaftIvfIndex struct. if (raft_ivf_params.type == VecSimType_FLOAT32) { - est += sizeof(TieredRaftIVFIndex); + est += sizeof(TieredRaftIvfIndex); } else if (raft_ivf_params.type == VecSimType_FLOAT64) { - est += sizeof(TieredRaftIVFIndex); + est += sizeof(TieredRaftIvfIndex); } return est; diff --git a/src/VecSim/index_factories/raft_ivf_tiered_factory.h b/src/VecSim/index_factories/raft_ivf_tiered_factory.h index e38cdd204..a89486531 100644 --- a/src/VecSim/index_factories/raft_ivf_tiered_factory.h +++ b/src/VecSim/index_factories/raft_ivf_tiered_factory.h @@ -7,13 +7,11 @@ namespace TieredRaftIvfFactory { -template VecSimIndex *NewIndex(const TieredIndexParams *params); // The size estimation is the sum of the buffer (brute force) and main index initial sizes // estimations, plus the tiered index class size. Note it does not include the size of internal // containers such as the job queue, as those depend on the user implementation. size_t EstimateInitialSize(const TieredIndexParams *params); -size_t EstimateElementSize(const TieredIndexParams *params); }; // namespace TieredRaftIvfFactory diff --git a/src/VecSim/index_factories/tiered_factory.cpp b/src/VecSim/index_factories/tiered_factory.cpp index 76b922867..f83b78103 100644 --- a/src/VecSim/index_factories/tiered_factory.cpp +++ b/src/VecSim/index_factories/tiered_factory.cpp @@ -7,8 +7,11 @@ #include "VecSim/index_factories/tiered_factory.h" #include "VecSim/index_factories/hnsw_factory.h" #include "VecSim/index_factories/brute_force_factory.h" + +#ifdef USE_CUDA #include "VecSim/index_factories/raft_ivf_tiered_factory.h" #include "VecSim/index_factories/raft_ivf_factory.h" +#endif #include "VecSim/algorithms/hnsw/hnsw_tiered.h" @@ -91,13 +94,13 @@ VecSimIndex *NewIndex(const TieredIndexParams *params) { } else if (type == VecSimType_FLOAT64) { return TieredHNSWFactory::NewIndex(params); } - } else if (params->primaryIndexParams->algo == VecSimAlgo_RaftIVF) { + } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFTIVF) { VecSimType type = params->primaryIndexParams->algoParams.raftIvfParams.type; if (type == VecSimType_FLOAT32) { - return TieredRaftIvfFactory::NewIndex(params); - } else if (type == VecSimType_FLOAT64) { + return TieredRaftIvfFactory::NewIndex(params); + }/* else if (type == VecSimType_FLOAT64) { return TieredRaftIvfFactory::NewIndex(params); - } + }*/ } return nullptr; // Invalid algorithm or type. } @@ -108,7 +111,7 @@ size_t EstimateInitialSize(const TieredIndexParams *params) { BFParams bf_params{}; if (params->primaryIndexParams->algo == VecSimAlgo_HNSWLIB) { est += TieredHNSWFactory::EstimateInitialSize(params, bf_params); - } else if (params->primaryIndexParams->algo == VecSimAlgo_RaftIVF) { + } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFTIVF) { est += TieredRaftIvfFactory::EstimateInitialSize(params); } @@ -120,7 +123,7 @@ size_t EstimateElementSize(const TieredIndexParams *params) { size_t est = 0; if (params->primaryIndexParams->algo == VecSimAlgo_HNSWLIB) { est = HNSWFactory::EstimateElementSize(¶ms->primaryIndexParams->algoParams.hnswParams); - } else if (params->primaryIndexParams->algo == VecSimAlgo_RaftIVF) { + } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFTIVF) { est = RaftIvfFactory::EstimateElementSize(¶ms->primaryIndexParams->algoParams.raftIvfParams); } return est; diff --git a/src/VecSim/utils/vec_utils.cpp b/src/VecSim/utils/vec_utils.cpp index 27cae0e06..4f5764724 100644 --- a/src/VecSim/utils/vec_utils.cpp +++ b/src/VecSim/utils/vec_utils.cpp @@ -126,7 +126,7 @@ const char *VecSimAlgo_ToString(VecSimAlgo vecsimAlgo) { return VecSimCommonStrings::FLAT_STRING; case VecSimAlgo_HNSWLIB: return VecSimCommonStrings::HNSW_STRING; - case VecSimAlgo_RaftIVF: + case VecSimAlgo_RAFTIVF: return VecSimCommonStrings::RAFTIVF_STRING; case VecSimAlgo_TIERED: return VecSimCommonStrings::TIERED_STRING; diff --git a/src/VecSim/vec_sim_common.h b/src/VecSim/vec_sim_common.h index db4a00779..a4eafb9ce 100644 --- a/src/VecSim/vec_sim_common.h +++ b/src/VecSim/vec_sim_common.h @@ -38,13 +38,13 @@ typedef enum { } VecSimType; // Algorithm type/library. -typedef enum { VecSimAlgo_BF, VecSimAlgo_HNSWLIB, VecSimAlgo_RaftIVF, VecSimAlgo_TIERED } VecSimAlgo; +typedef enum { VecSimAlgo_BF, VecSimAlgo_HNSWLIB, VecSimAlgo_RAFTIVF, VecSimAlgo_TIERED } VecSimAlgo; // Distance metric typedef enum { VecSimMetric_L2, VecSimMetric_IP, VecSimMetric_Cosine } VecSimMetric; // Codebook kind for IVFPQ indexes -typedef enum { IVFPQCodebookKind_PerCluster, IVFPQCodebookKind_PerSubspace } IVFPQCodebookKind; +typedef enum { RaftIVFPQCodebookKind_PerCluster, RaftIVFPQCodebookKind_PerSubspace } RaftIVFPQCodebookKind; // CUDA types supported by GPU-accelerated indexes typedef enum { CUDAType_R_32F, CUDAType_R_16F, CUDAType_R_8U } CudaType; @@ -159,7 +159,7 @@ typedef struct { // compression. If set to 0, a heuristic will be used to // select the dimensionality. - IVFPQCodebookKind codebookKind; + RaftIVFPQCodebookKind codebookKind; CudaType lutType; CudaType internalDistanceType; double preferredShmemCarveout; // Fraction of GPU's unified memory / L1 @@ -309,7 +309,7 @@ typedef struct { union { bfInfoStruct bfInfo; hnswInfoStruct hnswInfo; - raftIvfInfoStruct raftInfo; + raftIvfInfoStruct raftIvfInfo; tieredInfoStruct tieredInfo; }; } VecSimIndexInfo; diff --git a/src/VecSim/vec_sim_tiered_index.h b/src/VecSim/vec_sim_tiered_index.h index 768bf02b8..6318bf8db 100644 --- a/src/VecSim/vec_sim_tiered_index.h +++ b/src/VecSim/vec_sim_tiered_index.h @@ -284,7 +284,7 @@ VecSimIndexInfo VecSimTieredIndex::info() const { case VecSimAlgo_HNSWLIB: info.tieredInfo.backendInfo.hnswInfo = backendInfo.hnswInfo; break; - case VecSimAlgo_RaftIVF: // TODO Add RaftIVF info + case VecSimAlgo_RAFTIVF: // TODO Add RaftIVF info case VecSimAlgo_BF: case VecSimAlgo_TIERED: assert(false && "Invalid backend algorithm"); diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index e3cb19bdb..0bd5f363b 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -35,6 +35,11 @@ add_executable(test_allocator test_allocator.cpp test_utils.cpp) add_executable(test_spaces test_spaces.cpp) add_executable(test_common ../utils/mock_thread_pool.cpp test_utils.cpp test_common.cpp) +if(USE_CUDA) + add_executable(test_raftivf ../utils/mock_thread_pool.cpp test_raft_ivf_tiered.cpp test_utils.cpp) + target_link_libraries(test_raftivf PUBLIC gtest_main VectorSimilarity PRIVATE raft::raft) +endif() + target_link_libraries(test_hnsw PUBLIC gtest_main VectorSimilarity) target_link_libraries(test_hnsw_parallel PUBLIC gtest_main VectorSimilarity) target_link_libraries(test_bruteforce PUBLIC gtest_main VectorSimilarity) @@ -50,3 +55,7 @@ gtest_discover_tests(test_bruteforce) gtest_discover_tests(test_allocator) gtest_discover_tests(test_spaces) gtest_discover_tests(test_common) + +if(USE_CUDA) + gtest_discover_tests(test_raftivf) +endif() diff --git a/tests/unit/test_raft_ivf_tiered.cpp b/tests/unit/test_raft_ivf_tiered.cpp new file mode 100644 index 000000000..e7a781251 --- /dev/null +++ b/tests/unit/test_raft_ivf_tiered.cpp @@ -0,0 +1,132 @@ +#include "gtest/gtest.h" +#include "VecSim/vec_sim.h" +#include "VecSim/vec_sim_common.h" +#include "VecSim/index_factories/tiered_factory.h" +#include "test_utils.h" +#include +#include +#include + +#include "mock_thread_pool.h" + + +template +class RaftIvfTieredTest : public ::testing::Test { +public: + using data_t = typename index_type_t::data_t; + using dist_t = typename index_type_t::dist_t; +}; + +VecSimParams createDefaultPQParams(size_t dim, size_t nLists = 3, size_t nProbes = 3) { + RaftIvfParams ivfparams = {.dim = dim, + .metric = VecSimMetric_L2, + .nLists = nLists, + .kmeans_nIters = 20, + .kmeans_trainsetFraction = 0.5, + .nProbes = nProbes, + .usePQ = true, + .pqBits = 8, + .pqDim = 0, + .codebookKind = RaftIVFPQCodebookKind_PerSubspace, + .lutType = CUDAType_R_32F, + .internalDistanceType = CUDAType_R_32F, + .preferredShmemCarveout = 1.0}; + VecSimParams params{.algo = VecSimAlgo_RAFTIVF, .algoParams = {.raftIvfParams = ivfparams}}; + return params; +} + +VecSimParams createDefaultFlatParams(size_t dim, size_t nLists = 3, size_t nProbes = 3) { + RaftIvfParams ivfparams = {.dim = dim, + .metric = VecSimMetric_L2, + .nLists = nLists, + .kmeans_nIters = 20, + .kmeans_trainsetFraction = 0.5, + .nProbes = nProbes, + .usePQ = false}; + VecSimParams params{.algo = VecSimAlgo_RAFTIVF, .algoParams = {.raftIvfParams = ivfparams}}; + return params; +} + +VecSimIndex* createTieredIndex(VecSimParams *params, + tieredIndexMock &mock_thread_pool, + size_t flat_buffer_limit = 0) { + TieredIndexParams params_tiered = { + .jobQueue = &mock_thread_pool.jobQ, + .jobQueueCtx = mock_thread_pool.ctx, + .submitCb = tieredIndexMock::submit_callback, + .flatBufferLimit = flat_buffer_limit, + .primaryIndexParams = params, + }; + auto *tiered_index = TieredFactory::NewIndex(¶ms_tiered); + // Set the created tiered index in the index external context (it will take ownership over + // the index, and we'll need to release the ctx at the end of the test. + mock_thread_pool.ctx->index_strong_ref.reset(tiered_index); + + return tiered_index; +} + +using DataTypeSetFloat = ::testing::Types>; + +TYPED_TEST_SUITE(RaftIvfTieredTest, DataTypeSetFloat); + +TYPED_TEST(RaftIvfTieredTest, RaftIVFTiered_PQ_add_sanity_test) { + size_t dim = 4; + size_t flat_buffer_limit = 3; + size_t nLists = 2; + + VecSimParams params = createDefaultFlatParams(dim, nLists, nLists); + auto mock_thread_pool = tieredIndexMock(); + auto *index = createTieredIndex(¶ms, mock_thread_pool, flat_buffer_limit); + mock_thread_pool.init_threads(); + + VecSimQueryParams queryParams = {.batchSize = 1}; + + ASSERT_EQ(VecSimIndex_IndexSize(index), 0); + + TEST_DATA_T a[dim], b[dim], c[dim], d[dim], e[dim], zero[dim]; + std::vector a_vec(dim, (TEST_DATA_T)1); + std::vector b_vec(dim, (TEST_DATA_T)2); + std::vector c_vec(dim, (TEST_DATA_T)4); + std::vector d_vec(dim, (TEST_DATA_T)5); + std::vector zero_vec(dim, (TEST_DATA_T)0); + /*for (size_t i = 0; i < dim; i++) { + a[i] = (TEST_DATA_T)1; + b[i] = (TEST_DATA_T)2; + c[i] = (TEST_DATA_T)4; + d[i] = (TEST_DATA_T)5; + zero[i] = (TEST_DATA_T)0; + }*/ + auto inserted_vectors = std::vector>{a_vec, b_vec, c_vec, d_vec}; + + // Search for vectors when the index is empty. + runTopKSearchTest(index, a_vec.data(), 1, nullptr); + + // Add vectors. + VecSimIndex_AddVector(index, a_vec.data(), 0); + ASSERT_EQ(VecSimIndex_IndexSize(index), 1); + VecSimIndex_AddVector(index, b_vec.data(), 1); + VecSimIndex_AddVector(index, c_vec.data(), 2); + VecSimIndex_AddVector(index, d_vec.data(), 3); + ASSERT_EQ(VecSimIndex_IndexSize(index), 4); + + + mock_thread_pool.thread_pool_join(); + EXPECT_EQ(mock_thread_pool.jobQ.size(), 0); + // Callbacks for verifying results. + auto ver_res_0 = [&](size_t id, double score, size_t index) { + ASSERT_EQ(id, index); + ASSERT_DOUBLE_EQ(score, dim * inserted_vectors[id][0] * inserted_vectors[id][0]); + }; + size_t result_c[] = {2, 3, 1, 0}; // Order of results for query on c. + auto ver_res_c = [&](size_t id, double score, size_t index) { + ASSERT_EQ(id, result_c[index]); + double dist = inserted_vectors[id][0] - c_vec[0]; + ASSERT_DOUBLE_EQ(score, dim * dist * dist); + }; + + auto k = 4; + runTopKSearchTest(index, zero_vec.data(), k, ver_res_0); + runTopKSearchTest(index, c_vec.data(), k, ver_res_c); + VecSimIndex_Free(index); +} + From 54b895c68ffcd1cc525dc20a67b36c8cc4065e8c Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 8 Nov 2023 18:07:47 +0100 Subject: [PATCH 14/28] Separate cuda code and flat/pq --- CMakeLists.txt | 2 +- src/VecSim/CMakeLists.txt | 9 +- src/VecSim/algorithms/raft_ivf/ivf.cu | 103 +++++++++++++++ .../algorithms/raft_ivf/{ivf.cuh => ivf.h} | 118 +++--------------- .../raft_ivf/{ivf_tiered.cuh => ivf_tiered.h} | 68 ++++++---- src/VecSim/index_factories/index_factory.cpp | 9 +- ...ft_ivf_factory.cu => raft_ivf_factory.cpp} | 3 +- ...factory.cu => raft_ivf_tiered_factory.cpp} | 4 +- src/VecSim/index_factories/tiered_factory.cpp | 9 +- src/VecSim/utils/vec_utils.cpp | 9 +- src/VecSim/utils/vec_utils.h | 3 +- src/VecSim/vec_sim_common.h | 3 +- src/VecSim/vec_sim_tiered_index.h | 5 +- tests/unit/test_raft_ivf_tiered.cpp | 5 +- 14 files changed, 203 insertions(+), 147 deletions(-) create mode 100644 src/VecSim/algorithms/raft_ivf/ivf.cu rename src/VecSim/algorithms/raft_ivf/{ivf.cuh => ivf.h} (64%) rename src/VecSim/algorithms/raft_ivf/{ivf_tiered.cuh => ivf_tiered.h} (72%) rename src/VecSim/index_factories/{raft_ivf_factory.cu => raft_ivf_factory.cpp} (97%) rename src/VecSim/index_factories/{raft_ivf_tiered_factory.cu => raft_ivf_tiered_factory.cpp} (97%) diff --git a/CMakeLists.txt b/CMakeLists.txt index b04480539..8df0bbaa4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,7 +46,7 @@ endif() # Only do these if this is the main project, and not if it is included through add_subdirectory set_property(GLOBAL PROPERTY USE_FOLDERS ON) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions -fPIC -pthread ${CLANG_SAN_FLAGS} ${LLVM_CXX_FLAGS} ${COV_CXX_FLAGS}") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions -fPIC -pthread ${CLANG_SAN_FLAGS} ${LLVM_CXX_FLAGS} ${COV_CXX_FLAGS} -lrt") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_LINKER_FLAGS} ${LLVM_LD_FLAGS}") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_LINKER_FLAGS} ${LLVM_LD_FLAGS}") diff --git a/src/VecSim/CMakeLists.txt b/src/VecSim/CMakeLists.txt index 323121d2d..35c655a1e 100644 --- a/src/VecSim/CMakeLists.txt +++ b/src/VecSim/CMakeLists.txt @@ -19,8 +19,9 @@ add_library(VectorSimilarity ${VECSIM_LIBTYPE} index_factories/hnsw_factory.cpp index_factories/tiered_factory.cpp index_factories/index_factory.cpp - $<$:index_factories/raft_ivf_factory.cu> - $<$:index_factories/raft_ivf_tiered_factory.cu> + $<$:index_factories/raft_ivf_factory.cpp> + $<$:index_factories/raft_ivf_tiered_factory.cpp> + $<$:algorithms/raft_ivf/ivf.cu> algorithms/hnsw/visited_nodes_handler.cpp vec_sim.cpp vec_sim_interface.cpp @@ -44,4 +45,8 @@ PUBLIC $<$:VectorSimilaritySerializer> PRIVATE $<$:raft::raft> + CUDA::cusolver + CUDA::cublas + CUDA::curand + CUDA::cusparse ) diff --git a/src/VecSim/algorithms/raft_ivf/ivf.cu b/src/VecSim/algorithms/raft_ivf/ivf.cu new file mode 100644 index 000000000..a119d5f24 --- /dev/null +++ b/src/VecSim/algorithms/raft_ivf/ivf.cu @@ -0,0 +1,103 @@ +#include +#include +#include "ivf.h" + +template <> +int RaftIvfIndex::addVectorBatchAsync(const void *vector_data, labelType *label, + size_t batch_size, void *auxiliaryCtx) { + // Convert labels to internal data type + auto label_original = std::vector(label, label + batch_size); + auto label_converted = + std::vector(label_original.begin(), label_original.end()); + // Allocate memory on device to hold vectors to be added + auto vector_data_gpu = + raft::make_device_matrix(res_, batch_size, this->dim); + // Allocate memory on device to hold vector labels + auto label_gpu = raft::make_device_vector(res_, batch_size); + + // Copy vector data to previously allocated device buffer + raft::copy(vector_data_gpu.data_handle(), static_cast(vector_data), + this->dim * batch_size, res_.get_stream()); + // Copy label data to previously allocated device buffer + raft::copy(label_gpu.data_handle(), label_converted.data(), batch_size, res_.get_stream()); + + if (std::holds_alternative(build_params_)) { + if (!index_) { + index_ = raft::neighbors::ivf_flat::build( + res_, std::get(build_params_), + raft::make_const_mdspan(vector_data_gpu.view())); + } + raft::neighbors::ivf_flat::extend( + res_, raft::make_const_mdspan(vector_data_gpu.view()), + std::make_optional(raft::make_const_mdspan(label_gpu.view())), + &std::get(*index_)); + } else { + if (!index_) { + index_ = raft::neighbors::ivf_pq::build( + res_, std::get(build_params_), + raft::make_const_mdspan(vector_data_gpu.view())); + } + raft::neighbors::ivf_pq::extend( + res_, raft::make_const_mdspan(vector_data_gpu.view()), + std::make_optional(raft::make_const_mdspan(label_gpu.view())), + &std::get(*index_)); + } + + return batch_size; +} + +template <> +VecSimQueryReply * +RaftIvfIndex::topKQuery(const void *queryBlob, size_t k, + VecSimQueryParams *queryParams) const { + auto result_list = new VecSimQueryReply(this->allocator); + auto nVectors = this->indexSize(); + if (nVectors == 0 || k == 0 || !index_.has_value()) { + return result_list; + } + // Ensure we are not trying to retrieve more vectors than exist in the + // index + k = std::min(k, nVectors); + // Allocate memory on device for search vector + auto vector_data_gpu = raft::make_device_matrix(res_, 1, this->dim); + // Allocate memory on device for neighbor and distance results + auto neighbors_gpu = raft::make_device_matrix(res_, 1, k); + auto distances_gpu = raft::make_device_matrix(res_, 1, k); + // Copy query vector to device + raft::copy(vector_data_gpu.data_handle(), static_cast(queryBlob), this->dim, + res_.get_stream()); + + // Perform correct search based on index type + if (std::holds_alternative(*index_)) { + raft::neighbors::ivf_flat::search( + res_, std::get(search_params_), + std::get(*index_), raft::make_const_mdspan(vector_data_gpu.view()), + neighbors_gpu.view(), distances_gpu.view()); + // TODO ADD STREAM MANAGER + } else { + raft::neighbors::ivf_pq::search( + res_, std::get(search_params_), + std::get(*index_), raft::make_const_mdspan(vector_data_gpu.view()), + neighbors_gpu.view(), distances_gpu.view()); + // TODO ADD STREAM MANAGER + } + + // Allocate host buffers to hold returned results + auto neighbors = vecsim_stl::vector(k, this->allocator); + auto distances = vecsim_stl::vector(k, this->allocator); + // Copy data back from device to host + raft::copy(neighbors.data(), neighbors_gpu.data_handle(), k, res_.get_stream()); + raft::copy(distances.data(), distances_gpu.data_handle(), k, res_.get_stream()); + + // Ensure search is complete and data have been copied back before + // building query result objects on host + res_.sync_stream(); + + result_list->results.resize(k); + for (auto i = 0; i < k; ++i) { + result_list->results[i].id = labelType{neighbors[i]}; + result_list->results[i].score = distances[i]; + } + + return result_list; +} diff --git a/src/VecSim/algorithms/raft_ivf/ivf.cuh b/src/VecSim/algorithms/raft_ivf/ivf.h similarity index 64% rename from src/VecSim/algorithms/raft_ivf/ivf.cuh rename to src/VecSim/algorithms/raft_ivf/ivf.h index a56138e2d..084f5657a 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf.cuh +++ b/src/VecSim/algorithms/raft_ivf/ivf.h @@ -20,9 +20,7 @@ #include #include #include -#include #include -#include #include inline auto constexpr GetRaftDistanceType(VecSimMetric vsm) { @@ -134,47 +132,7 @@ struct RaftIvfIndex : public VecSimIndexAbstract { return addVectorBatch(vector_data, &label, 1, auxiliaryCtx); } int addVectorBatchAsync(const void *vector_data, labelType *label, size_t batch_size, - void *auxiliaryCtx = nullptr) { - // Convert labels to internal data type - auto label_original = std::vector(label, label + batch_size); - auto label_converted = - std::vector(label_original.begin(), label_original.end()); - // Allocate memory on device to hold vectors to be added - auto vector_data_gpu = - raft::make_device_matrix(res_, batch_size, this->dim); - // Allocate memory on device to hold vector labels - auto label_gpu = raft::make_device_vector(res_, batch_size); - - // Copy vector data to previously allocated device buffer - raft::copy(vector_data_gpu.data_handle(), static_cast(vector_data), - this->dim * batch_size, res_.get_stream()); - // Copy label data to previously allocated device buffer - raft::copy(label_gpu.data_handle(), label_converted.data(), batch_size, res_.get_stream()); - - if (std::holds_alternative(build_params_)) { - if (!index_) { - index_ = raft::neighbors::ivf_flat::build( - res_, std::get(build_params_), - raft::make_const_mdspan(vector_data_gpu.view())); - } - raft::neighbors::ivf_flat::extend( - res_, raft::make_const_mdspan(vector_data_gpu.view()), - std::make_optional(raft::make_const_mdspan(label_gpu.view())), - &std::get(*index_)); - } else { - if (!index_) { - index_ = raft::neighbors::ivf_pq::build( - res_, std::get(build_params_), - raft::make_const_mdspan(vector_data_gpu.view())); - } - raft::neighbors::ivf_pq::extend( - res_, raft::make_const_mdspan(vector_data_gpu.view()), - std::make_optional(raft::make_const_mdspan(label_gpu.view())), - &std::get(*index_)); - } - - return batch_size; - } + void *auxiliaryCtx = nullptr); int addVectorBatch(const void *vector_data, labelType *label, size_t batch_size, void *auxiliaryCtx = nullptr) { auto result = addVectorBatchAsync(vector_data, label, batch_size, auxiliaryCtx); @@ -200,71 +158,25 @@ struct RaftIvfIndex : public VecSimIndexAbstract { return this->indexSize(); // TODO: Return unique counts } VecSimQueryReply *topKQuery(const void *queryBlob, size_t k, - VecSimQueryParams *queryParams) const override { - auto result_list = new VecSimQueryReply(this->allocator); - auto nVectors = this->indexSize(); - if (nVectors == 0 || k == 0 || !index_.has_value()) { - return result_list; - } - // Ensure we are not trying to retrieve more vectors than exist in the - // index - k = std::min(k, nVectors); - // Allocate memory on device for search vector - auto vector_data_gpu = - raft::make_device_matrix(res_, 1, this->dim); - // Allocate memory on device for neighbor and distance results - auto neighbors_gpu = raft::make_device_matrix(res_, 1, k); - auto distances_gpu = raft::make_device_matrix(res_, 1, k); - // Copy query vector to device - raft::copy(vector_data_gpu.data_handle(), static_cast(queryBlob), this->dim, - res_.get_stream()); - - // Perform correct search based on index type - if (std::holds_alternative(*index_)) { - raft::neighbors::ivf_flat::search( - res_, std::get(search_params_), - std::get(*index_), raft::make_const_mdspan(vector_data_gpu.view()), - neighbors_gpu.view(), distances_gpu.view()); - // TODO ADD STREAM MANAGER - } else { - raft::neighbors::ivf_pq::search( - res_, std::get(search_params_), - std::get(*index_), raft::make_const_mdspan(vector_data_gpu.view()), - neighbors_gpu.view(), distances_gpu.view()); - // TODO ADD STREAM MANAGER - } - - // Allocate host buffers to hold returned results - auto neighbors = vecsim_stl::vector(k, this->allocator); - auto distances = vecsim_stl::vector(k, this->allocator); - // Copy data back from device to host - raft::copy(neighbors.data(), neighbors_gpu.data_handle(), k, res_.get_stream()); - raft::copy(distances.data(), distances_gpu.data_handle(), k, res_.get_stream()); - - // Ensure search is complete and data have been copied back before - // building query result objects on host - res_.sync_stream(); - - result_list->results.resize(k); - for (auto i = 0; i < k; ++i) { - result_list->results[i].id = labelType{neighbors[i]}; - result_list->results[i].score = distances[i]; - } - - return result_list; - } + VecSimQueryParams *queryParams) const override; virtual VecSimQueryReply *rangeQuery(const void *queryBlob, double radius, - VecSimQueryParams *queryParams) const override { + VecSimQueryParams *queryParams) const override { assert(!"RangeQuery not implemented"); + return nullptr; + } + VecSimInfoIterator *infoIterator() const override { + assert(!"infoIterator not implemented"); + return nullptr; } - VecSimInfoIterator *infoIterator() const override { assert(!"infoIterator not implemented"); } virtual VecSimBatchIterator *newBatchIterator(const void *queryBlob, VecSimQueryParams *queryParams) const override { assert(!"newBatchIterator not implemented"); + return nullptr; } bool preferAdHocSearch(size_t subsetSize, size_t k, bool initial_check) const override { assert(!"preferAdHocSearch not implemented"); + return false; } auto &get_resources() const { return res_; } @@ -282,7 +194,11 @@ struct RaftIvfIndex : public VecSimIndexAbstract { } VecSimIndexBasicInfo basicInfo() const override { VecSimIndexBasicInfo info = this->getBasicInfo(); - info.algo = VecSimAlgo_RAFTIVF; + if (std::holds_alternative(build_params_)) { + info.algo = VecSimAlgo_RAFT_IVFFLAT; + } else { + info.algo = VecSimAlgo_RAFT_IVFPQ; + } info.isTiered = false; return info; } @@ -299,6 +215,10 @@ struct RaftIvfIndex : public VecSimIndexAbstract { return info; } + inline void setNProbes(uint32_t n_probes) { + std::visit([n_probes](auto &¶ms) { params.n_probes = n_probes; }, search_params_); + } + private: // An object used to manage common device resources that may be // expensive to build but frequently accessed diff --git a/src/VecSim/algorithms/raft_ivf/ivf_tiered.cuh b/src/VecSim/algorithms/raft_ivf/ivf_tiered.h similarity index 72% rename from src/VecSim/algorithms/raft_ivf/ivf_tiered.cuh rename to src/VecSim/algorithms/raft_ivf/ivf_tiered.h index 0df6e1cb9..67d7359c1 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf_tiered.cuh +++ b/src/VecSim/algorithms/raft_ivf/ivf_tiered.h @@ -1,27 +1,25 @@ #pragma once #include -#include "VecSim/algorithms/raft_ivf/ivf.cuh" +#include "VecSim/algorithms/raft_ivf/ivf.h" #include "VecSim/vec_sim_tiered_index.h" struct RAFTTransferJob : public AsyncJob { - RAFTTransferJob(std::shared_ptr allocator, - JobCallback insertCb, VecSimIndex *index_) - : AsyncJob{allocator, RAFT_TRANSFER_JOB, insertCb, index_} - { - } + RAFTTransferJob(std::shared_ptr allocator, JobCallback insertCb, + VecSimIndex *index_) + : AsyncJob{allocator, RAFT_TRANSFER_JOB, insertCb, index_} {} }; template struct TieredRaftIvfIndex : public VecSimTieredIndex { - TieredRaftIvfIndex(RaftIvfIndex* raftIvfIndex, + TieredRaftIvfIndex(RaftIvfIndex *raftIvfIndex, BruteForceIndex *bf_index, const TieredIndexParams &tieredParams, std::shared_ptr allocator) - : VecSimTieredIndex(raftIvfIndex, bf_index, tieredParams, allocator) - { - assert(raftIvfIndex->nLists() < this->flatBufferLimit && - "The flat buffer limit must be greater than the number of lists in the backend index"); + : VecSimTieredIndex(raftIvfIndex, bf_index, tieredParams, allocator) { + assert( + raftIvfIndex->nLists() < this->flatBufferLimit && + "The flat buffer limit must be greater than the number of lists in the backend index"); } ~TieredRaftIvfIndex() { // Delete all the pending jobs @@ -49,8 +47,8 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { this->flatIndexGuard.unlock(); // Submit a transfer job - AsyncJob *new_insert_job = new (this->allocator) - RAFTTransferJob(this->allocator, executeTransferJobWrapper, this); + AsyncJob *new_insert_job = + new (this->allocator) RAFTTransferJob(this->allocator, executeTransferJobWrapper, this); this->submitSingleJob(new_insert_job); return ret; } @@ -76,6 +74,7 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { size_t indexLabelCount() const override { // TODO(wphicks) Count unique labels between both indexes + return 0; } size_t indexCapacity() const override { @@ -91,16 +90,24 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { static void executeTransferJobWrapper(AsyncJob *job) { if (job->isValid) { auto *transfer_job = reinterpret_cast(job); - auto *job_index = reinterpret_cast *>(transfer_job->index); + auto *job_index = + reinterpret_cast *>(transfer_job->index); job_index->executeTransferJob(); } delete job; } - VecSimIndexBasicInfo basicInfo() const override{} + VecSimIndexBasicInfo basicInfo() const override { + VecSimIndexBasicInfo info = this->backendIndex->getBasicInfo(); + info.isTiered = true; + return info; + } VecSimBatchIterator *newBatchIterator(const void *queryBlob, - VecSimQueryParams *queryParams) const override {} + VecSimQueryParams *queryParams) const override { + assert(!"newBatchIterator not implemented"); + return nullptr; + } inline void setLastSearchMode(VecSearchMode mode) override {} @@ -116,8 +123,13 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { this->mainIndexGuard.unlock_shared(); } -private: + inline void setNProbes(uint32_t n_probes) { + this->mainIndexGuard.lock(); + this->getBackendIndex().setNProbes(n_probes); + this->mainIndexGuard.unlock(); + } +private: inline auto &getBackendIndex() const { return *dynamic_cast *>(this->backendIndex); } @@ -135,20 +147,19 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { this->mainIndexGuard.lock_shared(); auto main_nVectors = this->backendIndex->indexSize(); this->mainIndexGuard.unlock_shared(); - if (main_nVectors == 0) { - if (nVectors < getBackendIndex().nLists()) { - frontend_lock.unlock(); - return; - } + if (main_nVectors == 0 && nVectors < getBackendIndex().nLists()) { + frontend_lock.unlock(); + return; } auto dim = this->backendIndex->getDim(); const auto &vectorBlocks = this->frontendIndex->getVectorBlocks(); - auto* vectorData = (DataType *)this->allocator->allocate(nVectors * dim * sizeof (DataType)); + auto *vectorData = (DataType *)this->allocator->allocate(nVectors * dim * sizeof(DataType)); - // Transfer vectors to a contiguous buffer + // Transfer vectors to a contiguous host buffer auto *curr_ptr = vectorData; - for (auto block_id = 0; block_id < vectorBlocks.size(); ++block_id) { - const auto *in_begin = reinterpret_cast(vectorBlocks[block_id].getElement(0)); + for (std::uint32_t block_id = 0; block_id < vectorBlocks.size(); ++block_id) { + const auto *in_begin = + reinterpret_cast(vectorBlocks[block_id].getElement(0)); auto length = vectorBlocks[block_id].getLength(); std::copy(in_begin, in_begin + (length * dim), curr_ptr); curr_ptr += length * dim; @@ -162,4 +173,9 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { frontend_lock.unlock(); this->allocator->free_allocation(vectorData); } + +#ifdef BUILD_TESTS + INDEX_TEST_FRIEND_CLASS(BM_VecSimBasics) + INDEX_TEST_FRIEND_CLASS(BM_VecSimCommon) +#endif }; diff --git a/src/VecSim/index_factories/index_factory.cpp b/src/VecSim/index_factories/index_factory.cpp index 4f34e2eaf..bcda2c7e7 100644 --- a/src/VecSim/index_factories/index_factory.cpp +++ b/src/VecSim/index_factories/index_factory.cpp @@ -26,7 +26,8 @@ VecSimIndex *NewIndex(const VecSimParams *params) { index = BruteForceFactory::NewIndex(params); break; } - case VecSimAlgo_RAFTIVF: { + case VecSimAlgo_RAFT_IVFFLAT: + case VecSimAlgo_RAFT_IVFPQ: { index = RaftIvfFactory::NewIndex(¶ms->algoParams.raftIvfParams); break; } @@ -47,7 +48,8 @@ size_t EstimateInitialSize(const VecSimParams *params) { return HNSWFactory::EstimateInitialSize(¶ms->algoParams.hnswParams); case VecSimAlgo_BF: return BruteForceFactory::EstimateInitialSize(¶ms->algoParams.bfParams); - case VecSimAlgo_RAFTIVF: + case VecSimAlgo_RAFT_IVFFLAT: + case VecSimAlgo_RAFT_IVFPQ: return RaftIvfFactory::EstimateInitialSize(¶ms->algoParams.raftIvfParams); case VecSimAlgo_TIERED: return TieredFactory::EstimateInitialSize(¶ms->algoParams.tieredParams); @@ -61,7 +63,8 @@ size_t EstimateElementSize(const VecSimParams *params) { return HNSWFactory::EstimateElementSize(¶ms->algoParams.hnswParams); case VecSimAlgo_BF: return BruteForceFactory::EstimateElementSize(¶ms->algoParams.bfParams); - case VecSimAlgo_RAFTIVF: + case VecSimAlgo_RAFT_IVFFLAT: + case VecSimAlgo_RAFT_IVFPQ: return RaftIvfFactory::EstimateElementSize(¶ms->algoParams.raftIvfParams); case VecSimAlgo_TIERED: return TieredFactory::EstimateElementSize(¶ms->algoParams.tieredParams); diff --git a/src/VecSim/index_factories/raft_ivf_factory.cu b/src/VecSim/index_factories/raft_ivf_factory.cpp similarity index 97% rename from src/VecSim/index_factories/raft_ivf_factory.cu rename to src/VecSim/index_factories/raft_ivf_factory.cpp index d1aa7705a..6c16324a3 100644 --- a/src/VecSim/index_factories/raft_ivf_factory.cu +++ b/src/VecSim/index_factories/raft_ivf_factory.cpp @@ -1,6 +1,5 @@ #include "VecSim/index_factories/brute_force_factory.h" -#include "VecSim/algorithms/raft_ivf/ivf.cuh" -#include +#include "VecSim/algorithms/raft_ivf/ivf.h" namespace RaftIvfFactory { diff --git a/src/VecSim/index_factories/raft_ivf_tiered_factory.cu b/src/VecSim/index_factories/raft_ivf_tiered_factory.cpp similarity index 97% rename from src/VecSim/index_factories/raft_ivf_tiered_factory.cu rename to src/VecSim/index_factories/raft_ivf_tiered_factory.cpp index dc75d7d3e..37b797fd9 100644 --- a/src/VecSim/index_factories/raft_ivf_tiered_factory.cu +++ b/src/VecSim/index_factories/raft_ivf_tiered_factory.cpp @@ -1,6 +1,6 @@ #include "VecSim/index_factories/brute_force_factory.h" -#include "VecSim/algorithms/raft_ivf/ivf_tiered.cuh" -#include "VecSim/algorithms/raft_ivf/ivf.cuh" +#include "VecSim/algorithms/raft_ivf/ivf_tiered.h" +#include "VecSim/algorithms/raft_ivf/ivf.h" #include "VecSim/index_factories/tiered_factory.h" #include "VecSim/index_factories/raft_ivf_factory.h" diff --git a/src/VecSim/index_factories/tiered_factory.cpp b/src/VecSim/index_factories/tiered_factory.cpp index f83b78103..03bd48ed0 100644 --- a/src/VecSim/index_factories/tiered_factory.cpp +++ b/src/VecSim/index_factories/tiered_factory.cpp @@ -94,7 +94,8 @@ VecSimIndex *NewIndex(const TieredIndexParams *params) { } else if (type == VecSimType_FLOAT64) { return TieredHNSWFactory::NewIndex(params); } - } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFTIVF) { + } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFFLAT || + params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFPQ) { VecSimType type = params->primaryIndexParams->algoParams.raftIvfParams.type; if (type == VecSimType_FLOAT32) { return TieredRaftIvfFactory::NewIndex(params); @@ -111,7 +112,8 @@ size_t EstimateInitialSize(const TieredIndexParams *params) { BFParams bf_params{}; if (params->primaryIndexParams->algo == VecSimAlgo_HNSWLIB) { est += TieredHNSWFactory::EstimateInitialSize(params, bf_params); - } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFTIVF) { + } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFFLAT || + params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFPQ) { est += TieredRaftIvfFactory::EstimateInitialSize(params); } @@ -123,7 +125,8 @@ size_t EstimateElementSize(const TieredIndexParams *params) { size_t est = 0; if (params->primaryIndexParams->algo == VecSimAlgo_HNSWLIB) { est = HNSWFactory::EstimateElementSize(¶ms->primaryIndexParams->algoParams.hnswParams); - } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFTIVF) { + } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFFLAT || + params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFPQ) { est = RaftIvfFactory::EstimateElementSize(¶ms->primaryIndexParams->algoParams.raftIvfParams); } return est; diff --git a/src/VecSim/utils/vec_utils.cpp b/src/VecSim/utils/vec_utils.cpp index 4f5764724..d782210a3 100644 --- a/src/VecSim/utils/vec_utils.cpp +++ b/src/VecSim/utils/vec_utils.cpp @@ -15,7 +15,8 @@ const char *VecSimCommonStrings::ALGORITHM_STRING = "ALGORITHM"; const char *VecSimCommonStrings::FLAT_STRING = "FLAT"; const char *VecSimCommonStrings::HNSW_STRING = "HNSW"; -const char *VecSimCommonStrings::RAFTIVF_STRING = "RAFT_IVF"; +const char *VecSimCommonStrings::RAFTIVFFLAT_STRING = "RAFT_IVF_FLAT"; +const char *VecSimCommonStrings::RAFTIVFPQ_STRING = "RAFT_IVF_PQ"; const char *VecSimCommonStrings::TIERED_STRING = "TIERED"; const char *VecSimCommonStrings::TYPE_STRING = "TYPE"; @@ -126,8 +127,10 @@ const char *VecSimAlgo_ToString(VecSimAlgo vecsimAlgo) { return VecSimCommonStrings::FLAT_STRING; case VecSimAlgo_HNSWLIB: return VecSimCommonStrings::HNSW_STRING; - case VecSimAlgo_RAFTIVF: - return VecSimCommonStrings::RAFTIVF_STRING; + case VecSimAlgo_RAFT_IVFFLAT: + return VecSimCommonStrings::RAFTIVFFLAT_STRING; + case VecSimAlgo_RAFT_IVFPQ: + return VecSimCommonStrings::RAFTIVFPQ_STRING; case VecSimAlgo_TIERED: return VecSimCommonStrings::TIERED_STRING; } diff --git a/src/VecSim/utils/vec_utils.h b/src/VecSim/utils/vec_utils.h index 1916444b2..723a1c2b9 100644 --- a/src/VecSim/utils/vec_utils.h +++ b/src/VecSim/utils/vec_utils.h @@ -18,7 +18,8 @@ struct VecSimCommonStrings { static const char *ALGORITHM_STRING; static const char *FLAT_STRING; static const char *HNSW_STRING; - static const char *RAFTIVF_STRING; + static const char *RAFTIVFFLAT_STRING; + static const char *RAFTIVFPQ_STRING; static const char *TIERED_STRING; static const char *TYPE_STRING; diff --git a/src/VecSim/vec_sim_common.h b/src/VecSim/vec_sim_common.h index a4eafb9ce..0a0c49c94 100644 --- a/src/VecSim/vec_sim_common.h +++ b/src/VecSim/vec_sim_common.h @@ -38,7 +38,7 @@ typedef enum { } VecSimType; // Algorithm type/library. -typedef enum { VecSimAlgo_BF, VecSimAlgo_HNSWLIB, VecSimAlgo_RAFTIVF, VecSimAlgo_TIERED } VecSimAlgo; +typedef enum { VecSimAlgo_BF, VecSimAlgo_HNSWLIB, VecSimAlgo_RAFT_IVFFLAT, VecSimAlgo_RAFT_IVFPQ, VecSimAlgo_TIERED } VecSimAlgo; // Distance metric typedef enum { VecSimMetric_L2, VecSimMetric_IP, VecSimMetric_Cosine } VecSimMetric; @@ -286,6 +286,7 @@ typedef struct { // Since we cannot recursively have a struct that contains itself, we need this workaround. union { hnswInfoStruct hnswInfo; + raftIvfInfoStruct raftIvfInfo; } backendInfo; // The backend index info. union { HnswTieredInfo hnswTieredInfo; diff --git a/src/VecSim/vec_sim_tiered_index.h b/src/VecSim/vec_sim_tiered_index.h index 6318bf8db..775521f4d 100644 --- a/src/VecSim/vec_sim_tiered_index.h +++ b/src/VecSim/vec_sim_tiered_index.h @@ -284,7 +284,10 @@ VecSimIndexInfo VecSimTieredIndex::info() const { case VecSimAlgo_HNSWLIB: info.tieredInfo.backendInfo.hnswInfo = backendInfo.hnswInfo; break; - case VecSimAlgo_RAFTIVF: // TODO Add RaftIVF info + case VecSimAlgo_RAFT_IVFFLAT: + case VecSimAlgo_RAFT_IVFPQ: + info.tieredInfo.backendInfo.raftIvfInfo = backendInfo.raftIvfInfo; + break; case VecSimAlgo_BF: case VecSimAlgo_TIERED: assert(false && "Invalid backend algorithm"); diff --git a/tests/unit/test_raft_ivf_tiered.cpp b/tests/unit/test_raft_ivf_tiered.cpp index e7a781251..7d1164b01 100644 --- a/tests/unit/test_raft_ivf_tiered.cpp +++ b/tests/unit/test_raft_ivf_tiered.cpp @@ -31,7 +31,7 @@ VecSimParams createDefaultPQParams(size_t dim, size_t nLists = 3, size_t nProbes .lutType = CUDAType_R_32F, .internalDistanceType = CUDAType_R_32F, .preferredShmemCarveout = 1.0}; - VecSimParams params{.algo = VecSimAlgo_RAFTIVF, .algoParams = {.raftIvfParams = ivfparams}}; + VecSimParams params{.algo = VecSimAlgo_RAFT_IVFPQ, .algoParams = {.raftIvfParams = ivfparams}}; return params; } @@ -43,7 +43,7 @@ VecSimParams createDefaultFlatParams(size_t dim, size_t nLists = 3, size_t nProb .kmeans_trainsetFraction = 0.5, .nProbes = nProbes, .usePQ = false}; - VecSimParams params{.algo = VecSimAlgo_RAFTIVF, .algoParams = {.raftIvfParams = ivfparams}}; + VecSimParams params{.algo = VecSimAlgo_RAFT_IVFFLAT, .algoParams = {.raftIvfParams = ivfparams}}; return params; } @@ -127,6 +127,5 @@ TYPED_TEST(RaftIvfTieredTest, RaftIVFTiered_PQ_add_sanity_test) { auto k = 4; runTopKSearchTest(index, zero_vec.data(), k, ver_res_0); runTopKSearchTest(index, c_vec.data(), k, ver_res_c); - VecSimIndex_Free(index); } From 1f99e555bbb7bc67e2b15a20f9e8d12168957cfc Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 14 Nov 2023 14:54:12 +0100 Subject: [PATCH 15/28] Rework IVF, add stream manager, interface, benchmark --- src/VecSim/CMakeLists.txt | 3 +- src/VecSim/algorithms/raft_ivf/ivf.cu | 103 -------------- .../algorithms/raft_ivf/{ivf.h => ivf.cuh} | 131 +++++++++++++++--- .../algorithms/raft_ivf/ivf_interface.h | 20 +++ src/VecSim/algorithms/raft_ivf/ivf_tiered.h | 64 +++++---- ...ft_ivf_factory.cpp => raft_ivf_factory.cu} | 2 +- .../raft_ivf_tiered_factory.cpp | 4 +- src/VecSim/vec_sim_common.h | 2 +- tests/benchmark/bm_common.h | 75 ++++++++++ .../bm_basics_initialize_fp32.h | 5 + tests/benchmark/bm_vecsim_general.h | 31 +++++ tests/benchmark/bm_vecsim_index.h | 19 +++ tests/unit/test_raft_ivf_tiered.cpp | 53 ++++--- 13 files changed, 328 insertions(+), 184 deletions(-) delete mode 100644 src/VecSim/algorithms/raft_ivf/ivf.cu rename src/VecSim/algorithms/raft_ivf/{ivf.h => ivf.cuh} (59%) create mode 100644 src/VecSim/algorithms/raft_ivf/ivf_interface.h rename src/VecSim/index_factories/{raft_ivf_factory.cpp => raft_ivf_factory.cu} (98%) diff --git a/src/VecSim/CMakeLists.txt b/src/VecSim/CMakeLists.txt index 35c655a1e..279fc5322 100644 --- a/src/VecSim/CMakeLists.txt +++ b/src/VecSim/CMakeLists.txt @@ -19,9 +19,8 @@ add_library(VectorSimilarity ${VECSIM_LIBTYPE} index_factories/hnsw_factory.cpp index_factories/tiered_factory.cpp index_factories/index_factory.cpp - $<$:index_factories/raft_ivf_factory.cpp> + $<$:index_factories/raft_ivf_factory.cu> $<$:index_factories/raft_ivf_tiered_factory.cpp> - $<$:algorithms/raft_ivf/ivf.cu> algorithms/hnsw/visited_nodes_handler.cpp vec_sim.cpp vec_sim_interface.cpp diff --git a/src/VecSim/algorithms/raft_ivf/ivf.cu b/src/VecSim/algorithms/raft_ivf/ivf.cu deleted file mode 100644 index a119d5f24..000000000 --- a/src/VecSim/algorithms/raft_ivf/ivf.cu +++ /dev/null @@ -1,103 +0,0 @@ -#include -#include -#include "ivf.h" - -template <> -int RaftIvfIndex::addVectorBatchAsync(const void *vector_data, labelType *label, - size_t batch_size, void *auxiliaryCtx) { - // Convert labels to internal data type - auto label_original = std::vector(label, label + batch_size); - auto label_converted = - std::vector(label_original.begin(), label_original.end()); - // Allocate memory on device to hold vectors to be added - auto vector_data_gpu = - raft::make_device_matrix(res_, batch_size, this->dim); - // Allocate memory on device to hold vector labels - auto label_gpu = raft::make_device_vector(res_, batch_size); - - // Copy vector data to previously allocated device buffer - raft::copy(vector_data_gpu.data_handle(), static_cast(vector_data), - this->dim * batch_size, res_.get_stream()); - // Copy label data to previously allocated device buffer - raft::copy(label_gpu.data_handle(), label_converted.data(), batch_size, res_.get_stream()); - - if (std::holds_alternative(build_params_)) { - if (!index_) { - index_ = raft::neighbors::ivf_flat::build( - res_, std::get(build_params_), - raft::make_const_mdspan(vector_data_gpu.view())); - } - raft::neighbors::ivf_flat::extend( - res_, raft::make_const_mdspan(vector_data_gpu.view()), - std::make_optional(raft::make_const_mdspan(label_gpu.view())), - &std::get(*index_)); - } else { - if (!index_) { - index_ = raft::neighbors::ivf_pq::build( - res_, std::get(build_params_), - raft::make_const_mdspan(vector_data_gpu.view())); - } - raft::neighbors::ivf_pq::extend( - res_, raft::make_const_mdspan(vector_data_gpu.view()), - std::make_optional(raft::make_const_mdspan(label_gpu.view())), - &std::get(*index_)); - } - - return batch_size; -} - -template <> -VecSimQueryReply * -RaftIvfIndex::topKQuery(const void *queryBlob, size_t k, - VecSimQueryParams *queryParams) const { - auto result_list = new VecSimQueryReply(this->allocator); - auto nVectors = this->indexSize(); - if (nVectors == 0 || k == 0 || !index_.has_value()) { - return result_list; - } - // Ensure we are not trying to retrieve more vectors than exist in the - // index - k = std::min(k, nVectors); - // Allocate memory on device for search vector - auto vector_data_gpu = raft::make_device_matrix(res_, 1, this->dim); - // Allocate memory on device for neighbor and distance results - auto neighbors_gpu = raft::make_device_matrix(res_, 1, k); - auto distances_gpu = raft::make_device_matrix(res_, 1, k); - // Copy query vector to device - raft::copy(vector_data_gpu.data_handle(), static_cast(queryBlob), this->dim, - res_.get_stream()); - - // Perform correct search based on index type - if (std::holds_alternative(*index_)) { - raft::neighbors::ivf_flat::search( - res_, std::get(search_params_), - std::get(*index_), raft::make_const_mdspan(vector_data_gpu.view()), - neighbors_gpu.view(), distances_gpu.view()); - // TODO ADD STREAM MANAGER - } else { - raft::neighbors::ivf_pq::search( - res_, std::get(search_params_), - std::get(*index_), raft::make_const_mdspan(vector_data_gpu.view()), - neighbors_gpu.view(), distances_gpu.view()); - // TODO ADD STREAM MANAGER - } - - // Allocate host buffers to hold returned results - auto neighbors = vecsim_stl::vector(k, this->allocator); - auto distances = vecsim_stl::vector(k, this->allocator); - // Copy data back from device to host - raft::copy(neighbors.data(), neighbors_gpu.data_handle(), k, res_.get_stream()); - raft::copy(distances.data(), distances_gpu.data_handle(), k, res_.get_stream()); - - // Ensure search is complete and data have been copied back before - // building query result objects on host - res_.sync_stream(); - - result_list->results.resize(k); - for (auto i = 0; i < k; ++i) { - result_list->results[i].id = labelType{neighbors[i]}; - result_list->results[i].score = distances[i]; - } - - return result_list; -} diff --git a/src/VecSim/algorithms/raft_ivf/ivf.h b/src/VecSim/algorithms/raft_ivf/ivf.cuh similarity index 59% rename from src/VecSim/algorithms/raft_ivf/ivf.h rename to src/VecSim/algorithms/raft_ivf/ivf.cuh index 084f5657a..46e81f80e 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf.h +++ b/src/VecSim/algorithms/raft_ivf/ivf.cuh @@ -14,6 +14,7 @@ // For VecSimIndexAbstract #include "VecSim/vec_sim_index.h" #include "VecSim/query_result_definitions.h" // VecSimQueryResult VecSimQueryReply +#include "VecSim/algorithms/raft_ivf/ivf_interface.h" // RaftIvfInterface #include "VecSim/memory/vecsim_malloc.h" #include @@ -72,7 +73,7 @@ inline auto constexpr GetCudaType(CudaType vss_type) { } template -struct RaftIvfIndex : public VecSimIndexAbstract { +struct RaftIvfIndex : public RaftIvfInterface { using data_type = DataType; using dist_type = DistType; @@ -82,15 +83,14 @@ struct RaftIvfIndex : public VecSimIndexAbstract { raft::neighbors::ivf_pq::index_params>; using search_params_t = std::variant; - using internal_idx_t = std::int64_t; - using index_flat_t = raft::neighbors::ivf_flat::index; - using index_pq_t = raft::neighbors::ivf_pq::index; + //using internal_idx_t = std::int64_t; + using index_flat_t = raft::neighbors::ivf_flat::index; + using index_pq_t = raft::neighbors::ivf_pq::index; using ann_index_t = std::variant; public: RaftIvfIndex(const RaftIvfParams *raftIvfParams, const AbstractIndexInitParams &commonParams) - : VecSimIndexAbstract{commonParams}, - res_{raft::device_resources_manager::get_device_resources()}, + : RaftIvfInterface{commonParams}, build_params_{raftIvfParams->usePQ ? build_params_t{std::in_place_index<1>} : build_params_t{std::in_place_index<0>}}, search_params_{raftIvfParams->usePQ ? search_params_t{std::in_place_index<1>} @@ -127,19 +127,61 @@ struct RaftIvfIndex : public VecSimIndexAbstract { } }, search_params_); + + raft::device_resources_manager::set_streams_per_device(16); // TODO: use env variable + raft::device_resources_manager::set_stream_pools_per_device(16); + // Create a 5 GB memory pool. Passing std::nullopt will allow + // the pool to grow to the available memory of the device. + raft::device_resources_manager::set_mem_pool(size_t{5000} << 20, std::nullopt); } int addVector(const void *vector_data, labelType label, void *auxiliaryCtx = nullptr) override { return addVectorBatch(vector_data, &label, 1, auxiliaryCtx); } - int addVectorBatchAsync(const void *vector_data, labelType *label, size_t batch_size, - void *auxiliaryCtx = nullptr); - int addVectorBatch(const void *vector_data, labelType *label, size_t batch_size, - void *auxiliaryCtx = nullptr) { - auto result = addVectorBatchAsync(vector_data, label, batch_size, auxiliaryCtx); + virtual int addVectorBatch(const void *vector_data, labelType *label, size_t batch_size, + void *auxiliaryCtx = nullptr) override { + auto& res = raft::device_resources_manager::get_device_resources(); + // Convert labels to internal data type + /*auto label_original = std::vector(label, label + batch_size); + auto label_converted = + std::vector(label_original.begin(), label_original.end());*/ + // Allocate memory on device to hold vectors to be added + auto vector_data_gpu = + raft::make_device_matrix(res, batch_size, this->dim); + // Allocate memory on device to hold vector labels + auto label_gpu = raft::make_device_vector(res, batch_size); + + // Copy vector data to previously allocated device buffer + raft::copy(vector_data_gpu.data_handle(), static_cast(vector_data), + this->dim * batch_size, res.get_stream()); + // Copy label data to previously allocated device buffer + raft::copy(label_gpu.data_handle(), label, batch_size, res.get_stream()); + + if (std::holds_alternative(build_params_)) { + if (!index_) { + index_ = raft::neighbors::ivf_flat::build( + res, std::get(build_params_), + raft::make_const_mdspan(vector_data_gpu.view())); + } + raft::neighbors::ivf_flat::extend( + res, raft::make_const_mdspan(vector_data_gpu.view()), + std::make_optional(raft::make_const_mdspan(label_gpu.view())), + &std::get(*index_)); + } else { + if (!index_) { + index_ = raft::neighbors::ivf_pq::build( + res, std::get(build_params_), + raft::make_const_mdspan(vector_data_gpu.view())); + } + raft::neighbors::ivf_pq::extend( + res, raft::make_const_mdspan(vector_data_gpu.view()), + std::make_optional(raft::make_const_mdspan(label_gpu.view())), + &std::get(*index_)); + } + // Ensure that above operation has executed on device before // returning from this function on host - res_.sync_stream(); - return result; + res.sync_stream(); + return batch_size; } int deleteVector(labelType label) override { assert(!"deleteVector not implemented"); @@ -158,7 +200,57 @@ struct RaftIvfIndex : public VecSimIndexAbstract { return this->indexSize(); // TODO: Return unique counts } VecSimQueryReply *topKQuery(const void *queryBlob, size_t k, - VecSimQueryParams *queryParams) const override; + VecSimQueryParams *queryParams) const override { + auto& res = raft::device_resources_manager::get_device_resources(); + auto result_list = new VecSimQueryReply(this->allocator); + auto nVectors = this->indexSize(); + if (nVectors == 0 || k == 0 || !index_.has_value()) { + return result_list; + } + // Ensure we are not trying to retrieve more vectors than exist in the + // index + k = std::min(k, nVectors); + // Allocate memory on device for search vector + auto vector_data_gpu = raft::make_device_matrix(res, 1, this->dim); + // Allocate memory on device for neighbor and distance results + auto neighbors_gpu = raft::make_device_matrix(res, 1, k); + auto distances_gpu = raft::make_device_matrix(res, 1, k); + // Copy query vector to device + raft::copy(vector_data_gpu.data_handle(), static_cast(queryBlob), this->dim, + res.get_stream()); + + // Perform correct search based on index type + if (std::holds_alternative(*index_)) { + raft::neighbors::ivf_flat::search( + res, std::get(search_params_), + std::get(*index_), raft::make_const_mdspan(vector_data_gpu.view()), + neighbors_gpu.view(), distances_gpu.view()); + } else { + raft::neighbors::ivf_pq::search( + res, std::get(search_params_), + std::get(*index_), raft::make_const_mdspan(vector_data_gpu.view()), + neighbors_gpu.view(), distances_gpu.view()); + } + + // Allocate host buffers to hold returned results + auto neighbors = vecsim_stl::vector(k, this->allocator); + auto distances = vecsim_stl::vector(k, this->allocator); + // Copy data back from device to host + raft::copy(neighbors.data(), neighbors_gpu.data_handle(), k, res.get_stream()); + raft::copy(distances.data(), distances_gpu.data_handle(), k, res.get_stream()); + + // Ensure search is complete and data have been copied back before + // building query result objects on host + res.sync_stream(); + + result_list->results.resize(k); + for (auto i = 0; i < k; ++i) { + result_list->results[i].id = labelType{neighbors[i]}; + result_list->results[i].score = distances[i]; + } + + return result_list; + } virtual VecSimQueryReply *rangeQuery(const void *queryBlob, double radius, VecSimQueryParams *queryParams) const override { @@ -179,9 +271,7 @@ struct RaftIvfIndex : public VecSimIndexAbstract { return false; } - auto &get_resources() const { return res_; } - - auto nLists() const { + virtual uint32_t nLists() const override { return std::visit([](auto &¶ms) { return params.n_lists; }, build_params_); } @@ -215,14 +305,11 @@ struct RaftIvfIndex : public VecSimIndexAbstract { return info; } - inline void setNProbes(uint32_t n_probes) { + virtual inline void setNProbes(uint32_t n_probes) override { std::visit([n_probes](auto &¶ms) { params.n_probes = n_probes; }, search_params_); } private: - // An object used to manage common device resources that may be - // expensive to build but frequently accessed - raft::device_resources res_; // Store build params to allow for index build on first batch // insertion build_params_t build_params_; @@ -232,4 +319,6 @@ struct RaftIvfIndex : public VecSimIndexAbstract { // Use a std::optional to allow building of the index on first batch // insertion std::optional index_; + // Bitset used for deleteVectors and search filtering. + //raft::core::bitset deleted_indices_; }; diff --git a/src/VecSim/algorithms/raft_ivf/ivf_interface.h b/src/VecSim/algorithms/raft_ivf/ivf_interface.h new file mode 100644 index 000000000..165c5455e --- /dev/null +++ b/src/VecSim/algorithms/raft_ivf/ivf_interface.h @@ -0,0 +1,20 @@ +#pragma once + +#include "VecSim/vec_sim.h" +// For VecSimIndexAbstract +#include "VecSim/vec_sim_index.h" +// For labelType +#include "VecSim/vec_sim_common.h" + +// Non-CUDA Interface of the RaftIVF index to avoid importing CUDA code +// in the tiered index. +template +struct RaftIvfInterface : public VecSimIndexAbstract +{ + RaftIvfInterface(const AbstractIndexInitParams ¶ms) : VecSimIndexAbstract(params) {} + virtual uint32_t nLists() const = 0; + virtual inline void setNProbes(uint32_t n_probes) = 0; + + virtual int addVectorBatch(const void *vector_data, labelType *label, size_t batch_size, + void *auxiliaryCtx = nullptr) = 0; +}; diff --git a/src/VecSim/algorithms/raft_ivf/ivf_tiered.h b/src/VecSim/algorithms/raft_ivf/ivf_tiered.h index 67d7359c1..4a66be88b 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf_tiered.h +++ b/src/VecSim/algorithms/raft_ivf/ivf_tiered.h @@ -1,18 +1,19 @@ #pragma once #include -#include "VecSim/algorithms/raft_ivf/ivf.h" +#include "VecSim/algorithms/raft_ivf/ivf_interface.h" #include "VecSim/vec_sim_tiered_index.h" struct RAFTTransferJob : public AsyncJob { + bool force_ = false; RAFTTransferJob(std::shared_ptr allocator, JobCallback insertCb, - VecSimIndex *index_) - : AsyncJob{allocator, RAFT_TRANSFER_JOB, insertCb, index_} {} + VecSimIndex *index_, bool force = false) + : AsyncJob{allocator, RAFT_TRANSFER_JOB, insertCb, index_}, force_{force} {} }; template struct TieredRaftIvfIndex : public VecSimTieredIndex { - TieredRaftIvfIndex(RaftIvfIndex *raftIvfIndex, + TieredRaftIvfIndex(RaftIvfInterface *raftIvfIndex, BruteForceIndex *bf_index, const TieredIndexParams &tieredParams, std::shared_ptr allocator) @@ -32,7 +33,7 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { // If the backend index is empty, build it with all the vectors // Otherwise, just add the vector to the backend index if (this->backendIndex->indexSize() == 0) { - executeTransferJob(); + executeTransferJob(true); } else { this->mainIndexGuard.lock(); ret = this->backendIndex->addVector(blob, label); @@ -83,7 +84,7 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { double getDistanceFrom_Unsafe(labelType label, const void *blob) const override { auto flat_dist = this->frontendIndex->getDistanceFrom_Unsafe(label, blob); - auto raft_dist = getBackendIndex().getDistanceFrom_Unsafe(label, blob); + auto raft_dist = this->backendIndex->getDistanceFrom_Unsafe(label, blob); return std::fmin(flat_dist, raft_dist); } @@ -92,7 +93,7 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { auto *transfer_job = reinterpret_cast(job); auto *job_index = reinterpret_cast *>(transfer_job->index); - job_index->executeTransferJob(); + job_index->executeTransferJob(transfer_job->force_); } delete job; } @@ -125,35 +126,43 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { inline void setNProbes(uint32_t n_probes) { this->mainIndexGuard.lock(); - this->getBackendIndex().setNProbes(n_probes); + this->getBackendIndex()->setNProbes(n_probes); this->mainIndexGuard.unlock(); } private: - inline auto &getBackendIndex() const { - return *dynamic_cast *>(this->backendIndex); + inline auto* getBackendIndex() const { + return dynamic_cast *>(this->backendIndex); } - void executeTransferJob() { - auto frontend_lock = std::unique_lock(this->flatIndexGuard); - auto nVectors = this->frontendIndex->indexSize(); + void executeTransferJob(bool force = false) { + size_t nVectors = this->frontendIndex->indexSize(); // No vectors to transfer if (nVectors == 0) { - frontend_lock.unlock(); return; } - // If the backend index is empty, don't transfer less than nLists vectors - this->mainIndexGuard.lock_shared(); - auto main_nVectors = this->backendIndex->indexSize(); - this->mainIndexGuard.unlock_shared(); - if (main_nVectors == 0 && nVectors < getBackendIndex().nLists()) { - frontend_lock.unlock(); + // Don't transfer less than nLists vectors + if (!force) { + auto main_nVectors = this->backendIndex->indexSize(); + size_t min_nVectors = getBackendIndex()->nLists(); + if (nVectors < min_nVectors) { + return; + } + } + + // Check that there are still vectors to transfer after exclusive lock + this->flatIndexGuard.lock(); + nVectors = this->frontendIndex->indexSize(); + if (nVectors == 0) { + this->flatIndexGuard.unlock(); return; } + auto dim = this->backendIndex->getDim(); const auto &vectorBlocks = this->frontendIndex->getVectorBlocks(); auto *vectorData = (DataType *)this->allocator->allocate(nVectors * dim * sizeof(DataType)); + auto *labelData = (labelType *)this->allocator->allocate(nVectors * sizeof(labelType)); // Transfer vectors to a contiguous host buffer auto *curr_ptr = vectorData; @@ -165,13 +174,18 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { curr_ptr += length * dim; } - // Add the vectors to the backend index - auto backend_lock = std::scoped_lock(this->mainIndexGuard); - getBackendIndex().addVectorBatch(vectorData, this->frontendIndex->getLabels().data(), - nVectors); + std::copy(labelData, labelData + nVectors, this->frontendIndex->getLabels().data()); this->frontendIndex->clear(); - frontend_lock.unlock(); + + // Lock the main index before unlocking the front index so that both indexes are not empty at the same time + this->mainIndexGuard.lock(); + this->flatIndexGuard.unlock(); + + // Add the vectors to the backend index + getBackendIndex()->addVectorBatch(vectorData, labelData, nVectors); + this->mainIndexGuard.unlock(); this->allocator->free_allocation(vectorData); + this->allocator->free_allocation(labelData); } #ifdef BUILD_TESTS diff --git a/src/VecSim/index_factories/raft_ivf_factory.cpp b/src/VecSim/index_factories/raft_ivf_factory.cu similarity index 98% rename from src/VecSim/index_factories/raft_ivf_factory.cpp rename to src/VecSim/index_factories/raft_ivf_factory.cu index 6c16324a3..7817a282d 100644 --- a/src/VecSim/index_factories/raft_ivf_factory.cpp +++ b/src/VecSim/index_factories/raft_ivf_factory.cu @@ -1,5 +1,5 @@ #include "VecSim/index_factories/brute_force_factory.h" -#include "VecSim/algorithms/raft_ivf/ivf.h" +#include "VecSim/algorithms/raft_ivf/ivf.cuh" namespace RaftIvfFactory { diff --git a/src/VecSim/index_factories/raft_ivf_tiered_factory.cpp b/src/VecSim/index_factories/raft_ivf_tiered_factory.cpp index 37b797fd9..7ecfcb8b7 100644 --- a/src/VecSim/index_factories/raft_ivf_tiered_factory.cpp +++ b/src/VecSim/index_factories/raft_ivf_tiered_factory.cpp @@ -1,6 +1,6 @@ #include "VecSim/index_factories/brute_force_factory.h" #include "VecSim/algorithms/raft_ivf/ivf_tiered.h" -#include "VecSim/algorithms/raft_ivf/ivf.h" +#include "VecSim/algorithms/raft_ivf/ivf_interface.h" #include "VecSim/index_factories/tiered_factory.h" #include "VecSim/index_factories/raft_ivf_factory.h" @@ -14,7 +14,7 @@ VecSimIndex *NewIndex(const TieredIndexParams *params) using DataType = float; using DistType = float; // initialize raft index - auto *raft_index = reinterpret_cast *>( + auto *raft_index = reinterpret_cast *>( RaftIvfFactory::NewIndex(params->primaryIndexParams)); // initialize brute force index BFParams bf_params = { diff --git a/src/VecSim/vec_sim_common.h b/src/VecSim/vec_sim_common.h index 0a0c49c94..f61823028 100644 --- a/src/VecSim/vec_sim_common.h +++ b/src/VecSim/vec_sim_common.h @@ -38,7 +38,7 @@ typedef enum { } VecSimType; // Algorithm type/library. -typedef enum { VecSimAlgo_BF, VecSimAlgo_HNSWLIB, VecSimAlgo_RAFT_IVFFLAT, VecSimAlgo_RAFT_IVFPQ, VecSimAlgo_TIERED } VecSimAlgo; +typedef enum { VecSimAlgo_BF, VecSimAlgo_HNSWLIB, VecSimAlgo_TIERED, VecSimAlgo_RAFT_IVFFLAT, VecSimAlgo_RAFT_IVFPQ } VecSimAlgo; // Distance metric typedef enum { VecSimMetric_L2, VecSimMetric_IP, VecSimMetric_Cosine } VecSimMetric; diff --git a/tests/benchmark/bm_common.h b/tests/benchmark/bm_common.h index 1a3388670..b7670107d 100644 --- a/tests/benchmark/bm_common.h +++ b/tests/benchmark/bm_common.h @@ -1,6 +1,8 @@ #pragma once #include "bm_vecsim_index.h" +#include "VecSim/algorithms/raft_ivf/ivf_tiered.h" + size_t BM_VecSimGeneral::block_size = 1024; @@ -17,6 +19,10 @@ class BM_VecSimCommon : public BM_VecSimIndex { static void RunTopK_HNSW(benchmark::State &st, size_t ef, size_t iter, size_t k, std::atomic_int &correct, unsigned short index_offset = 0, bool is_tiered = false); + static void RunTopK_TieredRaftIVFFlat(benchmark::State &st, size_t iter, size_t k, std::atomic_int &correct, + unsigned short index_offset = 0, bool is_tiered = true); + static void RunTopK_TieredRaftIVFPQ(benchmark::State &st, size_t iter, size_t k, std::atomic_int &correct, + unsigned short index_offset = 0, bool is_tiered = true); // Search for the K closest vectors to the query in the index. K is defined in the // test registration (initialization file). @@ -25,6 +31,12 @@ class BM_VecSimCommon : public BM_VecSimIndex { // with respect to the results returned by the flat index. static void TopK_HNSW(benchmark::State &st, unsigned short index_offset = 0); static void TopK_Tiered(benchmark::State &st, unsigned short index_offset = 0); + // Run TopK using Raft IVF Flat tiered and flat index and calculate the recall of the Raft IVF + // Flat algorithm with respect to the results returned by the flat index. + static void TopK_TieredRaftIVFFlat(benchmark::State &st, unsigned short index_offset = 0); + // Run TopK using both Raft IVF PQ Tiered and flat index and calculate the recall of the Raft IVF + // PQ algorithm with respect to the results returned by the flat index. + static void TopK_TieredRaftIVFPQ(benchmark::State &st, unsigned short index_offset = 0); // Does nothing but returning the index memory. static void Memory_FLAT(benchmark::State &st, unsigned short index_offset = 0); @@ -157,6 +169,54 @@ void BM_VecSimCommon::TopK_Tiered(benchmark::State &st, unsigned s st.counters["num_threads"] = (double)BM_VecSimGeneral::mock_thread_pool.thread_pool_size; } +template +void BM_VecSimCommon::TopK_TieredRaftIVFFlat(benchmark::State &st, unsigned short index_offset) { + size_t k = st.range(0); + size_t n_probes = st.range(1); + std::atomic_int correct = 0; + std::atomic_int iter = 0; + auto *tiered_index = //reinterpret_cast *>(INDICES[VecSimAlgo_RAFT_IVFFLAT]); + reinterpret_cast *>(INDICES[VecSimAlgo_RAFT_IVFFLAT]); + size_t total_iters = 50; + tiered_index->setNProbes(n_probes); + VecSimQueryReply *all_results[total_iters]; + + auto parallel_knn_search = [](AsyncJob *job) { + auto *search_job = reinterpret_cast(job); + VecSimQueryParams query_params { .batchSize = 1 }; + size_t cur_iter = search_job->iter; + auto results = + VecSimIndex_TopKQuery(INDICES[VecSimAlgo_RAFT_IVFFLAT], QUERIES[cur_iter % N_QUERIES].data(), + search_job->k, &query_params, BY_SCORE); + search_job->all_results[cur_iter] = results; + delete job; + }; + + for (auto _ : st) { + auto search_job = new (tiered_index->getAllocator()) + tieredIndexMock::SearchJobMock(tiered_index->getAllocator(), parallel_knn_search, + tiered_index, k, 0, iter++, all_results); + tiered_index->submitSingleJob(search_job); + if (iter == total_iters) { + BM_VecSimGeneral::mock_thread_pool_raft.thread_pool_wait(); + } + } + + // Measure recall + for (iter = 0; iter < total_iters; iter++) { + auto bf_results = + VecSimIndex_TopKQuery(INDICES[VecSimAlgo_BF + index_offset], + QUERIES[iter % N_QUERIES].data(), k, nullptr, BY_SCORE); + BM_VecSimGeneral::MeasureRecall(all_results[iter], bf_results, correct); + + VecSimQueryReply_Free(bf_results); + VecSimQueryReply_Free(all_results[iter]); + } + + st.counters["Recall"] = (float)correct / (float)(k * iter); + st.counters["num_threads"] = (double)BM_VecSimGeneral::mock_thread_pool.thread_pool_size; +} + #define REGISTER_TopK_BF(BM_CLASS, BM_FUNC) \ BENCHMARK_REGISTER_F(BM_CLASS, BM_FUNC) \ ->Arg(10) \ @@ -189,3 +249,18 @@ void BM_VecSimCommon::TopK_Tiered(benchmark::State &st, unsigned s ->ArgNames({"ef_runtime", "k"}) \ ->Iterations(50) \ ->Unit(benchmark::kMillisecond) + +#define REGISTER_TopK_TieredRaftIVF(BM_CLASS, BM_FUNC) \ + BENCHMARK_REGISTER_F(BM_CLASS, BM_FUNC) \ + ->Args({10, 20}) \ + ->Args({10, 50}) \ + ->Args({10, 150}) \ + ->Args({100, 20}) \ + ->Args({100, 50}) \ + ->Args({100, 150}) \ + ->Args({200, 20}) \ + ->Args({200, 50}) \ + ->Args({200, 150}) \ + ->ArgNames({"k", "n_probes"}) \ + ->Iterations(50) \ + ->Unit(benchmark::kMillisecond) diff --git a/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h b/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h index ae15b3c9e..aa70a83b6 100644 --- a/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h +++ b/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h @@ -44,6 +44,11 @@ REGISTER_TopK_HNSW(BM_VecSimCommon, BM_FUNC_NAME(TopK, HNSW)); BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimCommon, BM_FUNC_NAME(TopK, Tiered), fp32_index_t) (benchmark::State &st) { TopK_Tiered(st); } REGISTER_TopK_Tiered(BM_VecSimCommon, BM_FUNC_NAME(TopK, Tiered)); +// TopK Tiered RAFT IVF Flat +BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimCommon, BM_FUNC_NAME(TopK, TieredRaftIVF), fp32_index_t) +(benchmark::State &st) { TopK_TieredRaftIVFFlat(st); } +REGISTER_TopK_TieredRaftIVF(BM_VecSimCommon, BM_FUNC_NAME(TopK, TieredRaftIVF)); + // Range BF BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimBasics, BM_FUNC_NAME(Range, BF), fp32_index_t) diff --git a/tests/benchmark/bm_vecsim_general.h b/tests/benchmark/bm_vecsim_general.h index 256273faf..d1870461d 100644 --- a/tests/benchmark/bm_vecsim_general.h +++ b/tests/benchmark/bm_vecsim_general.h @@ -70,6 +70,37 @@ class BM_VecSimGeneral : public benchmark::Fixture { return params; } + static VecSimParams createDefaultRaftIvfPQParams(size_t dim, uint32_t nLists = 1024, uint32_t nProbes = 20) { + RaftIvfParams ivfparams = {.dim = dim, + .metric = VecSimMetric_L2, // TODO Cosine + .nLists = nLists, + .kmeans_nIters = 20, + .kmeans_trainsetFraction = 0.5, + .nProbes = nProbes, + .usePQ = true, + .pqBits = 8, + .pqDim = 0, + .codebookKind = RaftIVFPQCodebookKind_PerSubspace, + .lutType = CUDAType_R_32F, + .internalDistanceType = CUDAType_R_32F, + .preferredShmemCarveout = 1.0}; + VecSimParams params{.algo = VecSimAlgo_RAFT_IVFPQ, .algoParams = {.raftIvfParams = ivfparams}}; + return params; + } + + static VecSimParams createDefaultRaftIvfFlatParams(size_t dim, uint32_t nLists = 1024, uint32_t nProbes = 20, bool adaptiveCenters = true) { + RaftIvfParams ivfparams = {.dim = dim, + .metric = VecSimMetric_L2, // TODO Cosine + .nLists = nLists, + .kmeans_nIters = 20, + .kmeans_trainsetFraction = 0.9, + .nProbes = nProbes, + .usePQ = false, + .adaptiveCenters = adaptiveCenters}; + VecSimParams params{.algo = VecSimAlgo_RAFT_IVFFLAT, .algoParams = {.raftIvfParams = ivfparams}}; + return params; + } + // Gets HNSWParams or BFParams parameters struct, and creates new VecSimIndex. template static inline VecSimIndex *CreateNewIndex(IndexParams &index_params) { diff --git a/tests/benchmark/bm_vecsim_index.h b/tests/benchmark/bm_vecsim_index.h index 8b2406a81..ad1a1a2af 100644 --- a/tests/benchmark/bm_vecsim_index.h +++ b/tests/benchmark/bm_vecsim_index.h @@ -2,6 +2,7 @@ #include "bm_vecsim_general.h" #include "VecSim/index_factories/tiered_factory.h" +#include "VecSim/index_factories/raft_ivf_tiered_factory.h" template class BM_VecSimIndex : public BM_VecSimGeneral { @@ -111,12 +112,30 @@ void BM_VecSimIndex::Initialize() { // Launch the BG threads loop that takes jobs from the queue and executes them. mock_thread_pool.init_threads(); + // Create RAFFT IVF Flat tiered index. + auto &mock_thread_pool_ivf_flat = BM_VecSimGeneral::mock_thread_pool_raft; + + VecSimParams params_flat = createDefaultRaftIvfFlatParams(dim, 900, 100); + tiered_params = {.jobQueue = &BM_VecSimGeneral::mock_thread_pool_raft.jobQ, + .jobQueueCtx = mock_thread_pool_ivf_flat.ctx, + .submitCb = tieredIndexMock::submit_callback, + .flatBufferLimit = params_flat.algoParams.raftIvfParams.nLists * 5000, + .primaryIndexParams = ¶ms_flat}; + + auto *tiered_raft_ivf_flat_index = + TieredRaftIvfFactory::NewIndex(&tiered_params); + mock_thread_pool_ivf_flat.ctx->index_strong_ref.reset(tiered_raft_ivf_flat_index); + mock_thread_pool_ivf_flat.init_threads(); + + indices.push_back(tiered_raft_ivf_flat_index); + // Add the same vectors to Flat index. for (size_t i = 0; i < n_vectors; ++i) { const char *blob = GetHNSWDataByInternalId(i); // Fot multi value indices, the internal id is not necessarily equal the label. size_t label = CastToHNSW(indices[VecSimAlgo_HNSWLIB])->getExternalLabel(i); VecSimIndex_AddVector(indices[VecSimAlgo_BF], blob, label); + VecSimIndex_AddVector(indices[VecSimAlgo_RAFT_IVFFLAT], blob, label); } // Load the test query vectors form file. Index file path is relative to repository root dir. diff --git a/tests/unit/test_raft_ivf_tiered.cpp b/tests/unit/test_raft_ivf_tiered.cpp index 7d1164b01..bd5bdc7c4 100644 --- a/tests/unit/test_raft_ivf_tiered.cpp +++ b/tests/unit/test_raft_ivf_tiered.cpp @@ -1,6 +1,7 @@ #include "gtest/gtest.h" #include "VecSim/vec_sim.h" #include "VecSim/vec_sim_common.h" +#include "VecSim/algorithms/raft_ivf/ivf_tiered.h" #include "VecSim/index_factories/tiered_factory.h" #include "test_utils.h" #include @@ -15,9 +16,27 @@ class RaftIvfTieredTest : public ::testing::Test { public: using data_t = typename index_type_t::data_t; using dist_t = typename index_type_t::dist_t; + + TieredRaftIvfIndex* createTieredIndex(VecSimParams *params, + tieredIndexMock &mock_thread_pool, + size_t flat_buffer_limit = 0) { + TieredIndexParams params_tiered = { + .jobQueue = &mock_thread_pool.jobQ, + .jobQueueCtx = mock_thread_pool.ctx, + .submitCb = tieredIndexMock::submit_callback, + .flatBufferLimit = flat_buffer_limit, + .primaryIndexParams = params, + }; + auto *tiered_index = TieredFactory::NewIndex(¶ms_tiered); + // Set the created tiered index in the index external context (it will take ownership over + // the index, and we'll need to release the ctx at the end of the test. + mock_thread_pool.ctx->index_strong_ref.reset(tiered_index); + + return reinterpret_cast *>(tiered_index); + } }; -VecSimParams createDefaultPQParams(size_t dim, size_t nLists = 3, size_t nProbes = 3) { +VecSimParams createDefaultPQParams(size_t dim, uint32_t nLists = 3, uint32_t nProbes = 3) { RaftIvfParams ivfparams = {.dim = dim, .metric = VecSimMetric_L2, .nLists = nLists, @@ -35,7 +54,7 @@ VecSimParams createDefaultPQParams(size_t dim, size_t nLists = 3, size_t nProbes return params; } -VecSimParams createDefaultFlatParams(size_t dim, size_t nLists = 3, size_t nProbes = 3) { +VecSimParams createDefaultFlatParams(size_t dim, uint32_t nLists = 3, uint32_t nProbes = 3) { RaftIvfParams ivfparams = {.dim = dim, .metric = VecSimMetric_L2, .nLists = nLists, @@ -47,36 +66,18 @@ VecSimParams createDefaultFlatParams(size_t dim, size_t nLists = 3, size_t nProb return params; } -VecSimIndex* createTieredIndex(VecSimParams *params, - tieredIndexMock &mock_thread_pool, - size_t flat_buffer_limit = 0) { - TieredIndexParams params_tiered = { - .jobQueue = &mock_thread_pool.jobQ, - .jobQueueCtx = mock_thread_pool.ctx, - .submitCb = tieredIndexMock::submit_callback, - .flatBufferLimit = flat_buffer_limit, - .primaryIndexParams = params, - }; - auto *tiered_index = TieredFactory::NewIndex(¶ms_tiered); - // Set the created tiered index in the index external context (it will take ownership over - // the index, and we'll need to release the ctx at the end of the test. - mock_thread_pool.ctx->index_strong_ref.reset(tiered_index); - - return tiered_index; -} - using DataTypeSetFloat = ::testing::Types>; TYPED_TEST_SUITE(RaftIvfTieredTest, DataTypeSetFloat); -TYPED_TEST(RaftIvfTieredTest, RaftIVFTiered_PQ_add_sanity_test) { +TYPED_TEST(RaftIvfTieredTest, end_to_end) { size_t dim = 4; size_t flat_buffer_limit = 3; size_t nLists = 2; VecSimParams params = createDefaultFlatParams(dim, nLists, nLists); auto mock_thread_pool = tieredIndexMock(); - auto *index = createTieredIndex(¶ms, mock_thread_pool, flat_buffer_limit); + auto *index = this->createTieredIndex(¶ms, mock_thread_pool, flat_buffer_limit); mock_thread_pool.init_threads(); VecSimQueryParams queryParams = {.batchSize = 1}; @@ -89,13 +90,7 @@ TYPED_TEST(RaftIvfTieredTest, RaftIVFTiered_PQ_add_sanity_test) { std::vector c_vec(dim, (TEST_DATA_T)4); std::vector d_vec(dim, (TEST_DATA_T)5); std::vector zero_vec(dim, (TEST_DATA_T)0); - /*for (size_t i = 0; i < dim; i++) { - a[i] = (TEST_DATA_T)1; - b[i] = (TEST_DATA_T)2; - c[i] = (TEST_DATA_T)4; - d[i] = (TEST_DATA_T)5; - zero[i] = (TEST_DATA_T)0; - }*/ + auto inserted_vectors = std::vector>{a_vec, b_vec, c_vec, d_vec}; // Search for vectors when the index is empty. From 6bab5e387dd404692702def886ee3444e9086ac9 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 20 Nov 2023 12:13:22 +0100 Subject: [PATCH 16/28] Update Tiered vector ingestion --- CMakeLists.txt | 5 ++- src/VecSim/algorithms/raft_ivf/ivf.cuh | 42 ++++++++----------- src/VecSim/algorithms/raft_ivf/ivf_tiered.h | 16 ++++--- src/VecSim/vec_sim_common.h | 6 +++ tests/benchmark/CMakeLists.txt | 2 +- tests/benchmark/bm_common.h | 11 +++++ .../bm_basics_initialize_fp32.h | 5 +++ tests/benchmark/bm_vecsim_general.h | 2 +- tests/benchmark/bm_vecsim_index.h | 5 ++- 9 files changed, 60 insertions(+), 34 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8df0bbaa4..0fbef417f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,7 +40,10 @@ if (USE_CUDA) include(cmake/raft.cmake) # Required flags for compiling RAFT set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr") - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 -std=c++17") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -std=c++17") + set(CMAKE_CUDA_FLAGS_RELEASE "-O3") + set(CMAKE_CUDA_FLAGS_DEBUG "-g") + endif() # Only do these if this is the main project, and not if it is included through add_subdirectory diff --git a/src/VecSim/algorithms/raft_ivf/ivf.cuh b/src/VecSim/algorithms/raft_ivf/ivf.cuh index 46e81f80e..96cdd1e63 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf.cuh +++ b/src/VecSim/algorithms/raft_ivf/ivf.cuh @@ -83,9 +83,9 @@ private: raft::neighbors::ivf_pq::index_params>; using search_params_t = std::variant; - //using internal_idx_t = std::int64_t; - using index_flat_t = raft::neighbors::ivf_flat::index; - using index_pq_t = raft::neighbors::ivf_pq::index; + using internal_idx_t = std::uint64_t; + using index_flat_t = raft::neighbors::ivf_flat::index; + using index_pq_t = raft::neighbors::ivf_pq::index; using ann_index_t = std::variant; public: @@ -128,27 +128,18 @@ public: }, search_params_); - raft::device_resources_manager::set_streams_per_device(16); // TODO: use env variable - raft::device_resources_manager::set_stream_pools_per_device(16); - // Create a 5 GB memory pool. Passing std::nullopt will allow - // the pool to grow to the available memory of the device. - raft::device_resources_manager::set_mem_pool(size_t{5000} << 20, std::nullopt); } int addVector(const void *vector_data, labelType label, void *auxiliaryCtx = nullptr) override { return addVectorBatch(vector_data, &label, 1, auxiliaryCtx); } - virtual int addVectorBatch(const void *vector_data, labelType *label, size_t batch_size, + int addVectorBatch(const void *vector_data, labelType *label, size_t batch_size, void *auxiliaryCtx = nullptr) override { - auto& res = raft::device_resources_manager::get_device_resources(); - // Convert labels to internal data type - /*auto label_original = std::vector(label, label + batch_size); - auto label_converted = - std::vector(label_original.begin(), label_original.end());*/ + const auto& res = raft::device_resources_manager::get_device_resources(); // Allocate memory on device to hold vectors to be added auto vector_data_gpu = - raft::make_device_matrix(res, batch_size, this->dim); + raft::make_device_matrix(res, batch_size, this->dim); // Allocate memory on device to hold vector labels - auto label_gpu = raft::make_device_vector(res, batch_size); + auto label_gpu = raft::make_device_vector(res, batch_size); // Copy vector data to previously allocated device buffer raft::copy(vector_data_gpu.data_handle(), static_cast(vector_data), @@ -201,7 +192,7 @@ public: } VecSimQueryReply *topKQuery(const void *queryBlob, size_t k, VecSimQueryParams *queryParams) const override { - auto& res = raft::device_resources_manager::get_device_resources(); + const auto& res = raft::device_resources_manager::get_device_resources(); auto result_list = new VecSimQueryReply(this->allocator); auto nVectors = this->indexSize(); if (nVectors == 0 || k == 0 || !index_.has_value()) { @@ -211,29 +202,29 @@ public: // index k = std::min(k, nVectors); // Allocate memory on device for search vector - auto vector_data_gpu = raft::make_device_matrix(res, 1, this->dim); + auto vector_data_gpu = raft::make_device_matrix(res, 1, this->dim); // Allocate memory on device for neighbor and distance results - auto neighbors_gpu = raft::make_device_matrix(res, 1, k); - auto distances_gpu = raft::make_device_matrix(res, 1, k); + auto neighbors_gpu = raft::make_device_matrix(res, 1, k); + auto distances_gpu = raft::make_device_matrix(res, 1, k); // Copy query vector to device raft::copy(vector_data_gpu.data_handle(), static_cast(queryBlob), this->dim, res.get_stream()); // Perform correct search based on index type if (std::holds_alternative(*index_)) { - raft::neighbors::ivf_flat::search( + raft::neighbors::ivf_flat::search( res, std::get(search_params_), std::get(*index_), raft::make_const_mdspan(vector_data_gpu.view()), neighbors_gpu.view(), distances_gpu.view()); } else { - raft::neighbors::ivf_pq::search( + raft::neighbors::ivf_pq::search( res, std::get(search_params_), std::get(*index_), raft::make_const_mdspan(vector_data_gpu.view()), neighbors_gpu.view(), distances_gpu.view()); } // Allocate host buffers to hold returned results - auto neighbors = vecsim_stl::vector(k, this->allocator); + auto neighbors = vecsim_stl::vector(k, this->allocator); auto distances = vecsim_stl::vector(k, this->allocator); // Copy data back from device to host raft::copy(neighbors.data(), neighbors_gpu.data_handle(), k, res.get_stream()); @@ -297,10 +288,13 @@ public: info.commonInfo = this->getCommonInfo(); info.raftIvfInfo.nLists = nLists(); if (std::holds_alternative(build_params_)) { + info.commonInfo.basicInfo.algo = VecSimAlgo_RAFT_IVFPQ; const auto build_params_pq = std::get(build_params_); info.raftIvfInfo.pqBits = build_params_pq.pq_bits; info.raftIvfInfo.pqDim = build_params_pq.pq_dim; + } else { + info.commonInfo.basicInfo.algo = VecSimAlgo_RAFT_IVFFLAT; } return info; } @@ -320,5 +314,5 @@ private: // insertion std::optional index_; // Bitset used for deleteVectors and search filtering. - //raft::core::bitset deleted_indices_; + std::optional> deleted_indices_; }; diff --git a/src/VecSim/algorithms/raft_ivf/ivf_tiered.h b/src/VecSim/algorithms/raft_ivf/ivf_tiered.h index 4a66be88b..cb719d8a6 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf_tiered.h +++ b/src/VecSim/algorithms/raft_ivf/ivf_tiered.h @@ -21,6 +21,7 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { assert( raftIvfIndex->nLists() < this->flatBufferLimit && "The flat buffer limit must be greater than the number of lists in the backend index"); + this->minVectorsInit = std::max((size_t)1, tieredParams.specificParams.tieredRaftIvfParams.minVectorsInit); } ~TieredRaftIvfIndex() { // Delete all the pending jobs @@ -74,8 +75,7 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { } size_t indexLabelCount() const override { - // TODO(wphicks) Count unique labels between both indexes - return 0; + return indexSize(); } size_t indexCapacity() const override { @@ -131,6 +131,8 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { } private: + size_t minVectorsInit = 1; + inline auto* getBackendIndex() const { return dynamic_cast *>(this->backendIndex); } @@ -142,10 +144,14 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { return; } - // Don't transfer less than nLists vectors + // Don't transfer less than nLists vectors, unless the backend index is not + // empty (for kmeans initialization purposes) and force is true if (!force) { auto main_nVectors = this->backendIndex->indexSize(); size_t min_nVectors = getBackendIndex()->nLists(); + if (main_nVectors == 0) + min_nVectors *= this->minVectorsInit; + if (nVectors < min_nVectors) { return; } @@ -170,11 +176,11 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { const auto *in_begin = reinterpret_cast(vectorBlocks[block_id].getElement(0)); auto length = vectorBlocks[block_id].getLength(); - std::copy(in_begin, in_begin + (length * dim), curr_ptr); + std::copy_n(in_begin, length * dim, curr_ptr); curr_ptr += length * dim; } - std::copy(labelData, labelData + nVectors, this->frontendIndex->getLabels().data()); + std::copy_n(this->frontendIndex->getLabels().data(), nVectors, labelData); this->frontendIndex->clear(); // Lock the main index before unlocking the front index so that both indexes are not empty at the same time diff --git a/src/VecSim/vec_sim_common.h b/src/VecSim/vec_sim_common.h index f61823028..63335a674 100644 --- a/src/VecSim/vec_sim_common.h +++ b/src/VecSim/vec_sim_common.h @@ -124,6 +124,11 @@ typedef struct { // all the ready swap jobs in a batch. } TieredHNSWParams; +// A struct that contains Raft IVF tiered index specific params. +typedef struct { + size_t minVectorsInit; // Min. number of vectors per list in Tiered index to init IVF index +} TieredRAFTIVFParams; + // A struct that contains the common tiered index params. typedef struct { void *jobQueue; // External queue that holds the jobs. @@ -134,6 +139,7 @@ typedef struct { VecSimParams *primaryIndexParams; // Parameters to initialize the index. union { TieredHNSWParams tieredHnswParams; + TieredRAFTIVFParams tieredRaftIvfParams; } specificParams; } TieredIndexParams; diff --git a/tests/benchmark/CMakeLists.txt b/tests/benchmark/CMakeLists.txt index 8ef952187..5f5256b4f 100644 --- a/tests/benchmark/CMakeLists.txt +++ b/tests/benchmark/CMakeLists.txt @@ -22,7 +22,7 @@ foreach(benchmark IN ITEMS ${BENCHMARKS}) # NOTE: mock_thread_pool.cpp should appear *before* the benchmark files, so we can ensure that the thread pool # globals are initialized before we use them in the benchmark classes (as globals initialization is done by order). add_executable(bm_${benchmark} ../utils/mock_thread_pool.cpp bm_vecsim_general.cpp run_files/bm_${benchmark}.cpp) - target_link_libraries(bm_${benchmark} VectorSimilarity benchmark::benchmark) + target_link_libraries(bm_${benchmark} VectorSimilarity benchmark::benchmark $<$:raft::raft>) endforeach() # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # diff --git a/tests/benchmark/bm_common.h b/tests/benchmark/bm_common.h index b7670107d..7d01abd08 100644 --- a/tests/benchmark/bm_common.h +++ b/tests/benchmark/bm_common.h @@ -42,6 +42,7 @@ class BM_VecSimCommon : public BM_VecSimIndex { static void Memory_FLAT(benchmark::State &st, unsigned short index_offset = 0); static void Memory_HNSW(benchmark::State &st, unsigned short index_offset = 0); static void Memory_Tiered(benchmark::State &st, unsigned short index_offset = 0); + static void Memory_TieredRaftIVFFlat(benchmark::State &st, unsigned short index_offset = 0); }; template @@ -94,6 +95,16 @@ void BM_VecSimCommon::Memory_Tiered(benchmark::State &st, st.counters["memory"] = (double)VecSimIndex_Info(INDICES[VecSimAlgo_TIERED + index_offset]).commonInfo.memory; } +template +void BM_VecSimCommon::Memory_TieredRaftIVFFlat(benchmark::State &st, + unsigned short index_offset) { + + for (auto _ : st) { + // Do nothing... + } + st.counters["memory"] = + (double)VecSimIndex_Info(INDICES[VecSimAlgo_RAFT_IVFFLAT + index_offset]).commonInfo.memory; +} // TopK search BM diff --git a/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h b/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h index aa70a83b6..4b22e77e6 100644 --- a/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h +++ b/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h @@ -20,6 +20,11 @@ BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimCommon, BM_FUNC_NAME(Memory, Tiered), fp32_ (benchmark::State &st) { Memory_Tiered(st); } BENCHMARK_REGISTER_F(BM_VecSimCommon, BM_FUNC_NAME(Memory, Tiered))->Iterations(1); +// Memory TieredRaftIVFFlat +BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimCommon, BM_FUNC_NAME(Memory, TieredRaftIVFFlat), fp32_index_t) +(benchmark::State &st) { Memory_TieredRaftIVFFlat(st); } +BENCHMARK_REGISTER_F(BM_VecSimCommon, BM_FUNC_NAME(Memory, TieredRaftIVFFlat))->Iterations(1); + // AddLabel BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimBasics, BM_ADD_LABEL, fp32_index_t) (benchmark::State &st) { AddLabel(st); } diff --git a/tests/benchmark/bm_vecsim_general.h b/tests/benchmark/bm_vecsim_general.h index d1870461d..1d7790856 100644 --- a/tests/benchmark/bm_vecsim_general.h +++ b/tests/benchmark/bm_vecsim_general.h @@ -93,7 +93,7 @@ class BM_VecSimGeneral : public benchmark::Fixture { .metric = VecSimMetric_L2, // TODO Cosine .nLists = nLists, .kmeans_nIters = 20, - .kmeans_trainsetFraction = 0.9, + .kmeans_trainsetFraction = 0.5, .nProbes = nProbes, .usePQ = false, .adaptiveCenters = adaptiveCenters}; diff --git a/tests/benchmark/bm_vecsim_index.h b/tests/benchmark/bm_vecsim_index.h index ad1a1a2af..45e281983 100644 --- a/tests/benchmark/bm_vecsim_index.h +++ b/tests/benchmark/bm_vecsim_index.h @@ -115,12 +115,13 @@ void BM_VecSimIndex::Initialize() { // Create RAFFT IVF Flat tiered index. auto &mock_thread_pool_ivf_flat = BM_VecSimGeneral::mock_thread_pool_raft; - VecSimParams params_flat = createDefaultRaftIvfFlatParams(dim, 900, 100); + VecSimParams params_flat = createDefaultRaftIvfFlatParams(dim, 1000, 100); tiered_params = {.jobQueue = &BM_VecSimGeneral::mock_thread_pool_raft.jobQ, .jobQueueCtx = mock_thread_pool_ivf_flat.ctx, .submitCb = tieredIndexMock::submit_callback, .flatBufferLimit = params_flat.algoParams.raftIvfParams.nLists * 5000, - .primaryIndexParams = ¶ms_flat}; + .primaryIndexParams = ¶ms_flat, + .specificParams = {.tieredRaftIvfParams = {.minVectorsInit = 100}}}; auto *tiered_raft_ivf_flat_index = TieredRaftIvfFactory::NewIndex(&tiered_params); From d072d3ab6fa6068538602c63f2126fdb3733112c Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 21 Nov 2023 17:57:18 +0100 Subject: [PATCH 17/28] Update Tiered index. Add vector deletion code --- src/VecSim/algorithms/raft_ivf/ivf.cuh | 76 ++++++++++-- .../algorithms/raft_ivf/ivf_interface.h | 1 + src/VecSim/algorithms/raft_ivf/ivf_tiered.h | 49 ++++++-- tests/benchmark/bm_vecsim_general.h | 2 +- tests/benchmark/bm_vecsim_index.h | 1 + tests/unit/test_raft_ivf_tiered.cpp | 117 ++++++++++++++++++ 6 files changed, 223 insertions(+), 23 deletions(-) diff --git a/src/VecSim/algorithms/raft_ivf/ivf.cuh b/src/VecSim/algorithms/raft_ivf/ivf.cuh index 96cdd1e63..29f1e221f 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf.cuh +++ b/src/VecSim/algorithms/raft_ivf/ivf.cuh @@ -17,6 +17,7 @@ #include "VecSim/algorithms/raft_ivf/ivf_interface.h" // RaftIvfInterface #include "VecSim/memory/vecsim_malloc.h" +#include #include #include #include @@ -72,6 +73,16 @@ inline auto constexpr GetCudaType(CudaType vss_type) { return result; } +void init_raft_resources() { + auto static init_flag = std::once_flag{}; + std::call_once(init_flag, []() { + raft::device_resources_manager::set_streams_per_device(8); // TODO: use env variable + raft::device_resources_manager::set_stream_pools_per_device(8); + // Create a memory pool with half of the available GPU memory. + raft::device_resources_manager::set_mem_pool(); + }); +} + template struct RaftIvfIndex : public RaftIvfInterface { using data_type = DataType; @@ -95,7 +106,10 @@ public: : build_params_t{std::in_place_index<0>}}, search_params_{raftIvfParams->usePQ ? search_params_t{std::in_place_index<1>} : search_params_t{std::in_place_index<0>}}, - index_{std::nullopt} { + index_{std::nullopt}, + deleted_indices_{std::nullopt}, + idToLabelLookup_{this->allocator}, + labelToIdLookup_{this->allocator} { std::visit( [raftIvfParams](auto &&inner) { inner.metric = GetRaftDistanceType(raftIvfParams->metric); @@ -128,6 +142,7 @@ public: }, search_params_); + init_raft_resources(); } int addVector(const void *vector_data, labelType label, void *auxiliaryCtx = nullptr) override { return addVectorBatch(vector_data, &label, 1, auxiliaryCtx); @@ -138,45 +153,72 @@ public: // Allocate memory on device to hold vectors to be added auto vector_data_gpu = raft::make_device_matrix(res, batch_size, this->dim); - // Allocate memory on device to hold vector labels - auto label_gpu = raft::make_device_vector(res, batch_size); // Copy vector data to previously allocated device buffer raft::copy(vector_data_gpu.data_handle(), static_cast(vector_data), this->dim * batch_size, res.get_stream()); - // Copy label data to previously allocated device buffer - raft::copy(label_gpu.data_handle(), label, batch_size, res.get_stream()); + std::optional> label_opt = std::nullopt; if (std::holds_alternative(build_params_)) { if (!index_) { index_ = raft::neighbors::ivf_flat::build( res, std::get(build_params_), raft::make_const_mdspan(vector_data_gpu.view())); + deleted_indices_ = {raft::core::bitset(res, 0)}; } raft::neighbors::ivf_flat::extend( res, raft::make_const_mdspan(vector_data_gpu.view()), - std::make_optional(raft::make_const_mdspan(label_gpu.view())), + label_opt, &std::get(*index_)); } else { if (!index_) { index_ = raft::neighbors::ivf_pq::build( res, std::get(build_params_), raft::make_const_mdspan(vector_data_gpu.view())); + deleted_indices_ = {raft::core::bitset(res, 0)}; } raft::neighbors::ivf_pq::extend( res, raft::make_const_mdspan(vector_data_gpu.view()), - std::make_optional(raft::make_const_mdspan(label_gpu.view())), + label_opt, &std::get(*index_)); } + internal_idx_t last_id = this->indexSize(); + internal_idx_t first_id = last_id - batch_size; + + // Add labels to internal idToLabelLookup_ mapping + this->idToLabelLookup_.insert(this->idToLabelLookup_.end(), label, label + batch_size); + for (auto i = 0; i < batch_size; ++i) { + this->labelToIdLookup_[label[i]] = first_id + i; + } + + // Update the size of the deleted indices bitset + deleted_indices_->resize(res, deleted_indices_->size() + batch_size); + // Ensure that above operation has executed on device before // returning from this function on host res.sync_stream(); return batch_size; } int deleteVector(labelType label) override { - assert(!"deleteVector not implemented"); - return 0; + auto search = labelToIdLookup_.find(label); + if (search == labelToIdLookup_.end()) { + return 0; + } + const auto& res = raft::device_resources_manager::get_device_resources(); + // Create GPU vector to hold ids to mark as deleted + internal_idx_t id = search->second; + auto id_gpu = raft::make_device_vector(res, 1); + raft::copy(id_gpu.data_handle(), &id, 1, res.get_stream()); + // Mark the id as deleted + deleted_indices_->set(res, raft::make_const_mdspan(id_gpu.view()), false); + + // Remove label from internal labelToIdLookup_ mapping + labelToIdLookup_.erase(search); + // Ensure that above operation has executed on device before + // returning from this function on host + res.sync_stream(); + return 1; } double getDistanceFrom_Unsafe(labelType label, const void *vector_data) const override { assert(!"getDistanceFrom not implemented"); @@ -186,9 +228,16 @@ public: assert(!"indexCapacity not implemented"); return 0; } + inline vecsim_stl::set getLabelsSet() const override { + vecsim_stl::set result(this->allocator); + for (auto const &pair : labelToIdLookup_) { + result.insert(pair.first); + } + return result; + } // void increaseCapacity() override { assert(!"increaseCapacity not implemented"); } inline size_t indexLabelCount() const override { - return this->indexSize(); // TODO: Return unique counts + return this->labelToIdLookup_.size(); } VecSimQueryReply *topKQuery(const void *queryBlob, size_t k, VecSimQueryParams *queryParams) const override { @@ -236,7 +285,7 @@ public: result_list->results.resize(k); for (auto i = 0; i < k; ++i) { - result_list->results[i].id = labelType{neighbors[i]}; + result_list->results[i].id = idToLabelLookup_[neighbors[i]]; result_list->results[i].score = distances[i]; } @@ -314,5 +363,8 @@ private: // insertion std::optional index_; // Bitset used for deleteVectors and search filtering. - std::optional> deleted_indices_; + std::optional> deleted_indices_; + + vecsim_stl::vector idToLabelLookup_; + vecsim_stl::unordered_map labelToIdLookup_; }; diff --git a/src/VecSim/algorithms/raft_ivf/ivf_interface.h b/src/VecSim/algorithms/raft_ivf/ivf_interface.h index 165c5455e..db32b510f 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf_interface.h +++ b/src/VecSim/algorithms/raft_ivf/ivf_interface.h @@ -17,4 +17,5 @@ struct RaftIvfInterface : public VecSimIndexAbstract virtual int addVectorBatch(const void *vector_data, labelType *label, size_t batch_size, void *auxiliaryCtx = nullptr) = 0; + virtual vecsim_stl::set getLabelsSet() const = 0; }; diff --git a/src/VecSim/algorithms/raft_ivf/ivf_tiered.h b/src/VecSim/algorithms/raft_ivf/ivf_tiered.h index cb719d8a6..ab6e03aaa 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf_tiered.h +++ b/src/VecSim/algorithms/raft_ivf/ivf_tiered.h @@ -56,13 +56,27 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { } int deleteVector(labelType label) override { - this->flatIndexGuard.lock(); - auto result = this->frontendIndex->deleteVector(label); + int num_deleted_vectors = 0; + this->flatIndexGuard.lock_shared(); + if (this->frontendIndex->isLabelExists(label)) { + this->flatIndexGuard.unlock_shared(); + this->flatIndexGuard.lock(); + // Check again if the label exists, as it may have been removed while we released the lock. + if (this->frontendIndex->isLabelExists(label)) { + // Remove every id that corresponds the label from the flat buffer. + auto updated_ids = this->frontendIndex->deleteVectorAndGetUpdatedIds(label); + num_deleted_vectors += updated_ids.size(); + } + this->flatIndexGuard.unlock(); + } else { + this->flatIndexGuard.unlock_shared(); + } + + // delete in place. TODO: Add async job for this this->mainIndexGuard.lock(); - this->flatIndexGuard.unlock(); - result += this->backendIndex->deleteVector(label); + num_deleted_vectors += this->backendIndex->deleteVector(label); this->mainIndexGuard.unlock(); - return result; + return num_deleted_vectors; } size_t indexSize() const override { @@ -75,7 +89,16 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { } size_t indexLabelCount() const override { - return indexSize(); + this->flatIndexGuard.lock_shared(); + this->mainIndexGuard.lock_shared(); + auto flat_labels = this->frontendIndex->getLabelsSet(); + auto raft_ivf_labels = this->getBackendIndex()->getLabelsSet(); + this->flatIndexGuard.unlock_shared(); + this->mainIndexGuard.unlock_shared(); + std::vector output; + std::set_union(flat_labels.begin(), flat_labels.end(), raft_ivf_labels.begin(), raft_ivf_labels.end(), + std::back_inserter(output)); + return output.size(); } size_t indexCapacity() const override { @@ -144,13 +167,13 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { return; } - // Don't transfer less than nLists vectors, unless the backend index is not - // empty (for kmeans initialization purposes) and force is true + // Don't transfer less than nLists * minVectorsInit vectors if the backend index is empty + // (for kmeans initialization purposes) if (!force) { auto main_nVectors = this->backendIndex->indexSize(); - size_t min_nVectors = getBackendIndex()->nLists(); + size_t min_nVectors = 1; if (main_nVectors == 0) - min_nVectors *= this->minVectorsInit; + min_nVectors = this->minVectorsInit * getBackendIndex()->nLists(); if (nVectors < min_nVectors) { return; @@ -197,5 +220,11 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { #ifdef BUILD_TESTS INDEX_TEST_FRIEND_CLASS(BM_VecSimBasics) INDEX_TEST_FRIEND_CLASS(BM_VecSimCommon) + INDEX_TEST_FRIEND_CLASS(BM_VecSimIndex); + INDEX_TEST_FRIEND_CLASS(RaftIvfTieredTest) + INDEX_TEST_FRIEND_CLASS(RaftIvfTieredTest_transferJob_Test) + INDEX_TEST_FRIEND_CLASS(RaftIvfTieredTest_transferJobAsync_Test) + INDEX_TEST_FRIEND_CLASS(RaftIvfTieredTest_transferJob_inplace_Test) + INDEX_TEST_FRIEND_CLASS(RaftIvfTieredTest_deleteVector_backend_Test) #endif }; diff --git a/tests/benchmark/bm_vecsim_general.h b/tests/benchmark/bm_vecsim_general.h index 1d7790856..8f60aaa92 100644 --- a/tests/benchmark/bm_vecsim_general.h +++ b/tests/benchmark/bm_vecsim_general.h @@ -72,7 +72,7 @@ class BM_VecSimGeneral : public benchmark::Fixture { static VecSimParams createDefaultRaftIvfPQParams(size_t dim, uint32_t nLists = 1024, uint32_t nProbes = 20) { RaftIvfParams ivfparams = {.dim = dim, - .metric = VecSimMetric_L2, // TODO Cosine + .metric = VecSimMetric_L2, .nLists = nLists, .kmeans_nIters = 20, .kmeans_trainsetFraction = 0.5, diff --git a/tests/benchmark/bm_vecsim_index.h b/tests/benchmark/bm_vecsim_index.h index 45e281983..20a66b353 100644 --- a/tests/benchmark/bm_vecsim_index.h +++ b/tests/benchmark/bm_vecsim_index.h @@ -138,6 +138,7 @@ void BM_VecSimIndex::Initialize() { VecSimIndex_AddVector(indices[VecSimAlgo_BF], blob, label); VecSimIndex_AddVector(indices[VecSimAlgo_RAFT_IVFFLAT], blob, label); } + mock_thread_pool_ivf_flat.thread_pool_wait(100); // Load the test query vectors form file. Index file path is relative to repository root dir. loadTestVectors(AttachRootPath(test_queries_file), type); diff --git a/tests/unit/test_raft_ivf_tiered.cpp b/tests/unit/test_raft_ivf_tiered.cpp index bd5bdc7c4..978a24e9a 100644 --- a/tests/unit/test_raft_ivf_tiered.cpp +++ b/tests/unit/test_raft_ivf_tiered.cpp @@ -124,3 +124,120 @@ TYPED_TEST(RaftIvfTieredTest, end_to_end) { runTopKSearchTest(index, c_vec.data(), k, ver_res_c); } +TYPED_TEST(RaftIvfTieredTest, transferJob) { + // Create RAFT Tiered index instance with a mock queue. + + size_t dim = 4; + size_t flat_buffer_limit = 3; + size_t nLists = 1; + + VecSimParams params = createDefaultFlatParams(dim, nLists, nLists); + auto mock_thread_pool = tieredIndexMock(); + auto *tiered_index = this->createTieredIndex(¶ms, mock_thread_pool, flat_buffer_limit); + auto allocator = tiered_index->getAllocator(); + + VecSimQueryParams queryParams = {.batchSize = 1}; + + + // Create a vector and add it to the tiered index. + labelType vec_label = 1; + TEST_DATA_T vector[dim]; + GenerateVector(vector, dim, vec_label); + VecSimIndex_AddVector(tiered_index, vector, vec_label); + ASSERT_EQ(tiered_index->indexSize(), 1); + ASSERT_EQ(tiered_index->frontendIndex->indexSize(), 1); + ASSERT_EQ(tiered_index->frontendIndex->getDistanceFrom_Unsafe(vec_label, vector), 0); + + // Execute the insert job manually (in a synchronous manner). + ASSERT_EQ(mock_thread_pool.jobQ.size(), 1); + auto *insertion_job = reinterpret_cast(mock_thread_pool.jobQ.front().job); + ASSERT_EQ(insertion_job->jobType, RAFT_TRANSFER_JOB); + + mock_thread_pool.thread_iteration(); + ASSERT_EQ(tiered_index->indexSize(), 1); + ASSERT_EQ(tiered_index->frontendIndex->indexSize(), 0); + ASSERT_EQ(tiered_index->backendIndex->indexSize(), 1); + // RAFT IVF index should have allocated a single block, while flat index should remove the + // block. + ASSERT_EQ(tiered_index->frontendIndex->indexCapacity(), 0); + // After the execution, the job should be removed from the job queue. + ASSERT_EQ(mock_thread_pool.jobQ.size(), 0); +} + +TYPED_TEST(RaftIvfTieredTest, transferJobAsync) { + size_t dim = 32; + size_t n = 500; + size_t nLists = 120; + size_t flat_buffer_limit = 160; + + size_t k = 1; + + // Create RaftIvfTiered index instance with a mock queue. + VecSimParams params = createDefaultFlatParams(dim, nLists, 20); + auto mock_thread_pool = tieredIndexMock(); + auto *tiered_index = this->createTieredIndex(¶ms, mock_thread_pool, flat_buffer_limit); + + // Launch the BG threads loop that takes jobs from the queue and executes them. + mock_thread_pool.init_threads(); + // Insert vectors + for (size_t i = 0; i < n; i++) { + GenerateAndAddVector(tiered_index, dim, i, i); + } + + mock_thread_pool.thread_pool_join(); + // Verify that the vectors were inserted to RaftIvf as expected, that the jobqueue is empty, + ASSERT_EQ(tiered_index->indexSize(), n); + ASSERT_EQ(tiered_index->backendIndex->indexSize(), n); + ASSERT_EQ(tiered_index->frontendIndex->indexSize(), 0); + ASSERT_EQ(mock_thread_pool.jobQ.size(), 0); + // Verify that the vectors were inserted to RaftIvf as expected + for (size_t i = 0; i < size_t{n / 10}; i++) { + TEST_DATA_T expected_vector[dim]; + GenerateVector(expected_vector, dim, i); + VecSimQueryReply *res = VecSimIndex_TopKQuery(tiered_index->backendIndex, expected_vector, k, nullptr, BY_SCORE); + ASSERT_EQ(VecSimQueryReply_GetCode(res), VecSim_QueryReply_OK); + ASSERT_EQ(VecSimQueryReply_Len(res), k) + ASSERT_EQ(res->results[0].id, i); + ASSERT_EQ(res->results[0].score, 0); + VecSimQueryReply_Free(res); + } +} + +TYPED_TEST(RaftIvfTieredTest, transferJob_inplace) { + size_t dim = 32; + size_t n = 200; + size_t nLists = 120; + size_t flat_buffer_limit = 160; + + size_t k = 1; + + // Create RaftIvfTiered index instance with a mock queue. + VecSimParams params = createDefaultFlatParams(dim, nLists, 20); + auto mock_thread_pool = tieredIndexMock(); + auto *tiered_index = this->createTieredIndex(¶ms, mock_thread_pool, flat_buffer_limit); + + // In the absence of BG threads to takes jobs from the queue, the tiered index should + // transfer in place when flat_buffer is over the limit. + for (size_t i = 0; i < n; i++) { + GenerateAndAddVector(tiered_index, dim, i, i); + } + + ASSERT_EQ(tiered_index->indexSize(), n); + ASSERT_EQ(tiered_index->backendIndex->indexSize(), flat_buffer_limit); + ASSERT_EQ(tiered_index->frontendIndex->indexSize(), n - flat_buffer_limit); + + // Run another batch of insertion. The tiered index should transfer inplace again. + for (size_t i = n; i < n * 2; i++) { + GenerateAndAddVector(tiered_index, dim, i, i); + } + ASSERT_EQ(tiered_index->indexSize(), 2 * n); + ASSERT_EQ(tiered_index->backendIndex->indexSize(), flat_buffer_limit * 2); + ASSERT_EQ(tiered_index->frontendIndex->indexSize(), 2 * (n - flat_buffer_limit)); + + // Run a thread loop iteration. The thread should transfer the rest of the vectors to the backend index. + mock_thread_pool.thread_iteration(); + ASSERT_EQ(tiered_index->indexSize(), 2 * n); + ASSERT_EQ(tiered_index->backendIndex->indexSize(), 2 * n); + ASSERT_EQ(tiered_index->frontendIndex->indexSize(), 0); +} + From cfa190b85fc9911db04d0428788f116fcf4c05b0 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 23 Nov 2023 15:00:42 +0100 Subject: [PATCH 18/28] Add search bitset filter --- src/VecSim/algorithms/raft_ivf/ivf.cuh | 31 ++++++++---- src/VecSim/algorithms/raft_ivf/ivf_tiered.h | 20 ++++---- tests/unit/test_raft_ivf_tiered.cpp | 52 ++++++++++++++++++++- 3 files changed, 83 insertions(+), 20 deletions(-) diff --git a/src/VecSim/algorithms/raft_ivf/ivf.cuh b/src/VecSim/algorithms/raft_ivf/ivf.cuh index 29f1e221f..61df8507d 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf.cuh +++ b/src/VecSim/algorithms/raft_ivf/ivf.cuh @@ -22,8 +22,10 @@ #include #include #include +#include #include #include +#include inline auto constexpr GetRaftDistanceType(VecSimMetric vsm) { auto result = raft::distance::DistanceType{}; @@ -108,6 +110,7 @@ public: : search_params_t{std::in_place_index<0>}}, index_{std::nullopt}, deleted_indices_{std::nullopt}, + numDeleted_{0}, idToLabelLookup_{this->allocator}, labelToIdLookup_{this->allocator} { std::visit( @@ -157,8 +160,14 @@ public: // Copy vector data to previously allocated device buffer raft::copy(vector_data_gpu.data_handle(), static_cast(vector_data), this->dim * batch_size, res.get_stream()); - std::optional> label_opt = std::nullopt; + // Create GPU vector to hold ids + internal_idx_t first_id = this->indexSize(); + internal_idx_t last_id = first_id + batch_size; + auto ids = raft::make_device_vector(res, batch_size); + raft::linalg::range(ids.data_handle(), first_id, last_id, res.get_stream()); + + // Build index if it does not exist, and extend it with the new vectors and their ids if (std::holds_alternative(build_params_)) { if (!index_) { index_ = raft::neighbors::ivf_flat::build( @@ -168,7 +177,7 @@ public: } raft::neighbors::ivf_flat::extend( res, raft::make_const_mdspan(vector_data_gpu.view()), - label_opt, + {raft::make_const_mdspan(ids.view())}, &std::get(*index_)); } else { if (!index_) { @@ -179,12 +188,10 @@ public: } raft::neighbors::ivf_pq::extend( res, raft::make_const_mdspan(vector_data_gpu.view()), - label_opt, + {raft::make_const_mdspan(ids.view())}, &std::get(*index_)); } - internal_idx_t last_id = this->indexSize(); - internal_idx_t first_id = last_id - batch_size; // Add labels to internal idToLabelLookup_ mapping this->idToLabelLookup_.insert(this->idToLabelLookup_.end(), label, label + batch_size); @@ -201,6 +208,7 @@ public: return batch_size; } int deleteVector(labelType label) override { + // Check if label exists in internal labelToIdLookup_ mapping auto search = labelToIdLookup_.find(label); if (search == labelToIdLookup_.end()) { return 0; @@ -218,6 +226,7 @@ public: // Ensure that above operation has executed on device before // returning from this function on host res.sync_stream(); + this->numDeleted_ += 1; return 1; } double getDistanceFrom_Unsafe(labelType label, const void *vector_data) const override { @@ -258,18 +267,19 @@ public: // Copy query vector to device raft::copy(vector_data_gpu.data_handle(), static_cast(queryBlob), this->dim, res.get_stream()); + auto bitset_filter = raft::neighbors::filtering::bitset_filter(deleted_indices_->view()); // Perform correct search based on index type if (std::holds_alternative(*index_)) { - raft::neighbors::ivf_flat::search( + raft::neighbors::ivf_flat::search_with_filtering( res, std::get(search_params_), std::get(*index_), raft::make_const_mdspan(vector_data_gpu.view()), - neighbors_gpu.view(), distances_gpu.view()); + neighbors_gpu.view(), distances_gpu.view(), bitset_filter); } else { - raft::neighbors::ivf_pq::search( + raft::neighbors::ivf_pq::search_with_filtering( res, std::get(search_params_), std::get(*index_), raft::make_const_mdspan(vector_data_gpu.view()), - neighbors_gpu.view(), distances_gpu.view()); + neighbors_gpu.view(), distances_gpu.view(), bitset_filter); } // Allocate host buffers to hold returned results @@ -320,7 +330,7 @@ public: if (index_) { result = std::visit([](auto &&index) { return index.size(); }, *index_); } - return result; + return result - this->numDeleted_; } VecSimIndexBasicInfo basicInfo() const override { VecSimIndexBasicInfo info = this->getBasicInfo(); @@ -364,6 +374,7 @@ private: std::optional index_; // Bitset used for deleteVectors and search filtering. std::optional> deleted_indices_; + internal_idx_t numDeleted_ = 0; vecsim_stl::vector idToLabelLookup_; vecsim_stl::unordered_map labelToIdLookup_; diff --git a/src/VecSim/algorithms/raft_ivf/ivf_tiered.h b/src/VecSim/algorithms/raft_ivf/ivf_tiered.h index ab6e03aaa..9625d19f4 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf_tiered.h +++ b/src/VecSim/algorithms/raft_ivf/ivf_tiered.h @@ -33,17 +33,19 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { if (this->frontendIndex->indexSize() >= this->flatBufferLimit) { // If the backend index is empty, build it with all the vectors // Otherwise, just add the vector to the backend index - if (this->backendIndex->indexSize() == 0) { - executeTransferJob(true); - } else { - this->mainIndexGuard.lock(); - ret = this->backendIndex->addVector(blob, label); - this->mainIndexGuard.unlock(); - return ret; - } + executeTransferJob(true); + } + + // If the backend index is already built and that the write mode is in place + // add the vector to the backend index + if (this->backendIndex->indexSize() > 0 && this->getWriteMode() == VecSim_WriteInPlace) { + this->mainIndexGuard.lock(); + ret = this->backendIndex->addVector(blob, label); + this->mainIndexGuard.unlock(); + return ret; } - // Add the vector to the flat index + // Otherwise, add the vector to the flat index this->flatIndexGuard.lock(); ret = this->frontendIndex->addVector(blob, label); this->flatIndexGuard.unlock(); diff --git a/tests/unit/test_raft_ivf_tiered.cpp b/tests/unit/test_raft_ivf_tiered.cpp index 978a24e9a..f02b4645f 100644 --- a/tests/unit/test_raft_ivf_tiered.cpp +++ b/tests/unit/test_raft_ivf_tiered.cpp @@ -196,7 +196,7 @@ TYPED_TEST(RaftIvfTieredTest, transferJobAsync) { GenerateVector(expected_vector, dim, i); VecSimQueryReply *res = VecSimIndex_TopKQuery(tiered_index->backendIndex, expected_vector, k, nullptr, BY_SCORE); ASSERT_EQ(VecSimQueryReply_GetCode(res), VecSim_QueryReply_OK); - ASSERT_EQ(VecSimQueryReply_Len(res), k) + ASSERT_EQ(VecSimQueryReply_Len(res), k); ASSERT_EQ(res->results[0].id, i); ASSERT_EQ(res->results[0].score, 0); VecSimQueryReply_Free(res); @@ -241,3 +241,53 @@ TYPED_TEST(RaftIvfTieredTest, transferJob_inplace) { ASSERT_EQ(tiered_index->frontendIndex->indexSize(), 0); } +TYPED_TEST(RaftIvfTieredTest, deleteVector_backend) { + size_t dim = 32; + size_t n = 500; + size_t nLists = 120; + size_t nDelete = 10; + size_t flat_buffer_limit = 1000; + + size_t k = 1; + + // Create RaftIvfTiered index instance with a mock queue. + VecSimParams params = createDefaultFlatParams(dim, nLists, 20); + auto mock_thread_pool = tieredIndexMock(); + auto *tiered_index = this->createTieredIndex(¶ms, mock_thread_pool, flat_buffer_limit); + + labelType vec_label = 0; + // Delete from an empty index. + ASSERT_EQ(VecSimIndex_DeleteVector(tiered_index, vec_label), 0); + + // Insert vectors + for (size_t i = 0; i < n; i++) { + GenerateAndAddVector(tiered_index, dim, i, i); + } + // Use just one thread to transfer all the vectors + mock_thread_pool.thread_iteration(); + + // Check that the backend index has the first 12 vectors + ASSERT_EQ(tiered_index->indexSize(), n); + ASSERT_EQ(tiered_index->backendIndex->indexSize(), n); + ASSERT_EQ(tiered_index->frontendIndex->indexSize(), 0); + for (size_t i = 0; i < nDelete + 2; i++) { + TEST_DATA_T expected_vector[dim]; + GenerateVector(expected_vector, dim, i); + VecSimQueryReply *res = VecSimIndex_TopKQuery(tiered_index->backendIndex, expected_vector, k, nullptr, BY_SCORE); + ASSERT_EQ(VecSimQueryReply_GetCode(res), VecSim_QueryReply_OK); + ASSERT_EQ(VecSimQueryReply_Len(res), k); + ASSERT_EQ(res->results[0].id, i); + ASSERT_EQ(res->results[0].score, 0); + VecSimQueryReply_Free(res); + } + + // Delete 10 first vectors + for (size_t i = 0; i < nDelete; i++) { + VecSimIndex_DeleteVector(tiered_index, i); + } + + ASSERT_EQ(tiered_index->indexSize(), n - nDelete); + ASSERT_EQ(tiered_index->backendIndex->indexSize(), n - nDelete); + ASSERT_EQ(tiered_index->frontendIndex->indexSize(), 0); +} + From 17263188d2ce59c38dedad29d156d45afb67a9bd Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 24 Nov 2023 17:28:37 +0100 Subject: [PATCH 19/28] Add USE_CUDA guardrails for compilation --- CMakeLists.txt | 14 +++++++------- cmake/raft.cmake | 16 ++++++++-------- src/VecSim/index_factories/index_factory.cpp | 14 ++++++++++++++ src/VecSim/index_factories/tiered_factory.cpp | 19 ++++++++++++++----- tests/benchmark/bm_common.h | 14 +++++++++----- .../bm_basics_initialize_fp32.h | 3 +++ tests/benchmark/bm_vecsim_index.h | 7 +++++++ 7 files changed, 62 insertions(+), 25 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0fbef417f..b3e9e1ffe 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,15 +32,15 @@ project(VectorSimilarity) if (USE_CUDA) # List of architectures to generate device code set(CMAKE_CUDA_ARCHITECTURES "native") - # Enable CUDA compilation for this project - enable_language(CUDA) - # Add definition for conditional compilation of CUDA components + # Enable CUDA compilation for this project + enable_language(CUDA) + # Add definition for conditional compilation of CUDA components add_definitions(-DUSE_CUDA) - # Perform all RAFT-specific CMake setup + # Perform all RAFT-specific CMake setup include(cmake/raft.cmake) - # Required flags for compiling RAFT - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr") - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -std=c++17") + # Required flags for compiling RAFT + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -std=c++17") set(CMAKE_CUDA_FLAGS_RELEASE "-O3") set(CMAKE_CUDA_FLAGS_DEBUG "-g") diff --git a/cmake/raft.cmake b/cmake/raft.cmake index 2c9ac5509..b9d1c15cd 100644 --- a/cmake/raft.cmake +++ b/cmake/raft.cmake @@ -4,8 +4,8 @@ if(USE_CUDA) # Set which version of RAFT to use (defined separately for testing # minimal dependency changes if necessary) set(RAFT_VERSION "${RAPIDS_VERSION}") - set(RAFT_FORK "rapidsai") - set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}") + set(RAFT_FORK "rapidsai") + set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}") # Download CMake file for bootstrapping RAPIDS-CMake, a utility that # simplifies handling of complex RAPIDS dependencies @@ -13,7 +13,7 @@ if(USE_CUDA) file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION}/RAPIDS.cmake ${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) endif() - include(${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) + include(${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) # General tool for orchestrating RAPIDS dependencies include(rapids-cmake) @@ -43,13 +43,13 @@ if(USE_CUDA) cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} ) - set(RAFT_COMPONENTS "") - if(PKG_COMPILE_LIBRARY) + set(RAFT_COMPONENTS "") + if(PKG_COMPILE_LIBRARY) string(APPEND RAFT_COMPONENTS " compiled") - endif() - # Invoke CPM find_package() + endif() + # Invoke CPM find_package() # (From rapids-cpm) - rapids_cpm_find(raft ${PKG_VERSION} + rapids_cpm_find(raft ${PKG_VERSION} GLOBAL_TARGETS raft::raft BUILD_EXPORT_SET VectorSimilarity-exports INSTALL_EXPORT_SET VectorSimilarity-exports diff --git a/src/VecSim/index_factories/index_factory.cpp b/src/VecSim/index_factories/index_factory.cpp index bcda2c7e7..cffdd6f5d 100644 --- a/src/VecSim/index_factories/index_factory.cpp +++ b/src/VecSim/index_factories/index_factory.cpp @@ -8,7 +8,9 @@ #include "hnsw_factory.h" #include "brute_force_factory.h" #include "tiered_factory.h" +#ifdef USE_CUDA #include "raft_ivf_factory.h" +#endif #include "VecSim/vec_sim_index.h" namespace VecSimFactory { @@ -28,7 +30,11 @@ VecSimIndex *NewIndex(const VecSimParams *params) { } case VecSimAlgo_RAFT_IVFFLAT: case VecSimAlgo_RAFT_IVFPQ: { +#ifdef USE_CUDA index = RaftIvfFactory::NewIndex(¶ms->algoParams.raftIvfParams); +#else + throw std::runtime_error("RAFT_IVFFLAT and RAFT_IVFPQ are not supported in CPU version"); +#endif break; } case VecSimAlgo_TIERED: { @@ -50,7 +56,11 @@ size_t EstimateInitialSize(const VecSimParams *params) { return BruteForceFactory::EstimateInitialSize(¶ms->algoParams.bfParams); case VecSimAlgo_RAFT_IVFFLAT: case VecSimAlgo_RAFT_IVFPQ: +#ifdef USE_CUDA return RaftIvfFactory::EstimateInitialSize(¶ms->algoParams.raftIvfParams); +#else + throw std::runtime_error("RAFT_IVFFLAT and RAFT_IVFPQ are not supported in CPU version"); +#endif case VecSimAlgo_TIERED: return TieredFactory::EstimateInitialSize(¶ms->algoParams.tieredParams); } @@ -65,7 +75,11 @@ size_t EstimateElementSize(const VecSimParams *params) { return BruteForceFactory::EstimateElementSize(¶ms->algoParams.bfParams); case VecSimAlgo_RAFT_IVFFLAT: case VecSimAlgo_RAFT_IVFPQ: +#ifdef USE_CUDA return RaftIvfFactory::EstimateElementSize(¶ms->algoParams.raftIvfParams); +#else + throw std::runtime_error("RAFT_IVFFLAT and RAFT_IVFPQ are not supported in CPU version"); +#endif case VecSimAlgo_TIERED: return TieredFactory::EstimateElementSize(¶ms->algoParams.tieredParams); } diff --git a/src/VecSim/index_factories/tiered_factory.cpp b/src/VecSim/index_factories/tiered_factory.cpp index 03bd48ed0..bad304469 100644 --- a/src/VecSim/index_factories/tiered_factory.cpp +++ b/src/VecSim/index_factories/tiered_factory.cpp @@ -96,12 +96,13 @@ VecSimIndex *NewIndex(const TieredIndexParams *params) { } } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFFLAT || params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFPQ) { +#ifdef USE_CUDA VecSimType type = params->primaryIndexParams->algoParams.raftIvfParams.type; - if (type == VecSimType_FLOAT32) { - return TieredRaftIvfFactory::NewIndex(params); - }/* else if (type == VecSimType_FLOAT64) { - return TieredRaftIvfFactory::NewIndex(params); - }*/ + assert(type == VecSimType_FLOAT32); // TODO: support float64 + return TieredRaftIvfFactory::NewIndex(params); +#else + throw std::runtime_error("RAFT_IVFFLAT and RAFT_IVFPQ are not supported in CPU version"); +#endif } return nullptr; // Invalid algorithm or type. } @@ -114,7 +115,11 @@ size_t EstimateInitialSize(const TieredIndexParams *params) { est += TieredHNSWFactory::EstimateInitialSize(params, bf_params); } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFFLAT || params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFPQ) { +#ifdef USE_CUDA est += TieredRaftIvfFactory::EstimateInitialSize(params); +#else + throw std::runtime_error("RAFT_IVFFLAT and RAFT_IVFPQ are not supported in CPU version"); +#endif } est += BruteForceFactory::EstimateInitialSize(&bf_params); @@ -127,7 +132,11 @@ size_t EstimateElementSize(const TieredIndexParams *params) { est = HNSWFactory::EstimateElementSize(¶ms->primaryIndexParams->algoParams.hnswParams); } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFFLAT || params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFPQ) { +#ifdef USE_CUDA est = RaftIvfFactory::EstimateElementSize(¶ms->primaryIndexParams->algoParams.raftIvfParams); +#else + throw std::runtime_error("RAFT_IVFFLAT and RAFT_IVFPQ are not supported in CPU version"); +#endif } return est; } diff --git a/tests/benchmark/bm_common.h b/tests/benchmark/bm_common.h index 7d01abd08..0827caa05 100644 --- a/tests/benchmark/bm_common.h +++ b/tests/benchmark/bm_common.h @@ -1,7 +1,9 @@ #pragma once #include "bm_vecsim_index.h" +#ifdef USE_CUDA #include "VecSim/algorithms/raft_ivf/ivf_tiered.h" +#endif size_t BM_VecSimGeneral::block_size = 1024; @@ -19,10 +21,6 @@ class BM_VecSimCommon : public BM_VecSimIndex { static void RunTopK_HNSW(benchmark::State &st, size_t ef, size_t iter, size_t k, std::atomic_int &correct, unsigned short index_offset = 0, bool is_tiered = false); - static void RunTopK_TieredRaftIVFFlat(benchmark::State &st, size_t iter, size_t k, std::atomic_int &correct, - unsigned short index_offset = 0, bool is_tiered = true); - static void RunTopK_TieredRaftIVFPQ(benchmark::State &st, size_t iter, size_t k, std::atomic_int &correct, - unsigned short index_offset = 0, bool is_tiered = true); // Search for the K closest vectors to the query in the index. K is defined in the // test registration (initialization file). @@ -31,12 +29,14 @@ class BM_VecSimCommon : public BM_VecSimIndex { // with respect to the results returned by the flat index. static void TopK_HNSW(benchmark::State &st, unsigned short index_offset = 0); static void TopK_Tiered(benchmark::State &st, unsigned short index_offset = 0); +#ifdef USE_CUDA // Run TopK using Raft IVF Flat tiered and flat index and calculate the recall of the Raft IVF // Flat algorithm with respect to the results returned by the flat index. static void TopK_TieredRaftIVFFlat(benchmark::State &st, unsigned short index_offset = 0); // Run TopK using both Raft IVF PQ Tiered and flat index and calculate the recall of the Raft IVF // PQ algorithm with respect to the results returned by the flat index. static void TopK_TieredRaftIVFPQ(benchmark::State &st, unsigned short index_offset = 0); +#endif // Does nothing but returning the index memory. static void Memory_FLAT(benchmark::State &st, unsigned short index_offset = 0); @@ -180,13 +180,14 @@ void BM_VecSimCommon::TopK_Tiered(benchmark::State &st, unsigned s st.counters["num_threads"] = (double)BM_VecSimGeneral::mock_thread_pool.thread_pool_size; } +#ifdef USE_CUDA template void BM_VecSimCommon::TopK_TieredRaftIVFFlat(benchmark::State &st, unsigned short index_offset) { size_t k = st.range(0); size_t n_probes = st.range(1); std::atomic_int correct = 0; std::atomic_int iter = 0; - auto *tiered_index = //reinterpret_cast *>(INDICES[VecSimAlgo_RAFT_IVFFLAT]); + auto *tiered_index = reinterpret_cast *>(INDICES[VecSimAlgo_RAFT_IVFFLAT]); size_t total_iters = 50; tiered_index->setNProbes(n_probes); @@ -227,6 +228,7 @@ void BM_VecSimCommon::TopK_TieredRaftIVFFlat(benchmark::State &st, st.counters["Recall"] = (float)correct / (float)(k * iter); st.counters["num_threads"] = (double)BM_VecSimGeneral::mock_thread_pool.thread_pool_size; } +#endif #define REGISTER_TopK_BF(BM_CLASS, BM_FUNC) \ BENCHMARK_REGISTER_F(BM_CLASS, BM_FUNC) \ @@ -261,6 +263,7 @@ void BM_VecSimCommon::TopK_TieredRaftIVFFlat(benchmark::State &st, ->Iterations(50) \ ->Unit(benchmark::kMillisecond) +#ifdef USE_CUDA #define REGISTER_TopK_TieredRaftIVF(BM_CLASS, BM_FUNC) \ BENCHMARK_REGISTER_F(BM_CLASS, BM_FUNC) \ ->Args({10, 20}) \ @@ -275,3 +278,4 @@ void BM_VecSimCommon::TopK_TieredRaftIVFFlat(benchmark::State &st, ->ArgNames({"k", "n_probes"}) \ ->Iterations(50) \ ->Unit(benchmark::kMillisecond) +#endif diff --git a/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h b/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h index 4b22e77e6..1ca8f1f29 100644 --- a/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h +++ b/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h @@ -49,11 +49,14 @@ REGISTER_TopK_HNSW(BM_VecSimCommon, BM_FUNC_NAME(TopK, HNSW)); BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimCommon, BM_FUNC_NAME(TopK, Tiered), fp32_index_t) (benchmark::State &st) { TopK_Tiered(st); } REGISTER_TopK_Tiered(BM_VecSimCommon, BM_FUNC_NAME(TopK, Tiered)); + +#ifdef USE_CUDA // TopK Tiered RAFT IVF Flat BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimCommon, BM_FUNC_NAME(TopK, TieredRaftIVF), fp32_index_t) (benchmark::State &st) { TopK_TieredRaftIVFFlat(st); } REGISTER_TopK_TieredRaftIVF(BM_VecSimCommon, BM_FUNC_NAME(TopK, TieredRaftIVF)); +#endif // Range BF BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimBasics, BM_FUNC_NAME(Range, BF), fp32_index_t) diff --git a/tests/benchmark/bm_vecsim_index.h b/tests/benchmark/bm_vecsim_index.h index 20a66b353..0bf0c27ee 100644 --- a/tests/benchmark/bm_vecsim_index.h +++ b/tests/benchmark/bm_vecsim_index.h @@ -2,7 +2,10 @@ #include "bm_vecsim_general.h" #include "VecSim/index_factories/tiered_factory.h" +#ifdef USE_CUDA #include "VecSim/index_factories/raft_ivf_tiered_factory.h" +#include "VecSim/algorithms/raft_ivf/ivf_tiered.h" +#endif template class BM_VecSimIndex : public BM_VecSimGeneral { @@ -112,6 +115,7 @@ void BM_VecSimIndex::Initialize() { // Launch the BG threads loop that takes jobs from the queue and executes them. mock_thread_pool.init_threads(); +#ifdef USE_CUDA // Create RAFFT IVF Flat tiered index. auto &mock_thread_pool_ivf_flat = BM_VecSimGeneral::mock_thread_pool_raft; @@ -129,6 +133,7 @@ void BM_VecSimIndex::Initialize() { mock_thread_pool_ivf_flat.init_threads(); indices.push_back(tiered_raft_ivf_flat_index); +#endif // Add the same vectors to Flat index. for (size_t i = 0; i < n_vectors; ++i) { @@ -136,7 +141,9 @@ void BM_VecSimIndex::Initialize() { // Fot multi value indices, the internal id is not necessarily equal the label. size_t label = CastToHNSW(indices[VecSimAlgo_HNSWLIB])->getExternalLabel(i); VecSimIndex_AddVector(indices[VecSimAlgo_BF], blob, label); +#ifdef USE_CUDA VecSimIndex_AddVector(indices[VecSimAlgo_RAFT_IVFFLAT], blob, label); +#endif } mock_thread_pool_ivf_flat.thread_pool_wait(100); From 8ba37678091b4557807580b624c92976218f3ec8 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 24 Nov 2023 17:36:08 +0100 Subject: [PATCH 20/28] Fix style --- src/VecSim/algorithms/raft_ivf/ivf.cuh | 51 +++++++-------- .../algorithms/raft_ivf/ivf_interface.h | 6 +- src/VecSim/algorithms/raft_ivf/ivf_tiered.h | 19 +++--- src/VecSim/index_factories/index_factory.cpp | 3 +- .../index_factories/raft_ivf_factory.cu | 32 +++++----- .../raft_ivf_tiered_factory.cpp | 5 +- src/VecSim/index_factories/tiered_factory.cpp | 6 +- src/VecSim/vec_sim_common.h | 13 +++- tests/benchmark/bm_common.h | 16 ++--- tests/benchmark/bm_vecsim_general.h | 51 ++++++++------- tests/benchmark/bm_vecsim_index.h | 4 +- tests/unit/test_raft_ivf_tiered.cpp | 63 ++++++++++--------- 12 files changed, 143 insertions(+), 126 deletions(-) diff --git a/src/VecSim/algorithms/raft_ivf/ivf.cuh b/src/VecSim/algorithms/raft_ivf/ivf.cuh index 61df8507d..fd70cd372 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf.cuh +++ b/src/VecSim/algorithms/raft_ivf/ivf.cuh @@ -108,11 +108,8 @@ public: : build_params_t{std::in_place_index<0>}}, search_params_{raftIvfParams->usePQ ? search_params_t{std::in_place_index<1>} : search_params_t{std::in_place_index<0>}}, - index_{std::nullopt}, - deleted_indices_{std::nullopt}, - numDeleted_{0}, - idToLabelLookup_{this->allocator}, - labelToIdLookup_{this->allocator} { + index_{std::nullopt}, deleted_indices_{std::nullopt}, numDeleted_{0}, + idToLabelLookup_{this->allocator}, labelToIdLookup_{this->allocator} { std::visit( [raftIvfParams](auto &&inner) { inner.metric = GetRaftDistanceType(raftIvfParams->metric); @@ -151,15 +148,15 @@ public: return addVectorBatch(vector_data, &label, 1, auxiliaryCtx); } int addVectorBatch(const void *vector_data, labelType *label, size_t batch_size, - void *auxiliaryCtx = nullptr) override { - const auto& res = raft::device_resources_manager::get_device_resources(); + void *auxiliaryCtx = nullptr) override { + const auto &res = raft::device_resources_manager::get_device_resources(); // Allocate memory on device to hold vectors to be added auto vector_data_gpu = raft::make_device_matrix(res, batch_size, this->dim); // Copy vector data to previously allocated device buffer raft::copy(vector_data_gpu.data_handle(), static_cast(vector_data), - this->dim * batch_size, res.get_stream()); + this->dim * batch_size, res.get_stream()); // Create GPU vector to hold ids internal_idx_t first_id = this->indexSize(); @@ -175,10 +172,9 @@ public: raft::make_const_mdspan(vector_data_gpu.view())); deleted_indices_ = {raft::core::bitset(res, 0)}; } - raft::neighbors::ivf_flat::extend( - res, raft::make_const_mdspan(vector_data_gpu.view()), - {raft::make_const_mdspan(ids.view())}, - &std::get(*index_)); + raft::neighbors::ivf_flat::extend(res, raft::make_const_mdspan(vector_data_gpu.view()), + {raft::make_const_mdspan(ids.view())}, + &std::get(*index_)); } else { if (!index_) { index_ = raft::neighbors::ivf_pq::build( @@ -186,13 +182,11 @@ public: raft::make_const_mdspan(vector_data_gpu.view())); deleted_indices_ = {raft::core::bitset(res, 0)}; } - raft::neighbors::ivf_pq::extend( - res, raft::make_const_mdspan(vector_data_gpu.view()), - {raft::make_const_mdspan(ids.view())}, - &std::get(*index_)); + raft::neighbors::ivf_pq::extend(res, raft::make_const_mdspan(vector_data_gpu.view()), + {raft::make_const_mdspan(ids.view())}, + &std::get(*index_)); } - // Add labels to internal idToLabelLookup_ mapping this->idToLabelLookup_.insert(this->idToLabelLookup_.end(), label, label + batch_size); for (auto i = 0; i < batch_size; ++i) { @@ -201,7 +195,7 @@ public: // Update the size of the deleted indices bitset deleted_indices_->resize(res, deleted_indices_->size() + batch_size); - + // Ensure that above operation has executed on device before // returning from this function on host res.sync_stream(); @@ -213,7 +207,7 @@ public: if (search == labelToIdLookup_.end()) { return 0; } - const auto& res = raft::device_resources_manager::get_device_resources(); + const auto &res = raft::device_resources_manager::get_device_resources(); // Create GPU vector to hold ids to mark as deleted internal_idx_t id = search->second; auto id_gpu = raft::make_device_vector(res, 1); @@ -245,12 +239,10 @@ public: return result; } // void increaseCapacity() override { assert(!"increaseCapacity not implemented"); } - inline size_t indexLabelCount() const override { - return this->labelToIdLookup_.size(); - } + inline size_t indexLabelCount() const override { return this->labelToIdLookup_.size(); } VecSimQueryReply *topKQuery(const void *queryBlob, size_t k, VecSimQueryParams *queryParams) const override { - const auto& res = raft::device_resources_manager::get_device_resources(); + const auto &res = raft::device_resources_manager::get_device_resources(); auto result_list = new VecSimQueryReply(this->allocator); auto nVectors = this->indexSize(); if (nVectors == 0 || k == 0 || !index_.has_value()) { @@ -260,23 +252,26 @@ public: // index k = std::min(k, nVectors); // Allocate memory on device for search vector - auto vector_data_gpu = raft::make_device_matrix(res, 1, this->dim); + auto vector_data_gpu = + raft::make_device_matrix(res, 1, this->dim); // Allocate memory on device for neighbor and distance results auto neighbors_gpu = raft::make_device_matrix(res, 1, k); auto distances_gpu = raft::make_device_matrix(res, 1, k); // Copy query vector to device - raft::copy(vector_data_gpu.data_handle(), static_cast(queryBlob), this->dim, - res.get_stream()); + raft::copy(vector_data_gpu.data_handle(), static_cast(queryBlob), + this->dim, res.get_stream()); auto bitset_filter = raft::neighbors::filtering::bitset_filter(deleted_indices_->view()); // Perform correct search based on index type if (std::holds_alternative(*index_)) { - raft::neighbors::ivf_flat::search_with_filtering( + raft::neighbors::ivf_flat::search_with_filtering( res, std::get(search_params_), std::get(*index_), raft::make_const_mdspan(vector_data_gpu.view()), neighbors_gpu.view(), distances_gpu.view(), bitset_filter); } else { - raft::neighbors::ivf_pq::search_with_filtering( + raft::neighbors::ivf_pq::search_with_filtering( res, std::get(search_params_), std::get(*index_), raft::make_const_mdspan(vector_data_gpu.view()), neighbors_gpu.view(), distances_gpu.view(), bitset_filter); diff --git a/src/VecSim/algorithms/raft_ivf/ivf_interface.h b/src/VecSim/algorithms/raft_ivf/ivf_interface.h index db32b510f..35c2faa34 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf_interface.h +++ b/src/VecSim/algorithms/raft_ivf/ivf_interface.h @@ -9,9 +9,9 @@ // Non-CUDA Interface of the RaftIVF index to avoid importing CUDA code // in the tiered index. template -struct RaftIvfInterface : public VecSimIndexAbstract -{ - RaftIvfInterface(const AbstractIndexInitParams ¶ms) : VecSimIndexAbstract(params) {} +struct RaftIvfInterface : public VecSimIndexAbstract { + RaftIvfInterface(const AbstractIndexInitParams ¶ms) + : VecSimIndexAbstract(params) {} virtual uint32_t nLists() const = 0; virtual inline void setNProbes(uint32_t n_probes) = 0; diff --git a/src/VecSim/algorithms/raft_ivf/ivf_tiered.h b/src/VecSim/algorithms/raft_ivf/ivf_tiered.h index 9625d19f4..2a36754b4 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf_tiered.h +++ b/src/VecSim/algorithms/raft_ivf/ivf_tiered.h @@ -21,7 +21,8 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { assert( raftIvfIndex->nLists() < this->flatBufferLimit && "The flat buffer limit must be greater than the number of lists in the backend index"); - this->minVectorsInit = std::max((size_t)1, tieredParams.specificParams.tieredRaftIvfParams.minVectorsInit); + this->minVectorsInit = + std::max((size_t)1, tieredParams.specificParams.tieredRaftIvfParams.minVectorsInit); } ~TieredRaftIvfIndex() { // Delete all the pending jobs @@ -35,7 +36,7 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { // Otherwise, just add the vector to the backend index executeTransferJob(true); } - + // If the backend index is already built and that the write mode is in place // add the vector to the backend index if (this->backendIndex->indexSize() > 0 && this->getWriteMode() == VecSim_WriteInPlace) { @@ -63,7 +64,8 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { if (this->frontendIndex->isLabelExists(label)) { this->flatIndexGuard.unlock_shared(); this->flatIndexGuard.lock(); - // Check again if the label exists, as it may have been removed while we released the lock. + // Check again if the label exists, as it may have been removed while we released the + // lock. if (this->frontendIndex->isLabelExists(label)) { // Remove every id that corresponds the label from the flat buffer. auto updated_ids = this->frontendIndex->deleteVectorAndGetUpdatedIds(label); @@ -98,8 +100,8 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { this->flatIndexGuard.unlock_shared(); this->mainIndexGuard.unlock_shared(); std::vector output; - std::set_union(flat_labels.begin(), flat_labels.end(), raft_ivf_labels.begin(), raft_ivf_labels.end(), - std::back_inserter(output)); + std::set_union(flat_labels.begin(), flat_labels.end(), raft_ivf_labels.begin(), + raft_ivf_labels.end(), std::back_inserter(output)); return output.size(); } @@ -158,7 +160,7 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { private: size_t minVectorsInit = 1; - inline auto* getBackendIndex() const { + inline auto *getBackendIndex() const { return dynamic_cast *>(this->backendIndex); } @@ -181,7 +183,7 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { return; } } - + // Check that there are still vectors to transfer after exclusive lock this->flatIndexGuard.lock(); nVectors = this->frontendIndex->indexSize(); @@ -208,7 +210,8 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { std::copy_n(this->frontendIndex->getLabels().data(), nVectors, labelData); this->frontendIndex->clear(); - // Lock the main index before unlocking the front index so that both indexes are not empty at the same time + // Lock the main index before unlocking the front index so that both indexes are not empty + // at the same time this->mainIndexGuard.lock(); this->flatIndexGuard.unlock(); diff --git a/src/VecSim/index_factories/index_factory.cpp b/src/VecSim/index_factories/index_factory.cpp index cffdd6f5d..60365b45c 100644 --- a/src/VecSim/index_factories/index_factory.cpp +++ b/src/VecSim/index_factories/index_factory.cpp @@ -33,7 +33,8 @@ VecSimIndex *NewIndex(const VecSimParams *params) { #ifdef USE_CUDA index = RaftIvfFactory::NewIndex(¶ms->algoParams.raftIvfParams); #else - throw std::runtime_error("RAFT_IVFFLAT and RAFT_IVFPQ are not supported in CPU version"); + throw std::runtime_error( + "RAFT_IVFFLAT and RAFT_IVFPQ are not supported in CPU version"); #endif break; } diff --git a/src/VecSim/index_factories/raft_ivf_factory.cu b/src/VecSim/index_factories/raft_ivf_factory.cu index 7817a282d..e499cea89 100644 --- a/src/VecSim/index_factories/raft_ivf_factory.cu +++ b/src/VecSim/index_factories/raft_ivf_factory.cu @@ -6,23 +6,24 @@ namespace RaftIvfFactory { static AbstractIndexInitParams NewAbstractInitParams(const VecSimParams *params) { const RaftIvfParams *raftIvfParams = ¶ms->algoParams.raftIvfParams; - AbstractIndexInitParams abstractInitParams = {.allocator = - VecSimAllocator::newVecsimAllocator(), - .dim = raftIvfParams->dim, - .vecType = raftIvfParams->type, - .metric = raftIvfParams->metric, - //.multi = raftIvfParams->multi, - //.logCtx = params->logCtx - }; + AbstractIndexInitParams abstractInitParams = { + .allocator = VecSimAllocator::newVecsimAllocator(), + .dim = raftIvfParams->dim, + .vecType = raftIvfParams->type, + .metric = raftIvfParams->metric, + //.multi = raftIvfParams->multi, + //.logCtx = params->logCtx + }; return abstractInitParams; } -VecSimIndex *NewIndex(const RaftIvfParams *raftIvfParams, const AbstractIndexInitParams &abstractInitParams) { +VecSimIndex *NewIndex(const RaftIvfParams *raftIvfParams, + const AbstractIndexInitParams &abstractInitParams) { assert(raftIvfParams->type == VecSimType_FLOAT32 && "Invalid IVF data type algorithm"); if (raftIvfParams->type == VecSimType_FLOAT32) { return new (abstractInitParams.allocator) RaftIvfIndex(raftIvfParams, abstractInitParams); - } + } // If we got here something is wrong. return NULL; @@ -44,19 +45,22 @@ size_t EstimateInitialSize(const RaftIvfParams *raftIvfParams) { // Constant part (not effected by parameters). size_t est = sizeof(VecSimAllocator) + allocations_overhead; - est += sizeof(RaftIvfIndex); // Object size + est += sizeof(RaftIvfIndex); // Object size if (!raftIvfParams->usePQ) { // Size of each cluster data - est += raftIvfParams->nLists * sizeof(raft::neighbors::ivf_flat::list_data); + est += raftIvfParams->nLists * + sizeof(raft::neighbors::ivf_flat::list_data); // Vector of shared ptr to cluster - est += raftIvfParams->nLists * sizeof(std::shared_ptr>); + est += raftIvfParams->nLists * + sizeof(std::shared_ptr>); } else { // Size of each cluster data est += raftIvfParams->nLists * sizeof(raft::neighbors::ivf_pq::list_data); // accum_sorted_sizes_ Array est += raftIvfParams->nLists * sizeof(std::int64_t); // vector of shared ptr to cluster - est += raftIvfParams->nLists * sizeof(std::shared_ptr>); + est += raftIvfParams->nLists * + sizeof(std::shared_ptr>); } return est; } diff --git a/src/VecSim/index_factories/raft_ivf_tiered_factory.cpp b/src/VecSim/index_factories/raft_ivf_tiered_factory.cpp index 7ecfcb8b7..2befa3a13 100644 --- a/src/VecSim/index_factories/raft_ivf_tiered_factory.cpp +++ b/src/VecSim/index_factories/raft_ivf_tiered_factory.cpp @@ -6,10 +6,9 @@ namespace TieredRaftIvfFactory { -VecSimIndex *NewIndex(const TieredIndexParams *params) -{ +VecSimIndex *NewIndex(const TieredIndexParams *params) { assert(params->primaryIndexParams->algoParams.raftIvfParams.type == VecSimType_FLOAT32 && - "Invalid IVF data type algorithm"); + "Invalid IVF data type algorithm"); using DataType = float; using DistType = float; diff --git a/src/VecSim/index_factories/tiered_factory.cpp b/src/VecSim/index_factories/tiered_factory.cpp index bad304469..bae640672 100644 --- a/src/VecSim/index_factories/tiered_factory.cpp +++ b/src/VecSim/index_factories/tiered_factory.cpp @@ -98,7 +98,6 @@ VecSimIndex *NewIndex(const TieredIndexParams *params) { params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFPQ) { #ifdef USE_CUDA VecSimType type = params->primaryIndexParams->algoParams.raftIvfParams.type; - assert(type == VecSimType_FLOAT32); // TODO: support float64 return TieredRaftIvfFactory::NewIndex(params); #else throw std::runtime_error("RAFT_IVFFLAT and RAFT_IVFPQ are not supported in CPU version"); @@ -113,7 +112,7 @@ size_t EstimateInitialSize(const TieredIndexParams *params) { BFParams bf_params{}; if (params->primaryIndexParams->algo == VecSimAlgo_HNSWLIB) { est += TieredHNSWFactory::EstimateInitialSize(params, bf_params); - } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFFLAT || + } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFFLAT || params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFPQ) { #ifdef USE_CUDA est += TieredRaftIvfFactory::EstimateInitialSize(params); @@ -133,7 +132,8 @@ size_t EstimateElementSize(const TieredIndexParams *params) { } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFFLAT || params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFPQ) { #ifdef USE_CUDA - est = RaftIvfFactory::EstimateElementSize(¶ms->primaryIndexParams->algoParams.raftIvfParams); + est = RaftIvfFactory::EstimateElementSize( + ¶ms->primaryIndexParams->algoParams.raftIvfParams); #else throw std::runtime_error("RAFT_IVFFLAT and RAFT_IVFPQ are not supported in CPU version"); #endif diff --git a/src/VecSim/vec_sim_common.h b/src/VecSim/vec_sim_common.h index 63335a674..60d8a23c9 100644 --- a/src/VecSim/vec_sim_common.h +++ b/src/VecSim/vec_sim_common.h @@ -38,13 +38,22 @@ typedef enum { } VecSimType; // Algorithm type/library. -typedef enum { VecSimAlgo_BF, VecSimAlgo_HNSWLIB, VecSimAlgo_TIERED, VecSimAlgo_RAFT_IVFFLAT, VecSimAlgo_RAFT_IVFPQ } VecSimAlgo; +typedef enum { + VecSimAlgo_BF, + VecSimAlgo_HNSWLIB, + VecSimAlgo_TIERED, + VecSimAlgo_RAFT_IVFFLAT, + VecSimAlgo_RAFT_IVFPQ +} VecSimAlgo; // Distance metric typedef enum { VecSimMetric_L2, VecSimMetric_IP, VecSimMetric_Cosine } VecSimMetric; // Codebook kind for IVFPQ indexes -typedef enum { RaftIVFPQCodebookKind_PerCluster, RaftIVFPQCodebookKind_PerSubspace } RaftIVFPQCodebookKind; +typedef enum { + RaftIVFPQCodebookKind_PerCluster, + RaftIVFPQCodebookKind_PerSubspace +} RaftIVFPQCodebookKind; // CUDA types supported by GPU-accelerated indexes typedef enum { CUDAType_R_32F, CUDAType_R_16F, CUDAType_R_8U } CudaType; diff --git a/tests/benchmark/bm_common.h b/tests/benchmark/bm_common.h index 0827caa05..b5f86decd 100644 --- a/tests/benchmark/bm_common.h +++ b/tests/benchmark/bm_common.h @@ -5,7 +5,6 @@ #include "VecSim/algorithms/raft_ivf/ivf_tiered.h" #endif - size_t BM_VecSimGeneral::block_size = 1024; // Class for common bm for basic index and updated index. @@ -33,8 +32,8 @@ class BM_VecSimCommon : public BM_VecSimIndex { // Run TopK using Raft IVF Flat tiered and flat index and calculate the recall of the Raft IVF // Flat algorithm with respect to the results returned by the flat index. static void TopK_TieredRaftIVFFlat(benchmark::State &st, unsigned short index_offset = 0); - // Run TopK using both Raft IVF PQ Tiered and flat index and calculate the recall of the Raft IVF - // PQ algorithm with respect to the results returned by the flat index. + // Run TopK using both Raft IVF PQ Tiered and flat index and calculate the recall of the Raft + // IVF PQ algorithm with respect to the results returned by the flat index. static void TopK_TieredRaftIVFPQ(benchmark::State &st, unsigned short index_offset = 0); #endif @@ -182,7 +181,8 @@ void BM_VecSimCommon::TopK_Tiered(benchmark::State &st, unsigned s #ifdef USE_CUDA template -void BM_VecSimCommon::TopK_TieredRaftIVFFlat(benchmark::State &st, unsigned short index_offset) { +void BM_VecSimCommon::TopK_TieredRaftIVFFlat(benchmark::State &st, + unsigned short index_offset) { size_t k = st.range(0); size_t n_probes = st.range(1); std::atomic_int correct = 0; @@ -195,11 +195,11 @@ void BM_VecSimCommon::TopK_TieredRaftIVFFlat(benchmark::State &st, auto parallel_knn_search = [](AsyncJob *job) { auto *search_job = reinterpret_cast(job); - VecSimQueryParams query_params { .batchSize = 1 }; + VecSimQueryParams query_params{.batchSize = 1}; size_t cur_iter = search_job->iter; - auto results = - VecSimIndex_TopKQuery(INDICES[VecSimAlgo_RAFT_IVFFLAT], QUERIES[cur_iter % N_QUERIES].data(), - search_job->k, &query_params, BY_SCORE); + auto results = VecSimIndex_TopKQuery(INDICES[VecSimAlgo_RAFT_IVFFLAT], + QUERIES[cur_iter % N_QUERIES].data(), search_job->k, + &query_params, BY_SCORE); search_job->all_results[cur_iter] = results; delete job; }; diff --git a/tests/benchmark/bm_vecsim_general.h b/tests/benchmark/bm_vecsim_general.h index 8f60aaa92..6544dcf6f 100644 --- a/tests/benchmark/bm_vecsim_general.h +++ b/tests/benchmark/bm_vecsim_general.h @@ -70,34 +70,39 @@ class BM_VecSimGeneral : public benchmark::Fixture { return params; } - static VecSimParams createDefaultRaftIvfPQParams(size_t dim, uint32_t nLists = 1024, uint32_t nProbes = 20) { + static VecSimParams createDefaultRaftIvfPQParams(size_t dim, uint32_t nLists = 1024, + uint32_t nProbes = 20) { RaftIvfParams ivfparams = {.dim = dim, - .metric = VecSimMetric_L2, - .nLists = nLists, - .kmeans_nIters = 20, - .kmeans_trainsetFraction = 0.5, - .nProbes = nProbes, - .usePQ = true, - .pqBits = 8, - .pqDim = 0, - .codebookKind = RaftIVFPQCodebookKind_PerSubspace, - .lutType = CUDAType_R_32F, - .internalDistanceType = CUDAType_R_32F, - .preferredShmemCarveout = 1.0}; - VecSimParams params{.algo = VecSimAlgo_RAFT_IVFPQ, .algoParams = {.raftIvfParams = ivfparams}}; + .metric = VecSimMetric_L2, + .nLists = nLists, + .kmeans_nIters = 20, + .kmeans_trainsetFraction = 0.8, + .nProbes = nProbes, + .usePQ = true, + .pqBits = 8, + .pqDim = 0, + .codebookKind = RaftIVFPQCodebookKind_PerSubspace, + .lutType = CUDAType_R_32F, + .internalDistanceType = CUDAType_R_32F, + .preferredShmemCarveout = 1.0}; + VecSimParams params{.algo = VecSimAlgo_RAFT_IVFPQ, + .algoParams = {.raftIvfParams = ivfparams}}; return params; } - static VecSimParams createDefaultRaftIvfFlatParams(size_t dim, uint32_t nLists = 1024, uint32_t nProbes = 20, bool adaptiveCenters = true) { + static VecSimParams createDefaultRaftIvfFlatParams(size_t dim, uint32_t nLists = 1024, + uint32_t nProbes = 20, + bool adaptiveCenters = true) { RaftIvfParams ivfparams = {.dim = dim, - .metric = VecSimMetric_L2, // TODO Cosine - .nLists = nLists, - .kmeans_nIters = 20, - .kmeans_trainsetFraction = 0.5, - .nProbes = nProbes, - .usePQ = false, - .adaptiveCenters = adaptiveCenters}; - VecSimParams params{.algo = VecSimAlgo_RAFT_IVFFLAT, .algoParams = {.raftIvfParams = ivfparams}}; + .metric = VecSimMetric_L2, + .nLists = nLists, + .kmeans_nIters = 20, + .kmeans_trainsetFraction = 0.5, + .nProbes = nProbes, + .usePQ = false, + .adaptiveCenters = adaptiveCenters}; + VecSimParams params{.algo = VecSimAlgo_RAFT_IVFFLAT, + .algoParams = {.raftIvfParams = ivfparams}}; return params; } diff --git a/tests/benchmark/bm_vecsim_index.h b/tests/benchmark/bm_vecsim_index.h index 0bf0c27ee..2b7e0278c 100644 --- a/tests/benchmark/bm_vecsim_index.h +++ b/tests/benchmark/bm_vecsim_index.h @@ -127,8 +127,8 @@ void BM_VecSimIndex::Initialize() { .primaryIndexParams = ¶ms_flat, .specificParams = {.tieredRaftIvfParams = {.minVectorsInit = 100}}}; - auto *tiered_raft_ivf_flat_index = - TieredRaftIvfFactory::NewIndex(&tiered_params); + auto *tiered_raft_ivf_flat_index = reinterpret_cast *>( + TieredRaftIvfFactory::NewIndex(&tiered_params)); mock_thread_pool_ivf_flat.ctx->index_strong_ref.reset(tiered_raft_ivf_flat_index); mock_thread_pool_ivf_flat.init_threads(); diff --git a/tests/unit/test_raft_ivf_tiered.cpp b/tests/unit/test_raft_ivf_tiered.cpp index f02b4645f..d50412ec1 100644 --- a/tests/unit/test_raft_ivf_tiered.cpp +++ b/tests/unit/test_raft_ivf_tiered.cpp @@ -10,16 +10,15 @@ #include "mock_thread_pool.h" - template class RaftIvfTieredTest : public ::testing::Test { public: using data_t = typename index_type_t::data_t; using dist_t = typename index_type_t::dist_t; - TieredRaftIvfIndex* createTieredIndex(VecSimParams *params, - tieredIndexMock &mock_thread_pool, - size_t flat_buffer_limit = 0) { + TieredRaftIvfIndex *createTieredIndex(VecSimParams *params, + tieredIndexMock &mock_thread_pool, + size_t flat_buffer_limit = 0) { TieredIndexParams params_tiered = { .jobQueue = &mock_thread_pool.jobQ, .jobQueueCtx = mock_thread_pool.ctx, @@ -38,31 +37,32 @@ class RaftIvfTieredTest : public ::testing::Test { VecSimParams createDefaultPQParams(size_t dim, uint32_t nLists = 3, uint32_t nProbes = 3) { RaftIvfParams ivfparams = {.dim = dim, - .metric = VecSimMetric_L2, - .nLists = nLists, - .kmeans_nIters = 20, - .kmeans_trainsetFraction = 0.5, - .nProbes = nProbes, - .usePQ = true, - .pqBits = 8, - .pqDim = 0, - .codebookKind = RaftIVFPQCodebookKind_PerSubspace, - .lutType = CUDAType_R_32F, - .internalDistanceType = CUDAType_R_32F, - .preferredShmemCarveout = 1.0}; + .metric = VecSimMetric_L2, + .nLists = nLists, + .kmeans_nIters = 20, + .kmeans_trainsetFraction = 0.5, + .nProbes = nProbes, + .usePQ = true, + .pqBits = 8, + .pqDim = 0, + .codebookKind = RaftIVFPQCodebookKind_PerSubspace, + .lutType = CUDAType_R_32F, + .internalDistanceType = CUDAType_R_32F, + .preferredShmemCarveout = 1.0}; VecSimParams params{.algo = VecSimAlgo_RAFT_IVFPQ, .algoParams = {.raftIvfParams = ivfparams}}; return params; } VecSimParams createDefaultFlatParams(size_t dim, uint32_t nLists = 3, uint32_t nProbes = 3) { RaftIvfParams ivfparams = {.dim = dim, - .metric = VecSimMetric_L2, - .nLists = nLists, - .kmeans_nIters = 20, - .kmeans_trainsetFraction = 0.5, - .nProbes = nProbes, - .usePQ = false}; - VecSimParams params{.algo = VecSimAlgo_RAFT_IVFFLAT, .algoParams = {.raftIvfParams = ivfparams}}; + .metric = VecSimMetric_L2, + .nLists = nLists, + .kmeans_nIters = 20, + .kmeans_trainsetFraction = 0.5, + .nProbes = nProbes, + .usePQ = false}; + VecSimParams params{.algo = VecSimAlgo_RAFT_IVFFLAT, + .algoParams = {.raftIvfParams = ivfparams}}; return params; } @@ -79,7 +79,7 @@ TYPED_TEST(RaftIvfTieredTest, end_to_end) { auto mock_thread_pool = tieredIndexMock(); auto *index = this->createTieredIndex(¶ms, mock_thread_pool, flat_buffer_limit); mock_thread_pool.init_threads(); - + VecSimQueryParams queryParams = {.batchSize = 1}; ASSERT_EQ(VecSimIndex_IndexSize(index), 0); @@ -104,7 +104,6 @@ TYPED_TEST(RaftIvfTieredTest, end_to_end) { VecSimIndex_AddVector(index, d_vec.data(), 3); ASSERT_EQ(VecSimIndex_IndexSize(index), 4); - mock_thread_pool.thread_pool_join(); EXPECT_EQ(mock_thread_pool.jobQ.size(), 0); // Callbacks for verifying results. @@ -135,9 +134,8 @@ TYPED_TEST(RaftIvfTieredTest, transferJob) { auto mock_thread_pool = tieredIndexMock(); auto *tiered_index = this->createTieredIndex(¶ms, mock_thread_pool, flat_buffer_limit); auto allocator = tiered_index->getAllocator(); - - VecSimQueryParams queryParams = {.batchSize = 1}; + VecSimQueryParams queryParams = {.batchSize = 1}; // Create a vector and add it to the tiered index. labelType vec_label = 1; @@ -194,7 +192,8 @@ TYPED_TEST(RaftIvfTieredTest, transferJobAsync) { for (size_t i = 0; i < size_t{n / 10}; i++) { TEST_DATA_T expected_vector[dim]; GenerateVector(expected_vector, dim, i); - VecSimQueryReply *res = VecSimIndex_TopKQuery(tiered_index->backendIndex, expected_vector, k, nullptr, BY_SCORE); + VecSimQueryReply *res = VecSimIndex_TopKQuery(tiered_index->backendIndex, expected_vector, + k, nullptr, BY_SCORE); ASSERT_EQ(VecSimQueryReply_GetCode(res), VecSim_QueryReply_OK); ASSERT_EQ(VecSimQueryReply_Len(res), k); ASSERT_EQ(res->results[0].id, i); @@ -234,7 +233,8 @@ TYPED_TEST(RaftIvfTieredTest, transferJob_inplace) { ASSERT_EQ(tiered_index->backendIndex->indexSize(), flat_buffer_limit * 2); ASSERT_EQ(tiered_index->frontendIndex->indexSize(), 2 * (n - flat_buffer_limit)); - // Run a thread loop iteration. The thread should transfer the rest of the vectors to the backend index. + // Run a thread loop iteration. The thread should transfer the rest of the vectors to the + // backend index. mock_thread_pool.thread_iteration(); ASSERT_EQ(tiered_index->indexSize(), 2 * n); ASSERT_EQ(tiered_index->backendIndex->indexSize(), 2 * n); @@ -265,7 +265,7 @@ TYPED_TEST(RaftIvfTieredTest, deleteVector_backend) { } // Use just one thread to transfer all the vectors mock_thread_pool.thread_iteration(); - + // Check that the backend index has the first 12 vectors ASSERT_EQ(tiered_index->indexSize(), n); ASSERT_EQ(tiered_index->backendIndex->indexSize(), n); @@ -273,7 +273,8 @@ TYPED_TEST(RaftIvfTieredTest, deleteVector_backend) { for (size_t i = 0; i < nDelete + 2; i++) { TEST_DATA_T expected_vector[dim]; GenerateVector(expected_vector, dim, i); - VecSimQueryReply *res = VecSimIndex_TopKQuery(tiered_index->backendIndex, expected_vector, k, nullptr, BY_SCORE); + VecSimQueryReply *res = VecSimIndex_TopKQuery(tiered_index->backendIndex, expected_vector, + k, nullptr, BY_SCORE); ASSERT_EQ(VecSimQueryReply_GetCode(res), VecSim_QueryReply_OK); ASSERT_EQ(VecSimQueryReply_Len(res), k); ASSERT_EQ(res->results[0].id, i); From 2071218e8695fc3f97a2a76d5bcfab09d4482e29 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 24 Nov 2023 18:59:34 +0100 Subject: [PATCH 21/28] Remaining USE_CUDA guards --- CMakeLists.txt | 2 +- src/VecSim/index_factories/tiered_factory.cpp | 1 - tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h | 2 ++ 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b3e9e1ffe..2821a02b9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -option(USE_CUDA "Build Cuda code" On) +option(USE_CUDA "Build Cuda code" OFF) if(USE_CUDA) cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) else() diff --git a/src/VecSim/index_factories/tiered_factory.cpp b/src/VecSim/index_factories/tiered_factory.cpp index bae640672..fbc5f8d05 100644 --- a/src/VecSim/index_factories/tiered_factory.cpp +++ b/src/VecSim/index_factories/tiered_factory.cpp @@ -97,7 +97,6 @@ VecSimIndex *NewIndex(const TieredIndexParams *params) { } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFFLAT || params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFPQ) { #ifdef USE_CUDA - VecSimType type = params->primaryIndexParams->algoParams.raftIvfParams.type; return TieredRaftIvfFactory::NewIndex(params); #else throw std::runtime_error("RAFT_IVFFLAT and RAFT_IVFPQ are not supported in CPU version"); diff --git a/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h b/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h index 1ca8f1f29..20826335d 100644 --- a/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h +++ b/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h @@ -21,9 +21,11 @@ BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimCommon, BM_FUNC_NAME(Memory, Tiered), fp32_ BENCHMARK_REGISTER_F(BM_VecSimCommon, BM_FUNC_NAME(Memory, Tiered))->Iterations(1); // Memory TieredRaftIVFFlat +#ifdef USE_CUDA BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimCommon, BM_FUNC_NAME(Memory, TieredRaftIVFFlat), fp32_index_t) (benchmark::State &st) { Memory_TieredRaftIVFFlat(st); } BENCHMARK_REGISTER_F(BM_VecSimCommon, BM_FUNC_NAME(Memory, TieredRaftIVFFlat))->Iterations(1); +#endif // AddLabel BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimBasics, BM_ADD_LABEL, fp32_index_t) From b735b2072cfd5283c73ca36c540252e96b817054 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 5 Dec 2023 13:25:20 +0100 Subject: [PATCH 22/28] Fix thread pool benchmark --- .../algorithms/brute_force/brute_force_multi.h | 2 ++ .../algorithms/brute_force/brute_force_single.h | 2 ++ tests/benchmark/bm_common.h | 2 +- tests/benchmark/bm_vecsim_basics.h | 5 ++++- tests/benchmark/bm_vecsim_index.h | 13 +++++-------- 5 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/VecSim/algorithms/brute_force/brute_force_multi.h b/src/VecSim/algorithms/brute_force/brute_force_multi.h index c93776096..6933c4c20 100644 --- a/src/VecSim/algorithms/brute_force/brute_force_multi.h +++ b/src/VecSim/algorithms/brute_force/brute_force_multi.h @@ -26,7 +26,9 @@ class BruteForceIndex_Multi : public BruteForceIndex { void clear() override { this->labelToIdsLookup.clear(); this->idToLabelMapping.clear(); + this->idToLabelMapping.shrink_to_fit(); this->vectorBlocks.clear(); + this->vectorBlocks.shrink_to_fit(); this->count = idType{}; } int addVector(const void *vector_data, labelType label, void *auxiliaryCtx = nullptr) override; diff --git a/src/VecSim/algorithms/brute_force/brute_force_single.h b/src/VecSim/algorithms/brute_force/brute_force_single.h index 88237c338..ea0adc3ff 100644 --- a/src/VecSim/algorithms/brute_force/brute_force_single.h +++ b/src/VecSim/algorithms/brute_force/brute_force_single.h @@ -24,7 +24,9 @@ class BruteForceIndex_Single : public BruteForceIndex { void clear() override { this->labelToIdLookup.clear(); this->idToLabelMapping.clear(); + this->idToLabelMapping.shrink_to_fit(); this->vectorBlocks.clear(); + this->vectorBlocks.shrink_to_fit(); this->count = idType{}; } int addVector(const void *vector_data, labelType label, void *auxiliaryCtx = nullptr) override; diff --git a/tests/benchmark/bm_common.h b/tests/benchmark/bm_common.h index b5f86decd..858bf75fe 100644 --- a/tests/benchmark/bm_common.h +++ b/tests/benchmark/bm_common.h @@ -210,7 +210,7 @@ void BM_VecSimCommon::TopK_TieredRaftIVFFlat(benchmark::State &st, tiered_index, k, 0, iter++, all_results); tiered_index->submitSingleJob(search_job); if (iter == total_iters) { - BM_VecSimGeneral::mock_thread_pool_raft.thread_pool_wait(); + BM_VecSimGeneral::mock_thread_pool.thread_pool_wait(); } } diff --git a/tests/benchmark/bm_vecsim_basics.h b/tests/benchmark/bm_vecsim_basics.h index 0ddfb240d..a57fb3ac1 100644 --- a/tests/benchmark/bm_vecsim_basics.h +++ b/tests/benchmark/bm_vecsim_basics.h @@ -123,7 +123,10 @@ void BM_VecSimBasics::AddLabel_AsyncIngest(benchmark::State &st) { size_t new_label_count = (INDICES[st.range(0)])->indexLabelCount(); // Remove directly inplace from the underline HNSW index. for (size_t label_ = initial_label_count; label_ < new_label_count; label_++) { - VecSimIndex_DeleteVector(INDICES[VecSimAlgo_HNSWLIB], label_); + if (st.range(0) == VecSimAlgo_TIERED) + VecSimIndex_DeleteVector(INDICES[VecSimAlgo_HNSWLIB], label_); + else + VecSimIndex_DeleteVector(INDICES[st.range(0)], label_); } assert(VecSimIndex_IndexSize(INDICES[st.range(0)]) == N_VECTORS); diff --git a/tests/benchmark/bm_vecsim_index.h b/tests/benchmark/bm_vecsim_index.h index 2b7e0278c..8ae115d47 100644 --- a/tests/benchmark/bm_vecsim_index.h +++ b/tests/benchmark/bm_vecsim_index.h @@ -117,11 +117,10 @@ void BM_VecSimIndex::Initialize() { #ifdef USE_CUDA // Create RAFFT IVF Flat tiered index. - auto &mock_thread_pool_ivf_flat = BM_VecSimGeneral::mock_thread_pool_raft; - - VecSimParams params_flat = createDefaultRaftIvfFlatParams(dim, 1000, 100); - tiered_params = {.jobQueue = &BM_VecSimGeneral::mock_thread_pool_raft.jobQ, - .jobQueueCtx = mock_thread_pool_ivf_flat.ctx, + // Use one unique thread pool for the tiered index by changing the thread pool context. + VecSimParams params_flat = createDefaultRaftIvfFlatParams(dim, 10000, 100, false); + tiered_params = {.jobQueue = &BM_VecSimGeneral::mock_thread_pool.jobQ, + .jobQueueCtx = mock_thread_pool.ctx, .submitCb = tieredIndexMock::submit_callback, .flatBufferLimit = params_flat.algoParams.raftIvfParams.nLists * 5000, .primaryIndexParams = ¶ms_flat, @@ -129,8 +128,6 @@ void BM_VecSimIndex::Initialize() { auto *tiered_raft_ivf_flat_index = reinterpret_cast *>( TieredRaftIvfFactory::NewIndex(&tiered_params)); - mock_thread_pool_ivf_flat.ctx->index_strong_ref.reset(tiered_raft_ivf_flat_index); - mock_thread_pool_ivf_flat.init_threads(); indices.push_back(tiered_raft_ivf_flat_index); #endif @@ -145,7 +142,7 @@ void BM_VecSimIndex::Initialize() { VecSimIndex_AddVector(indices[VecSimAlgo_RAFT_IVFFLAT], blob, label); #endif } - mock_thread_pool_ivf_flat.thread_pool_wait(100); + mock_thread_pool.thread_pool_wait(100); // Load the test query vectors form file. Index file path is relative to repository root dir. loadTestVectors(AttachRootPath(test_queries_file), type); From 07909823109e93750e5c254eaaa634b65ece3c6a Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 5 Dec 2023 13:26:18 +0100 Subject: [PATCH 23/28] USE_CUDA fix --- src/VecSim/CMakeLists.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/VecSim/CMakeLists.txt b/src/VecSim/CMakeLists.txt index 279fc5322..46d89cbc3 100644 --- a/src/VecSim/CMakeLists.txt +++ b/src/VecSim/CMakeLists.txt @@ -44,8 +44,8 @@ PUBLIC $<$:VectorSimilaritySerializer> PRIVATE $<$:raft::raft> - CUDA::cusolver - CUDA::cublas - CUDA::curand - CUDA::cusparse + $<$:CUDA::cusolver> + $<$:CUDA::cublas> + $<$:CUDA::curand> + $<$:CUDA::cusparse> ) From 4c9024862aeb6ed6473486b2cd5aaec212e1abe3 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 7 Dec 2023 17:02:45 +0100 Subject: [PATCH 24/28] Fix compilation --- CMakeLists.txt | 2 +- src/VecSim/algorithms/raft_ivf/ivf.cuh | 2 ++ src/VecSim/memory/vecsim_malloc.h | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2821a02b9..2e6569186 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,9 +40,9 @@ if (USE_CUDA) include(cmake/raft.cmake) # Required flags for compiling RAFT set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr") - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -std=c++17") set(CMAKE_CUDA_FLAGS_RELEASE "-O3") set(CMAKE_CUDA_FLAGS_DEBUG "-g") + set(CMAKE_CUDA_STANDARD 17) endif() diff --git a/src/VecSim/algorithms/raft_ivf/ivf.cuh b/src/VecSim/algorithms/raft_ivf/ivf.cuh index fd70cd372..f05fe007d 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf.cuh +++ b/src/VecSim/algorithms/raft_ivf/ivf.cuh @@ -24,7 +24,9 @@ #include #include #include +#include #include +#include #include inline auto constexpr GetRaftDistanceType(VecSimMetric vsm) { diff --git a/src/VecSim/memory/vecsim_malloc.h b/src/VecSim/memory/vecsim_malloc.h index e25cf6e6b..56f681ac3 100644 --- a/src/VecSim/memory/vecsim_malloc.h +++ b/src/VecSim/memory/vecsim_malloc.h @@ -25,7 +25,7 @@ struct VecSimAllocator { static size_t allocation_header_size; static VecSimMemoryFunctions memFunctions; - VecSimAllocator() : allocated(std::atomic_uint64_t(sizeof(VecSimAllocator))) {} + VecSimAllocator() : allocated((sizeof(VecSimAllocator))) {} public: static std::shared_ptr newVecsimAllocator(); From 7805aa174dbb931fa86a02c53f8b048ea98f70df Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 22 Dec 2023 01:06:00 +0100 Subject: [PATCH 25/28] Add ivfpq bench --- CMakeLists.txt | 5 +- src/VecSim/algorithms/raft_ivf/ivf.cuh | 2 +- src/VecSim/algorithms/raft_ivf/ivf_tiered.h | 25 +++++-- tests/benchmark/bm_common.h | 74 ++++++++++++------- .../bm_basics_initialize_fp32.h | 35 +++++++-- tests/benchmark/bm_vecsim_general.h | 2 +- tests/benchmark/bm_vecsim_index.h | 22 +++++- tests/unit/test_raft_ivf_tiered.cpp | 8 +- 8 files changed, 123 insertions(+), 50 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2e6569186..70945bf0d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,7 +43,10 @@ if (USE_CUDA) set(CMAKE_CUDA_FLAGS_RELEASE "-O3") set(CMAKE_CUDA_FLAGS_DEBUG "-g") set(CMAKE_CUDA_STANDARD 17) - + if(${CUDAToolkit_VERSION_MAJOR} GREATER 10) + # cuda11 support --threads for compile some large .cu more efficient + add_compile_options($<$:--threads=4>) + endif() endif() # Only do these if this is the main project, and not if it is included through add_subdirectory diff --git a/src/VecSim/algorithms/raft_ivf/ivf.cuh b/src/VecSim/algorithms/raft_ivf/ivf.cuh index f05fe007d..35c7ed854 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf.cuh +++ b/src/VecSim/algorithms/raft_ivf/ivf.cuh @@ -286,11 +286,11 @@ public: raft::copy(neighbors.data(), neighbors_gpu.data_handle(), k, res.get_stream()); raft::copy(distances.data(), distances_gpu.data_handle(), k, res.get_stream()); + result_list->results.resize(k); // Ensure search is complete and data have been copied back before // building query result objects on host res.sync_stream(); - result_list->results.resize(k); for (auto i = 0; i < k; ++i) { result_list->results[i].id = idToLabelLookup_[neighbors[i]]; result_list->results[i].score = distances[i]; diff --git a/src/VecSim/algorithms/raft_ivf/ivf_tiered.h b/src/VecSim/algorithms/raft_ivf/ivf_tiered.h index 2a36754b4..a930f5863 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf_tiered.h +++ b/src/VecSim/algorithms/raft_ivf/ivf_tiered.h @@ -34,7 +34,8 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { if (this->frontendIndex->indexSize() >= this->flatBufferLimit) { // If the backend index is empty, build it with all the vectors // Otherwise, just add the vector to the backend index - executeTransferJob(true); + auto temp_job = RAFTTransferJob(this->allocator, executeTransferJobWrapper, this, true); + executeTransferJob(&temp_job); } // If the backend index is already built and that the write mode is in place @@ -49,12 +50,15 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { // Otherwise, add the vector to the flat index this->flatIndexGuard.lock(); ret = this->frontendIndex->addVector(blob, label); - this->flatIndexGuard.unlock(); // Submit a transfer job AsyncJob *new_insert_job = new (this->allocator) RAFTTransferJob(this->allocator, executeTransferJobWrapper, this); this->submitSingleJob(new_insert_job); + + // Update the pointer to the latest transfer job + this->pendingTransferJob = reinterpret_cast(new_insert_job); + this->flatIndexGuard.unlock(); return ret; } @@ -120,7 +124,7 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { auto *transfer_job = reinterpret_cast(job); auto *job_index = reinterpret_cast *>(transfer_job->index); - job_index->executeTransferJob(transfer_job->force_); + job_index->executeTransferJob(transfer_job); } delete job; } @@ -160,11 +164,14 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { private: size_t minVectorsInit = 1; + // This ptr is designating the latest transfer job. It is protected by flat buffer lock + volatile RAFTTransferJob* pendingTransferJob = nullptr; + inline auto *getBackendIndex() const { return dynamic_cast *>(this->backendIndex); } - void executeTransferJob(bool force = false) { + void executeTransferJob(RAFTTransferJob *job) { size_t nVectors = this->frontendIndex->indexSize(); // No vectors to transfer if (nVectors == 0) { @@ -173,7 +180,7 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { // Don't transfer less than nLists * minVectorsInit vectors if the backend index is empty // (for kmeans initialization purposes) - if (!force) { + if (!job->force_) { auto main_nVectors = this->backendIndex->indexSize(); size_t min_nVectors = 1; if (main_nVectors == 0) @@ -184,8 +191,14 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { } } - // Check that there are still vectors to transfer after exclusive lock this->flatIndexGuard.lock(); + // Check that the job has not been cancelled while waiting for the lock + // and that the job is the latest one if there is no force flag + if (!job->isValid || (this->pendingTransferJob != job && !job->force_)) { + this->flatIndexGuard.unlock(); + return; + } + // Check that there are still vectors to transfer after exclusive lock nVectors = this->frontendIndex->indexSize(); if (nVectors == 0) { this->flatIndexGuard.unlock(); diff --git a/tests/benchmark/bm_common.h b/tests/benchmark/bm_common.h index 858bf75fe..aebb6ae26 100644 --- a/tests/benchmark/bm_common.h +++ b/tests/benchmark/bm_common.h @@ -29,19 +29,16 @@ class BM_VecSimCommon : public BM_VecSimIndex { static void TopK_HNSW(benchmark::State &st, unsigned short index_offset = 0); static void TopK_Tiered(benchmark::State &st, unsigned short index_offset = 0); #ifdef USE_CUDA - // Run TopK using Raft IVF Flat tiered and flat index and calculate the recall of the Raft IVF - // Flat algorithm with respect to the results returned by the flat index. - static void TopK_TieredRaftIVFFlat(benchmark::State &st, unsigned short index_offset = 0); - // Run TopK using both Raft IVF PQ Tiered and flat index and calculate the recall of the Raft - // IVF PQ algorithm with respect to the results returned by the flat index. - static void TopK_TieredRaftIVFPQ(benchmark::State &st, unsigned short index_offset = 0); + // Run TopK using Raft IVF tiered and flat index and calculate the recall of the Raft IVF + // algorithm with respect to the results returned by the flat index. + static void TopK_TieredRaftIVF(benchmark::State &st, unsigned short index_offset = 0); #endif // Does nothing but returning the index memory. static void Memory_FLAT(benchmark::State &st, unsigned short index_offset = 0); static void Memory_HNSW(benchmark::State &st, unsigned short index_offset = 0); static void Memory_Tiered(benchmark::State &st, unsigned short index_offset = 0); - static void Memory_TieredRaftIVFFlat(benchmark::State &st, unsigned short index_offset = 0); + static void Memory_TieredRaftIVF(benchmark::State &st, unsigned short index_offset = 0); }; template @@ -95,7 +92,7 @@ void BM_VecSimCommon::Memory_Tiered(benchmark::State &st, (double)VecSimIndex_Info(INDICES[VecSimAlgo_TIERED + index_offset]).commonInfo.memory; } template -void BM_VecSimCommon::Memory_TieredRaftIVFFlat(benchmark::State &st, +void BM_VecSimCommon::Memory_TieredRaftIVF(benchmark::State &st, unsigned short index_offset) { for (auto _ : st) { @@ -181,19 +178,20 @@ void BM_VecSimCommon::TopK_Tiered(benchmark::State &st, unsigned s #ifdef USE_CUDA template -void BM_VecSimCommon::TopK_TieredRaftIVFFlat(benchmark::State &st, - unsigned short index_offset) { +void BM_VecSimCommon::TopK_TieredRaftIVF(benchmark::State &st, + unsigned short index_offset) { size_t k = st.range(0); size_t n_probes = st.range(1); std::atomic_int correct = 0; std::atomic_int iter = 0; auto *tiered_index = - reinterpret_cast *>(INDICES[VecSimAlgo_RAFT_IVFFLAT]); + reinterpret_cast *>(INDICES[VecSimAlgo_RAFT_IVFFLAT + index_offset]); size_t total_iters = 50; tiered_index->setNProbes(n_probes); VecSimQueryReply *all_results[total_iters]; - auto parallel_knn_search = [](AsyncJob *job) { + // Declare 2 lambda to avoid changing AsyncJob type for the JobMock. + auto parallel_knn_search_flat = [](AsyncJob *job) { auto *search_job = reinterpret_cast(job); VecSimQueryParams query_params{.batchSize = 1}; size_t cur_iter = search_job->iter; @@ -204,11 +202,31 @@ void BM_VecSimCommon::TopK_TieredRaftIVFFlat(benchmark::State &st, delete job; }; + auto parallel_knn_search_pq = [](AsyncJob *job) { + auto *search_job = reinterpret_cast(job); + VecSimQueryParams query_params{.batchSize = 1}; + size_t cur_iter = search_job->iter; + auto results = VecSimIndex_TopKQuery(INDICES[VecSimAlgo_RAFT_IVFPQ], + QUERIES[cur_iter % N_QUERIES].data(), search_job->k, + &query_params, BY_SCORE); + search_job->all_results[cur_iter] = results; + delete job; + }; + for (auto _ : st) { - auto search_job = new (tiered_index->getAllocator()) - tieredIndexMock::SearchJobMock(tiered_index->getAllocator(), parallel_knn_search, - tiered_index, k, 0, iter++, all_results); - tiered_index->submitSingleJob(search_job); + if (index_offset == 0) // Flat + { + auto search_job = new (tiered_index->getAllocator()) + tieredIndexMock::SearchJobMock(tiered_index->getAllocator(), parallel_knn_search_flat, + tiered_index, k, 0, iter++, all_results); + tiered_index->submitSingleJob(search_job); + } else // PQ + { + auto search_job = new (tiered_index->getAllocator()) + tieredIndexMock::SearchJobMock(tiered_index->getAllocator(), parallel_knn_search_pq, + tiered_index, k, 0, iter++, all_results); + tiered_index->submitSingleJob(search_job); + } if (iter == total_iters) { BM_VecSimGeneral::mock_thread_pool.thread_pool_wait(); } @@ -217,7 +235,7 @@ void BM_VecSimCommon::TopK_TieredRaftIVFFlat(benchmark::State &st, // Measure recall for (iter = 0; iter < total_iters; iter++) { auto bf_results = - VecSimIndex_TopKQuery(INDICES[VecSimAlgo_BF + index_offset], + VecSimIndex_TopKQuery(INDICES[VecSimAlgo_BF], QUERIES[iter % N_QUERIES].data(), k, nullptr, BY_SCORE); BM_VecSimGeneral::MeasureRecall(all_results[iter], bf_results, correct); @@ -264,18 +282,18 @@ void BM_VecSimCommon::TopK_TieredRaftIVFFlat(benchmark::State &st, ->Unit(benchmark::kMillisecond) #ifdef USE_CUDA -#define REGISTER_TopK_TieredRaftIVF(BM_CLASS, BM_FUNC) \ +#define REGISTER_TopK_TieredRaftIVF(BM_CLASS, BM_FUNC) \ BENCHMARK_REGISTER_F(BM_CLASS, BM_FUNC) \ - ->Args({10, 20}) \ - ->Args({10, 50}) \ - ->Args({10, 150}) \ - ->Args({100, 20}) \ - ->Args({100, 50}) \ - ->Args({100, 150}) \ - ->Args({200, 20}) \ - ->Args({200, 50}) \ - ->Args({200, 150}) \ - ->ArgNames({"k", "n_probes"}) \ + ->Args({10, 200}) \ + ->Args({10, 500}) \ + ->Args({10, 1500}) \ + ->Args({100, 200}) \ + ->Args({100, 500}) \ + ->Args({100, 1500}) \ + ->Args({200, 200}) \ + ->Args({200, 500}) \ + ->Args({200, 1500}) \ + ->ArgNames({"k", "n_probes"}) \ ->Iterations(50) \ ->Unit(benchmark::kMillisecond) #endif diff --git a/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h b/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h index 20826335d..9a60453c5 100644 --- a/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h +++ b/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h @@ -20,11 +20,15 @@ BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimCommon, BM_FUNC_NAME(Memory, Tiered), fp32_ (benchmark::State &st) { Memory_Tiered(st); } BENCHMARK_REGISTER_F(BM_VecSimCommon, BM_FUNC_NAME(Memory, Tiered))->Iterations(1); -// Memory TieredRaftIVFFlat #ifdef USE_CUDA +// Memory TieredRaftIVFFlat BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimCommon, BM_FUNC_NAME(Memory, TieredRaftIVFFlat), fp32_index_t) -(benchmark::State &st) { Memory_TieredRaftIVFFlat(st); } +(benchmark::State &st) { Memory_TieredRaftIVF(st); } BENCHMARK_REGISTER_F(BM_VecSimCommon, BM_FUNC_NAME(Memory, TieredRaftIVFFlat))->Iterations(1); +// Memory TieredRaftIVFPQ +BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimCommon, BM_FUNC_NAME(Memory, TieredRaftIVFPQ), fp32_index_t) +(benchmark::State &st) { Memory_TieredRaftIVF(st, 1); } +BENCHMARK_REGISTER_F(BM_VecSimCommon, BM_FUNC_NAME(Memory, TieredRaftIVFPQ))->Iterations(1); #endif // AddLabel @@ -32,7 +36,10 @@ BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimBasics, BM_ADD_LABEL, fp32_index_t) (benchmark::State &st) { AddLabel(st); } REGISTER_AddLabel(BM_ADD_LABEL, VecSimAlgo_BF); REGISTER_AddLabel(BM_ADD_LABEL, VecSimAlgo_HNSWLIB); - +#ifdef USE_CUDA +REGISTER_AddLabel(BM_ADD_LABEL, VecSimAlgo_RAFT_IVFFLAT); +REGISTER_AddLabel(BM_ADD_LABEL, VecSimAlgo_RAFT_IVFPQ); +#endif // DeleteLabel Registration. Definition is placed in the .cpp file. REGISTER_DeleteLabel(BM_FUNC_NAME(DeleteLabel, BF)); REGISTER_DeleteLabel(BM_FUNC_NAME(DeleteLabel, HNSW)); @@ -54,10 +61,14 @@ REGISTER_TopK_Tiered(BM_VecSimCommon, BM_FUNC_NAME(TopK, Tiered)); #ifdef USE_CUDA // TopK Tiered RAFT IVF Flat -BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimCommon, BM_FUNC_NAME(TopK, TieredRaftIVF), fp32_index_t) -(benchmark::State &st) { TopK_TieredRaftIVFFlat(st); } -REGISTER_TopK_TieredRaftIVF(BM_VecSimCommon, BM_FUNC_NAME(TopK, TieredRaftIVF)); - +BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimCommon, BM_FUNC_NAME(TopK, TieredRaftIVFFLAT), fp32_index_t) +(benchmark::State &st) { TopK_TieredRaftIVF(st, 0); } +REGISTER_TopK_TieredRaftIVF(BM_VecSimCommon, BM_FUNC_NAME(TopK, TieredRaftIVFFLAT)); + +// TopK Tiered RAFT IVF PQ +BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimCommon, BM_FUNC_NAME(TopK, TieredRaftIVFPQ), fp32_index_t) +(benchmark::State &st) { TopK_TieredRaftIVF(st, 1); } +REGISTER_TopK_TieredRaftIVF(BM_VecSimCommon, BM_FUNC_NAME(TopK, TieredRaftIVFPQ)); #endif // Range BF @@ -87,3 +98,13 @@ BENCHMARK_REGISTER_F(BM_VecSimBasics, BM_DELETE_LABEL_ASYNC) ->Arg(100) ->Arg(BM_VecSimGeneral::block_size) ->ArgName("SwapJobsThreshold"); + +// Tiered RAFT IVF Flat add_async benchmarks +#ifdef USE_CUDA +BENCHMARK_REGISTER_F(BM_VecSimBasics, BM_ADD_LABEL_ASYNC) + ->UNIT_AND_ITERATIONS->Arg(VecSimAlgo_RAFT_IVFFLAT) + ->ArgName("VecSimAlgo_RAFT_IVFFLAT"); +BENCHMARK_REGISTER_F(BM_VecSimBasics, BM_ADD_LABEL_ASYNC) + ->UNIT_AND_ITERATIONS->Arg(VecSimAlgo_RAFT_IVFPQ) + ->ArgName("VecSimAlgo_RAFT_IVFPQ"); +#endif diff --git a/tests/benchmark/bm_vecsim_general.h b/tests/benchmark/bm_vecsim_general.h index 6544dcf6f..9a4a3bb28 100644 --- a/tests/benchmark/bm_vecsim_general.h +++ b/tests/benchmark/bm_vecsim_general.h @@ -76,7 +76,7 @@ class BM_VecSimGeneral : public benchmark::Fixture { .metric = VecSimMetric_L2, .nLists = nLists, .kmeans_nIters = 20, - .kmeans_trainsetFraction = 0.8, + .kmeans_trainsetFraction = 0.5, .nProbes = nProbes, .usePQ = true, .pqBits = 8, diff --git a/tests/benchmark/bm_vecsim_index.h b/tests/benchmark/bm_vecsim_index.h index 8ae115d47..7291073ad 100644 --- a/tests/benchmark/bm_vecsim_index.h +++ b/tests/benchmark/bm_vecsim_index.h @@ -122,14 +122,31 @@ void BM_VecSimIndex::Initialize() { tiered_params = {.jobQueue = &BM_VecSimGeneral::mock_thread_pool.jobQ, .jobQueueCtx = mock_thread_pool.ctx, .submitCb = tieredIndexMock::submit_callback, - .flatBufferLimit = params_flat.algoParams.raftIvfParams.nLists * 5000, + .flatBufferLimit = n_vectors, .primaryIndexParams = ¶ms_flat, - .specificParams = {.tieredRaftIvfParams = {.minVectorsInit = 100}}}; + .specificParams = {.tieredRaftIvfParams = {.minVectorsInit = + size_t(1000000 / params_pq.algoParams.raftIvfParams.nLists)}}}; auto *tiered_raft_ivf_flat_index = reinterpret_cast *>( TieredRaftIvfFactory::NewIndex(&tiered_params)); indices.push_back(tiered_raft_ivf_flat_index); + + // Create RAFT IVF PQ tiered index. + // Use one unique thread pool for the tiered index by changing the thread pool context. + VecSimParams params_pq = createDefaultRaftIvfPQParams(dim, 5000, 100); + tiered_params = {.jobQueue = &BM_VecSimGeneral::mock_thread_pool.jobQ, + .jobQueueCtx = mock_thread_pool.ctx, + .submitCb = tieredIndexMock::submit_callback, + .flatBufferLimit = n_vectors, + .primaryIndexParams = ¶ms_pq, + .specificParams = {.tieredRaftIvfParams = { + .minVectorsInit = size_t(1000000 / params_pq.algoParams.raftIvfParams.nLists)}}}; + + auto *tiered_raft_ivf_pq_index = reinterpret_cast *>( + TieredRaftIvfFactory::NewIndex(&tiered_params)); + + indices.push_back(tiered_raft_ivf_pq_index); #endif // Add the same vectors to Flat index. @@ -140,6 +157,7 @@ void BM_VecSimIndex::Initialize() { VecSimIndex_AddVector(indices[VecSimAlgo_BF], blob, label); #ifdef USE_CUDA VecSimIndex_AddVector(indices[VecSimAlgo_RAFT_IVFFLAT], blob, label); + VecSimIndex_AddVector(indices[VecSimAlgo_RAFT_IVFPQ], blob, label); #endif } mock_thread_pool.thread_pool_wait(100); diff --git a/tests/unit/test_raft_ivf_tiered.cpp b/tests/unit/test_raft_ivf_tiered.cpp index d50412ec1..b2927a76a 100644 --- a/tests/unit/test_raft_ivf_tiered.cpp +++ b/tests/unit/test_raft_ivf_tiered.cpp @@ -233,9 +233,9 @@ TYPED_TEST(RaftIvfTieredTest, transferJob_inplace) { ASSERT_EQ(tiered_index->backendIndex->indexSize(), flat_buffer_limit * 2); ASSERT_EQ(tiered_index->frontendIndex->indexSize(), 2 * (n - flat_buffer_limit)); - // Run a thread loop iteration. The thread should transfer the rest of the vectors to the + // Run the thread pool. The thread should transfer the rest of the vectors to the // backend index. - mock_thread_pool.thread_iteration(); + mock_thread_pool.thread_pool_wait(100); ASSERT_EQ(tiered_index->indexSize(), 2 * n); ASSERT_EQ(tiered_index->backendIndex->indexSize(), 2 * n); ASSERT_EQ(tiered_index->frontendIndex->indexSize(), 0); @@ -251,7 +251,7 @@ TYPED_TEST(RaftIvfTieredTest, deleteVector_backend) { size_t k = 1; // Create RaftIvfTiered index instance with a mock queue. - VecSimParams params = createDefaultFlatParams(dim, nLists, 20); + VecSimParams params = createDefaultFlatParams(dim, nLists, nLists); auto mock_thread_pool = tieredIndexMock(); auto *tiered_index = this->createTieredIndex(¶ms, mock_thread_pool, flat_buffer_limit); @@ -264,7 +264,7 @@ TYPED_TEST(RaftIvfTieredTest, deleteVector_backend) { GenerateAndAddVector(tiered_index, dim, i, i); } // Use just one thread to transfer all the vectors - mock_thread_pool.thread_iteration(); + mock_thread_pool.thread_pool_wait(100); // Check that the backend index has the first 12 vectors ASSERT_EQ(tiered_index->indexSize(), n); From f8c02efb61d7c50014eff48cdf09110dd28b174e Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 8 Jan 2024 16:57:08 +0100 Subject: [PATCH 26/28] Add test for Cosine and IP --- src/VecSim/algorithms/raft_ivf/ivf.cuh | 13 ++- src/VecSim/algorithms/raft_ivf/ivf_tiered.h | 10 +- tests/benchmark/bm_vecsim_general.h | 4 +- tests/benchmark/bm_vecsim_index.h | 6 +- tests/unit/test_raft_ivf_tiered.cpp | 111 ++++++++++++++++++-- 5 files changed, 123 insertions(+), 21 deletions(-) diff --git a/src/VecSim/algorithms/raft_ivf/ivf.cuh b/src/VecSim/algorithms/raft_ivf/ivf.cuh index 35c7ed854..afb45a9b0 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf.cuh +++ b/src/VecSim/algorithms/raft_ivf/ivf.cuh @@ -36,6 +36,7 @@ inline auto constexpr GetRaftDistanceType(VecSimMetric vsm) { result = raft::distance::DistanceType::L2Expanded; break; case VecSimMetric_IP: + case VecSimMetric_Cosine: result = raft::distance::DistanceType::InnerProduct; break; default: @@ -93,7 +94,7 @@ struct RaftIvfIndex : public RaftIvfInterface { using dist_type = DistType; private: - // Allow either IVF-flat or IVFPQ parameters + // Allow either IVF-Flat or IVF-PQ parameters using build_params_t = std::variant; using search_params_t = std::variantmetric == VecSimMetric_Cosine || raftIvfParams->metric == VecSimMetric_IP; } int addVector(const void *vector_data, labelType label, void *auxiliaryCtx = nullptr) override { return addVectorBatch(vector_data, &label, 1, auxiliaryCtx); @@ -293,7 +294,11 @@ public: for (auto i = 0; i < k; ++i) { result_list->results[i].id = idToLabelLookup_[neighbors[i]]; - result_list->results[i].score = distances[i]; + if (cosine_postprocess_) { + result_list->results[i].score = 1.0f - distances[i]; + } else { + result_list->results[i].score = distances[i]; + } } return result_list; @@ -375,4 +380,6 @@ private: vecsim_stl::vector idToLabelLookup_; vecsim_stl::unordered_map labelToIdLookup_; + + bool cosine_postprocess_ = false; }; diff --git a/src/VecSim/algorithms/raft_ivf/ivf_tiered.h b/src/VecSim/algorithms/raft_ivf/ivf_tiered.h index a930f5863..99d8c14c5 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf_tiered.h +++ b/src/VecSim/algorithms/raft_ivf/ivf_tiered.h @@ -50,15 +50,13 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { // Otherwise, add the vector to the flat index this->flatIndexGuard.lock(); ret = this->frontendIndex->addVector(blob, label); + this->flatIndexGuard.unlock(); // Submit a transfer job AsyncJob *new_insert_job = new (this->allocator) RAFTTransferJob(this->allocator, executeTransferJobWrapper, this); this->submitSingleJob(new_insert_job); - // Update the pointer to the latest transfer job - this->pendingTransferJob = reinterpret_cast(new_insert_job); - this->flatIndexGuard.unlock(); return ret; } @@ -165,7 +163,6 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { size_t minVectorsInit = 1; // This ptr is designating the latest transfer job. It is protected by flat buffer lock - volatile RAFTTransferJob* pendingTransferJob = nullptr; inline auto *getBackendIndex() const { return dynamic_cast *>(this->backendIndex); @@ -193,8 +190,7 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { this->flatIndexGuard.lock(); // Check that the job has not been cancelled while waiting for the lock - // and that the job is the latest one if there is no force flag - if (!job->isValid || (this->pendingTransferJob != job && !job->force_)) { + if (!job->isValid) { this->flatIndexGuard.unlock(); return; } @@ -244,5 +240,7 @@ struct TieredRaftIvfIndex : public VecSimTieredIndex { INDEX_TEST_FRIEND_CLASS(RaftIvfTieredTest_transferJobAsync_Test) INDEX_TEST_FRIEND_CLASS(RaftIvfTieredTest_transferJob_inplace_Test) INDEX_TEST_FRIEND_CLASS(RaftIvfTieredTest_deleteVector_backend_Test) + INDEX_TEST_FRIEND_CLASS(RaftIvfTieredTest_searchMetricCosine_Test) + INDEX_TEST_FRIEND_CLASS(RaftIvfTieredTest_searchMetricIP_Test) #endif }; diff --git a/tests/benchmark/bm_vecsim_general.h b/tests/benchmark/bm_vecsim_general.h index 9a4a3bb28..21074101f 100644 --- a/tests/benchmark/bm_vecsim_general.h +++ b/tests/benchmark/bm_vecsim_general.h @@ -73,7 +73,7 @@ class BM_VecSimGeneral : public benchmark::Fixture { static VecSimParams createDefaultRaftIvfPQParams(size_t dim, uint32_t nLists = 1024, uint32_t nProbes = 20) { RaftIvfParams ivfparams = {.dim = dim, - .metric = VecSimMetric_L2, + .metric = VecSimMetric_Cosine, .nLists = nLists, .kmeans_nIters = 20, .kmeans_trainsetFraction = 0.5, @@ -94,7 +94,7 @@ class BM_VecSimGeneral : public benchmark::Fixture { uint32_t nProbes = 20, bool adaptiveCenters = true) { RaftIvfParams ivfparams = {.dim = dim, - .metric = VecSimMetric_L2, + .metric = VecSimMetric_Cosine, .nLists = nLists, .kmeans_nIters = 20, .kmeans_trainsetFraction = 0.5, diff --git a/tests/benchmark/bm_vecsim_index.h b/tests/benchmark/bm_vecsim_index.h index 7291073ad..436b65f0c 100644 --- a/tests/benchmark/bm_vecsim_index.h +++ b/tests/benchmark/bm_vecsim_index.h @@ -125,7 +125,7 @@ void BM_VecSimIndex::Initialize() { .flatBufferLimit = n_vectors, .primaryIndexParams = ¶ms_flat, .specificParams = {.tieredRaftIvfParams = {.minVectorsInit = - size_t(1000000 / params_pq.algoParams.raftIvfParams.nLists)}}}; + size_t(n_vectors / params_flat.algoParams.raftIvfParams.nLists)}}}; auto *tiered_raft_ivf_flat_index = reinterpret_cast *>( TieredRaftIvfFactory::NewIndex(&tiered_params)); @@ -140,8 +140,8 @@ void BM_VecSimIndex::Initialize() { .submitCb = tieredIndexMock::submit_callback, .flatBufferLimit = n_vectors, .primaryIndexParams = ¶ms_pq, - .specificParams = {.tieredRaftIvfParams = { - .minVectorsInit = size_t(1000000 / params_pq.algoParams.raftIvfParams.nLists)}}}; + .specificParams = {.tieredRaftIvfParams = {.minVectorsInit = + size_t(n_vectors / params_pq.algoParams.raftIvfParams.nLists)}}}; auto *tiered_raft_ivf_pq_index = reinterpret_cast *>( TieredRaftIvfFactory::NewIndex(&tiered_params)); diff --git a/tests/unit/test_raft_ivf_tiered.cpp b/tests/unit/test_raft_ivf_tiered.cpp index b2927a76a..511da4929 100644 --- a/tests/unit/test_raft_ivf_tiered.cpp +++ b/tests/unit/test_raft_ivf_tiered.cpp @@ -232,13 +232,6 @@ TYPED_TEST(RaftIvfTieredTest, transferJob_inplace) { ASSERT_EQ(tiered_index->indexSize(), 2 * n); ASSERT_EQ(tiered_index->backendIndex->indexSize(), flat_buffer_limit * 2); ASSERT_EQ(tiered_index->frontendIndex->indexSize(), 2 * (n - flat_buffer_limit)); - - // Run the thread pool. The thread should transfer the rest of the vectors to the - // backend index. - mock_thread_pool.thread_pool_wait(100); - ASSERT_EQ(tiered_index->indexSize(), 2 * n); - ASSERT_EQ(tiered_index->backendIndex->indexSize(), 2 * n); - ASSERT_EQ(tiered_index->frontendIndex->indexSize(), 0); } TYPED_TEST(RaftIvfTieredTest, deleteVector_backend) { @@ -259,6 +252,8 @@ TYPED_TEST(RaftIvfTieredTest, deleteVector_backend) { // Delete from an empty index. ASSERT_EQ(VecSimIndex_DeleteVector(tiered_index, vec_label), 0); + // Launch the BG threads loop that takes jobs from the queue and executes them. + mock_thread_pool.init_threads(); // Insert vectors for (size_t i = 0; i < n; i++) { GenerateAndAddVector(tiered_index, dim, i, i); @@ -292,3 +287,105 @@ TYPED_TEST(RaftIvfTieredTest, deleteVector_backend) { ASSERT_EQ(tiered_index->frontendIndex->indexSize(), 0); } +TYPED_TEST(RaftIvfTieredTest, searchMetricCosine) { + size_t dim = 32; + size_t n = 25; + size_t nLists = 5; + size_t flat_buffer_limit = 100; + + size_t k = 10; + + // Create RaftIvfTiered index instance with a mock queue. + VecSimParams params = createDefaultFlatParams(dim, nLists, nLists); + + // Set the metric to cosine. + params.algoParams.raftIvfParams.metric = VecSimMetric_Cosine; + + auto mock_thread_pool = tieredIndexMock(); + auto *tiered_index = this->createTieredIndex(¶ms, mock_thread_pool, flat_buffer_limit); + + // Launch the BG threads loop that takes jobs from the queue and executes them. + mock_thread_pool.init_threads(); + std::vector> inserted_vectors; + + for (size_t i = 0; i < n; i++) { + inserted_vectors.push_back(std::vector(dim)); + // Generate vectors + for (size_t j = 0; j < dim; j++) { + inserted_vectors.back()[j] = (TEST_DATA_T)i + j; + } + // Insert vectors + VecSimIndex_AddVector(tiered_index, inserted_vectors.back().data(), i); + } + mock_thread_pool.thread_pool_wait(100); + + // The query is a vector with half of the values equal to 8.1 and the other half equal to 1.1. + TEST_DATA_T query[dim]; + TEST_DATA_T query_norm[dim]; + GenerateVector(query, dim / 2, 8.1f); + GenerateVector(query + dim / 2, dim / 2, 1.1f); + memcpy(query_norm, query, dim * sizeof(TEST_DATA_T)); + VecSim_Normalize(query_norm, dim, VecSimType_FLOAT32); + + auto verify_cb = [&](size_t id, double score, size_t index) { + TEST_DATA_T neighbor_norm[dim]; + memcpy(neighbor_norm, inserted_vectors[id].data(), dim * sizeof(TEST_DATA_T)); + VecSim_Normalize(neighbor_norm, dim, VecSimType_FLOAT32); + + // Use distance function of the bruteforce index to verify the score. + double dist = tiered_index->frontendIndex->getDistFunc()( + query_norm, + neighbor_norm, + dim); + ASSERT_NEAR(score, dist, 1e-5); + }; + + runTopKSearchTest(tiered_index, query, k, verify_cb); +} + +TYPED_TEST(RaftIvfTieredTest, searchMetricIP) { + size_t dim = 4; + size_t n = 25; + size_t nLists = 5; + size_t flat_buffer_limit = 100; + + size_t k = 10; + + // Create RaftIvfTiered index instance with a mock queue. + VecSimParams params = createDefaultFlatParams(dim, nLists, nLists); + + // Set the metric to Inner Product. + params.algoParams.raftIvfParams.metric = VecSimMetric_IP; + + auto mock_thread_pool = tieredIndexMock(); + auto *tiered_index = this->createTieredIndex(¶ms, mock_thread_pool, flat_buffer_limit); + + // Launch the BG threads loop that takes jobs from the queue and executes them. + mock_thread_pool.init_threads(); + std::vector> inserted_vectors; + + for (size_t i = 0; i < n; i++) { + inserted_vectors.push_back(std::vector(dim)); + // Generate vectors + for (size_t j = 0; j < dim; j++) { + inserted_vectors.back()[j] = (TEST_DATA_T)i + j; + } + // Insert vectors + VecSimIndex_AddVector(tiered_index, inserted_vectors.back().data(), i); + } + mock_thread_pool.thread_pool_wait(100); + + // The query is a vector with half of the values equal to 1.1 and the other half equal to 0.1. + TEST_DATA_T query[dim] = {1.1f, 1.1f, 0.1f, 0.1f}; + + auto verify_cb = [&](size_t id, double score, size_t index) { + // Use distance function of the bruteforce index to verify the score. + double dist = tiered_index->frontendIndex->getDistFunc()( + query, + inserted_vectors[id].data(), + dim); + ASSERT_NEAR(score, dist, 1e-5); + }; + + runTopKSearchTest(tiered_index, query, k, verify_cb); +} From 45c510c8a84179fb25a933e040542268a39d3c9d Mon Sep 17 00:00:00 2001 From: Micka Date: Tue, 2 Apr 2024 12:54:03 +0200 Subject: [PATCH 27/28] Update src/VecSim/algorithms/raft_ivf/ivf_tiered.h Co-authored-by: GuyAv46 <47632673+GuyAv46@users.noreply.github.com> --- src/VecSim/algorithms/raft_ivf/ivf_tiered.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/VecSim/algorithms/raft_ivf/ivf_tiered.h b/src/VecSim/algorithms/raft_ivf/ivf_tiered.h index 99d8c14c5..3bc5b241e 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf_tiered.h +++ b/src/VecSim/algorithms/raft_ivf/ivf_tiered.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include "VecSim/algorithms/raft_ivf/ivf_interface.h" #include "VecSim/vec_sim_tiered_index.h" From 21a7d209961e66b66132ed5f1e938d74808e0838 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 3 Jun 2024 22:01:21 +0200 Subject: [PATCH 28/28] Separate index Size --- src/VecSim/algorithms/raft_ivf/ivf.cuh | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/VecSim/algorithms/raft_ivf/ivf.cuh b/src/VecSim/algorithms/raft_ivf/ivf.cuh index afb45a9b0..2c36ffc56 100644 --- a/src/VecSim/algorithms/raft_ivf/ivf.cuh +++ b/src/VecSim/algorithms/raft_ivf/ivf.cuh @@ -330,7 +330,11 @@ public: size_t indexSize() const override { auto result = size_t{}; if (index_) { - result = std::visit([](auto &&index) { return index.size(); }, *index_); + if (std::holds_alternative(*index_)) { + result = std::get(*index_).size(); + } else { + result = std::get(*index_).size(); + } } return result - this->numDeleted_; }