diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..0cf07159 --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +max-line-length = 120 +extend-ignore = E203, W503, E266 +exclude = .git,__pycache__,build,dist,src/cpp,cmake-build-debug,cmake-build-release,cmake-build-relwithdebinfo \ No newline at end of file diff --git a/.github/workflows/build_and_publish_docs.yaml b/.github/workflows/build_and_publish_docs.yaml index 9bc9c525..5b70599e 100644 --- a/.github/workflows/build_and_publish_docs.yaml +++ b/.github/workflows/build_and_publish_docs.yaml @@ -3,30 +3,30 @@ name: Build and Publish Docs on: push: branches: [ main ] -# pull_request: -# branches: [ main ] jobs: + setup-container: + uses: ./.github/workflows/container_setup.yaml + build_docs: + needs: setup-container runs-on: ubuntu-latest container: - image: ghcr.io/marius-team/quake/ubuntu-latest:latest + image: ${{ needs.setup-container.outputs.container_image }} steps: - - uses: actions/checkout@v4 - with: - submodules: true - fetch-depth: 0 + - name: Checkout Repository + uses: actions/checkout@v4 - name: Install Quake run: | git config --global --add safe.directory '*' eval "$(conda shell.bash hook)" conda activate quake-env - pip install . + pip install --no-use-pep517 . - name: Build Sphinx Documentation working-directory: docs - run: | + run: | conda run -n quake-env sphinx-build -b html . _build/html - name: Deploy Documentation to GitHub Pages diff --git a/.github/workflows/build_and_test.yaml b/.github/workflows/build_and_test.yaml index d4eb7927..5f81f0d2 100644 --- a/.github/workflows/build_and_test.yaml +++ b/.github/workflows/build_and_test.yaml @@ -2,31 +2,29 @@ name: Build and Test on: push: - branches: - - main + branches: [ main ] pull_request: - branches: - - main - -env: - BUILD_TYPE: Debug + branches: [ main ] jobs: - test: + setup-container: + uses: ./.github/workflows/container_setup.yaml + + build_and_test: + needs: setup-container runs-on: ubuntu-latest container: - image: ghcr.io/marius-team/quake/ubuntu-latest:latest + image: ${{ needs.setup-container.outputs.container_image }} steps: - - uses: actions/checkout@v4 - with: - submodules: true - fetch-depth: 0 + - name: Checkout Repository + uses: actions/checkout@v4 - - name: Build project + - name: Build C++ run: | git config --global --add safe.directory '*' eval "$(conda shell.bash hook)" conda activate quake-env + conda install libarrow-all=19.0.1 -c conda-forge mkdir -p build cd build cmake -DCMAKE_BUILD_TYPE=${{ env.BUILD_TYPE }} \ @@ -38,30 +36,18 @@ jobs: make bindings -j2 make quake_tests -j2 - - name: Run tests - run: | - cd build - test/cpp/quake_tests - - python_build: - runs-on: ubuntu-latest - container: - image: ghcr.io/marius-team/quake/ubuntu-latest:latest - steps: - - uses: actions/checkout@v4 - with: - submodules: true - fetch-depth: 0 - - - name: Debug Checkout + - name: Run C++ Tests + shell: bash run: | - ls -R + ./build/test/cpp/quake_tests - - name: Build project + - name: Run Python Tests + shell: bash run: | git config --global --add safe.directory '*' eval "$(conda shell.bash hook)" conda activate quake-env - pip install . + conda install libarrow-all=19.0.1 -c conda-forge + pip install --no-use-pep517 . pip install pytest python -m pytest test/python \ No newline at end of file diff --git a/.github/workflows/container_setup.yaml b/.github/workflows/container_setup.yaml new file mode 100644 index 00000000..24a20c09 --- /dev/null +++ b/.github/workflows/container_setup.yaml @@ -0,0 +1,106 @@ +name: Container Setup + +on: + workflow_call: + outputs: + container_image: + description: "Container image determined by this workflow" + value: ${{ jobs.determine-and-build.outputs.container_image }} + +jobs: + determine-and-build: + runs-on: ubuntu-latest + outputs: + container_image: ${{ steps.set_image.outputs.container_image }} + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + with: + submodules: true + fetch-depth: 0 + + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Check for dependency changes + id: diff + shell: bash + run: | + echo "Checking for dependency changes..." + # Fetch the latest state of the default branch. + git fetch origin ${{ github.event.repository.default_branch }} + # Compare the dependency files between the default branch and current commit. + if git diff --name-only origin/${{ github.event.repository.default_branch }}...${{ github.sha }} | grep -E 'Dockerfile|conda.yaml'; then + echo "build_required=true" >> "$GITHUB_OUTPUT" + else + echo "build_required=false" >> "$GITHUB_OUTPUT" + fi + + - name: Use Main Container if No Dependency Changes + id: use_main + if: steps.diff.outputs.build_required == 'false' + shell: bash + run: | + image="ghcr.io/marius-team/quake/ubuntu-latest:latest" + echo "No dependency changes detected. Using main container: $image" + echo "container_image=$image" >> "$GITHUB_OUTPUT" + + - name: Compute Dependency Hash + id: dependency_hash + if: steps.diff.outputs.build_required == 'true' + shell: bash + run: | + echo "Computing hash for dependency files..." + # Concatenate the files to create a unique fingerprint. + hash=$(cat environments/ubuntu-latest/Dockerfile environments/ubuntu-latest/conda.yaml 2>/dev/null | sha256sum | awk '{print $1}') + echo "hash=$hash" >> "$GITHUB_OUTPUT" + echo "Dependency hash: $hash" + + - name: Check if Test Container Exists + id: check_container + if: steps.diff.outputs.build_required == 'true' + shell: bash + run: | + image="ghcr.io/marius-team/quake/test_container:${{ steps.dependency_hash.outputs.hash }}" + echo "Checking if container exists: $image" + if docker pull "$image" > /dev/null 2>&1; then + echo "container_exists=true" >> "$GITHUB_OUTPUT" + echo "container_image=$image" >> "$GITHUB_OUTPUT" + echo "Container already exists: $image" + else + echo "container_exists=false" >> "$GITHUB_OUTPUT" + echo "Container not found: $image" + fi + + - name: Build Test Container if Needed + id: build + if: steps.diff.outputs.build_required == 'true' && steps.check_container.outputs.container_exists == 'false' + shell: bash + run: | + tag="ghcr.io/marius-team/quake/test_container:${{ steps.dependency_hash.outputs.hash }}" + echo "Building test container: $tag" + docker build -f environments/ubuntu-latest/Dockerfile -t "$tag" . + docker push "$tag" + echo "container_image=$tag" >> "$GITHUB_OUTPUT" + echo "Test container built and pushed." + + - name: Set Container Image Output + id: set_image + shell: bash + run: | + if [ "${{ steps.diff.outputs.build_required }}" = "false" ]; then + image="ghcr.io/marius-team/quake/ubuntu-latest:latest" + else + # Use the pre-existing container image if available, otherwise the one just built. + if [ -n "${{ steps.check_container.outputs.container_image }}" ]; then + image="${{ steps.check_container.outputs.container_image }}" + else + image="${{ steps.build.outputs.container_image }}" + fi + fi + echo "Using container image: $image" + echo "container_image=$image" >> "$GITHUB_OUTPUT" \ No newline at end of file diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 00000000..fff6ffe4 --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,32 @@ +name: Lint Checks + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.9' + + - name: Install Dependencies + run: | + pip install black==25.1.0 isort==6.0.1 flake8==7.1.2 + + - name: Run Black Check + run: black --check . + + - name: Run isort Check + run: isort --check-only . + + - name: Run flake8 + run: flake8 . \ No newline at end of file diff --git a/.github/workflows/publish_docker_image.yaml b/.github/workflows/publish_docker_image.yaml index ad054aae..a4b822b8 100644 --- a/.github/workflows/publish_docker_image.yaml +++ b/.github/workflows/publish_docker_image.yaml @@ -3,14 +3,13 @@ name: Build and Publish Docker Image on: push: branches: [ main ] -# pull_request: -# branches: [ main ] jobs: build-and-push-image: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - name: Checkout Repository + uses: actions/checkout@v4 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -25,27 +24,10 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - - name: Build and push + - name: Build and push main container uses: docker/build-push-action@v4 with: context: . file: environments/ubuntu-latest/Dockerfile tags: ghcr.io/marius-team/quake/ubuntu-latest:latest - push: true - - test: - runs-on: ubuntu-latest - container: - image: ghcr.io/marius-team/quake/ubuntu-latest:latest - needs: [build-and-push-image] - steps: - - uses: actions/checkout@v4 - - - name: Build project - run: | - git config --global --add safe.directory '*' - eval "$(conda shell.bash hook)" - conda activate quake-env - mkdir build && cd build - cmake -DCMAKE_BUILD_TYPE=Release -DQUAKE_USE_NUMA=OFF -DQUAKE_ENABLE_GPU=OFF -DBUILD_TESTS=ON ../ - make bindings -j2 \ No newline at end of file + push: true \ No newline at end of file diff --git a/.github/workflows/regression_baseline.yaml b/.github/workflows/regression_baseline.yaml new file mode 100644 index 00000000..c4889be5 --- /dev/null +++ b/.github/workflows/regression_baseline.yaml @@ -0,0 +1,41 @@ +name: Manual Baseline Generation + +on: + workflow_dispatch: + +jobs: + setup-container: + uses: ./.github/workflows/container_setup.yaml + + generate-baseline: + needs: setup-container + runs-on: ubuntu-latest + container: + image: ${{ needs.setup-container.outputs.container_image }} + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + + - name: Install Quake + run: | + git config --global --add safe.directory '*' + eval "$(conda shell.bash hook)" + conda activate quake-env + pip install --no-use-pep517 . + + - name: Run Baseline Driver Script + run: | + git config --global --add safe.directory '*' + eval "$(conda shell.bash hook)" + conda activate quake-env + cd test/python/regression + python run_all_workloads.py --overwrite --name baseline + + - name: Upload Baseline Artifacts + uses: actions/upload-artifact@v4 + with: + name: baseline-artifacts + path: | + test/python/regression/workloads + test/python/regression/results + test/python/regression/data \ No newline at end of file diff --git a/.github/workflows/run_regression.yaml b/.github/workflows/run_regression.yaml new file mode 100644 index 00000000..3c60930a --- /dev/null +++ b/.github/workflows/run_regression.yaml @@ -0,0 +1,42 @@ +name: PR Regression Test + +on: + pull_request: + branches: [ main ] + +jobs: + setup-container: + uses: ./.github/workflows/container_setup.yaml + + regression-test: + needs: setup-container + runs-on: ubuntu-latest + container: + image: ${{ needs.setup-container.outputs.container_image }} + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + + - name: Install Quake + run: | + git config --global --add safe.directory '*' + eval "$(conda shell.bash hook)" + conda activate quake-env + pip install --no-use-pep517 . + + - name: Get Baseline Artifacts + uses: dawidd6/action-download-artifact@v9 + with: + name: baseline-artifacts + workflow: regression_baseline.yaml + + - name: Run Regression Tests + run: | + git config --global --add safe.directory '*' + eval "$(conda shell.bash hook)" + conda activate quake-env + mv workloads test/python/regression + mv results test/python/regression + mv data test/python/regression + cd test/python/regression + python run_all_workloads.py --name "PR-${{ github.event.number }}" \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 32db4ded..f7f2db8e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,6 +8,10 @@ if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) endif() +set(Torch_USE_CUDA OFF CACHE BOOL "Force disable CUDA in Torch") +set(Torch_NO_CUDA ON CACHE BOOL "Force disable CUDA in Torch") +set(USE_CUDA OFF CACHE BOOL "Force disable CUDA globally") + # QUAKE_ENABLE_GPU: Enable GPU support for Faiss # Default: OFF if(QUAKE_ENABLE_GPU) @@ -27,6 +31,9 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) # Compiler flags set(CMAKE_CXX_FLAGS_DEBUG "-g") set(CMAKE_CXX_FLAGS_RELEASE "-O3") +if(NOT DEFINED QUAKE_SET_ABI_MODE) + set(QUAKE_SET_ABI_MODE ON) +endif() # If in a conda environment, favor conda packages if(EXISTS $ENV{CONDA_PREFIX}) @@ -75,7 +82,12 @@ endif() # Compiler options and definitions add_compile_options(-march=native) -add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) + +# Switch ABI mode +if(QUAKE_SET_ABI_MODE) + add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) +endif() + # --------------------------------------------------------------- # Find Required Packages diff --git a/README.md b/README.md index 65ce9788..8bab54fd 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ ![image](https://github.com/user-attachments/assets/559fe8da-84a6-4e44-a06a-cd35c5012e9a) -# Query-Adaptive KNN Index +# Query-Adaptive Vector Search -Quake is a C++ library (with Python bindings) for dynamic, high‑performance approximate nearest neighbor (ANN) search. Its core operations—building a dynamic index, adaptive search, real‑time updates, and automatic maintenance—ensure high-throughput updates and low-latency queries without manual tuning. +Quake is a library for dynamic, high‑performance approximate nearest neighbor (ANN) search. --- ## Key Advantages @@ -55,7 +55,7 @@ Quake has the following limitations which will be addressed in future developmen 3. **Install Quake:** ```bash - pip install . + pip install --no-use-pep517 . ``` --- diff --git a/docs/architecture/architecture.rst b/docs/architecture/architecture.rst index 5fd2e73c..de7c4d76 100644 --- a/docs/architecture/architecture.rst +++ b/docs/architecture/architecture.rst @@ -31,6 +31,7 @@ Detailed Components :maxdepth: 2 query_coordinator + maintenance_policy partition_manager dynamic_inverted_list index_partition diff --git a/docs/architecture/maintenance_policy.rst b/docs/architecture/maintenance_policy.rst new file mode 100644 index 00000000..df281b6c --- /dev/null +++ b/docs/architecture/maintenance_policy.rst @@ -0,0 +1,56 @@ +MaintenancePolicy +================================ + +Overview +-------- +The **MaintenancePolicy** is a central component of the Quake indexing system that +automates index upkeep by monitoring query “hit” patterns and deciding when to +perform maintenance actions such as splitting or deleting partitions. It ties together +several key components: + +- **PartitionManager**: Manages the set of index partitions. The maintenance policy + uses it to query partition sizes, delete outdated partitions, and trigger splits. +- **MaintenanceCostEstimator**: Provides latency-based estimates for how much a + particular maintenance action (split or delete) will cost. It uses measured or + extrapolated scan latencies as a function of partition size and search parameters. +- **HitCountTracker**: Records per-query “hit” counts (i.e. which partitions were + scanned) over a sliding window. It computes the average scan fraction and maintains + history of split and delete events for later analysis. + +Key Methods +------------------------- +1. **Record Query Hits:** + Each time a query is processed, the `record_query_hits()` method is called + with the partition IDs that were “hit” during the search. The *HitCountTracker* + accumulates these events to compute the current scan fraction across the sliding window. + +2. **Perform Maintenance:** + When the window is full, `perform_maintenance()` is invoked. This method: + + - Aggregates hit counts from the *HitCountTracker*. + - For each partition, uses the *MaintenanceCostEstimator* to compute a “delta” + (change in cost) for both deletion and splitting. The decision is based on whether + the estimated cost delta exceeds configured thresholds. + - Determines which partitions should be split (to improve query efficiency) + or deleted (if underutilized). + - Triggers the corresponding operations through the *PartitionManager* and then, + if needed, calls local refinement on newly split partitions. + - Returns timing information via a *MaintenanceTimingInfo* structure. + +3. **Reset:** + After maintenance operations complete, the policy can be reset (via `reset()`) + to clear the hit history and start a fresh monitoring window. + +Configuration and Parameters +------------------------------ +Maintenance behavior is governed by a set of parameters (encapsulated in the +**MaintenancePolicyParams** structure): + +- **window_size**: Number of queries over which hits are aggregated. +- **refinement_radius** and **refinement_iterations**: Control local refinement of new partitions. +- **delete_threshold_ns** and **split_threshold_ns**: Latency thresholds (in nanoseconds) + that trigger deletion or splitting. +- **alpha**: Scaling factor applied to cost estimates. +- **enable_split_rejection / enable_delete_rejection**: Flags to allow rejecting an + otherwise triggered action if additional checks (such as vector reassignments) suggest it + may not be beneficial. diff --git a/docs/architecture/query_coordinator.rst b/docs/architecture/query_coordinator.rst new file mode 100644 index 00000000..047f9401 --- /dev/null +++ b/docs/architecture/query_coordinator.rst @@ -0,0 +1,156 @@ +.. _query_coordinator: + +Query Coordinator +===================================== + +Overview +-------- +This document describes the NUMA‑aware design for the QueryCoordinator. In this design, the coordinator distributes query work to worker threads that scan index partitions in parallel. The design focuses on reducing memory latency by using NUMA‑aware allocation and thread affinity, minimizing synchronization overhead via per‑core resource pools and job queues, and uses adaptive partition scanning (APS) to enable early termination. + +The design aims to: + +- Minimize memory latency by pinning data (e.g. top‑K buffers, partitions) to the same cores that process them. +- Reduce synchronization overhead by using per‑core resource pools and per‑core job queues. +- Achieve nearly linear scalability with the number of cores while matching the serial scan performance for a single worker. +- Use APS (adaptive early termination) to terminate scanning early when the recall target is met. + +Coordinator Architecture Diagram +---------------------------------- + +.. mermaid:: + + %%{ + init: { + "theme": "default", + "themeVariables": { + "fontSize": "20px", + "fontFamily": "Arial" + "fontWeight": "bold" + } + } + }%% + + flowchart LR + + %% External inputs/nodes + QIN["Input Query Vectors"] + PSCAN["Partition IDs"] + R["Final Results"] + + %% The main QueryCoordinator subgraph + subgraph QC["QueryCoordinator"] + direction LR + + %% Core 1 + subgraph Core1["Core 1"] + direction TB + JQ1[["Job Queue"]] + QBA1[/Query Buffers/] + IP1["Index Partitions"] + WT1("Scan Thread") + LTK1(("Local TopK")) + end + + %% Core n + subgraph CoreN["Core n"] + direction TB + JQN[["Job Queue"]] + QBAN[/Query Buffers/] + IPN["Index Partitions"] + WTN("Scan Thread") + LTKn(("Local TopK")) + end + + GTopk(("Global TopK")) + + %% APS module that checks the global buffer and signals early termination + APS["Adaptive Partition Scanning (APS)"] + end + + %% Edges from input queries to each core's query buffers + QIN -->|memcpy| QBA1 + QIN -->|memcpy| QBAN + + PSCAN -->|enqueue| JQ1 + PSCAN -->|enqueue| JQN + + %% Inside each core, show the flow + JQ1 -->|dequeue| WT1 + QBA1 -->|read| WT1 + IP1 -->|scan| WT1 + WT1 -->|write| LTK1 + + JQN -->|dequeue| WTN + QBAN -->|read| WTN + IPN -->|scan| WTN + WTN -->|write| LTKn + + %% Merge local topK into global topK + LTK1 -->|merge| GTopk + LTKn -->|merge| GTopk + + %% APS periodically checks global buffer for recall progress + GTopk <-->|check recall| APS + APS -->|signal early termination| GTopk + + %% Finally, global topk to results + GTopk -->|return| R + + %% Optional styling for clarity + style QC fill:#fff7e6,stroke:#666,stroke-width:8px; + style QIN fill:#ccf,stroke:#333,stroke-width:1px; + style R fill:#ffecb3,stroke:#333,stroke-width:1px; + style Core1 fill:#fff,stroke:#999,stroke-width:1px; + style CoreN fill:#fff,stroke:#999,stroke-width:1px; + style LTK1 fill:#eef,stroke:#333,stroke-width:1px,stroke-dasharray:2 2; + style LTKn fill:#eef,stroke:#333,stroke-width:1px,stroke-dasharray:2 2; + style GTopk fill:#eef,stroke:#333,stroke-width:1px,stroke-dasharray:2 2; + style APS fill:#fff,stroke:#999,stroke-width:1px,stroke-dasharray:3 3; + + +Key Components +-------------- +- **QueryCoordinator** + The main class that distributes query work, manages worker threads, and merges local results from all cores into a final search result. + +- **CoreResources** + A per‑core structure that contains: + + - A pool of preallocated Top‑K buffers that are allocated using NUMA‑aware routines and pinned to local memory. + - A local aggregator (query buffer) to collect intermediate results. + - A dedicated job queue that holds scan jobs for that core. + +- **ScanJob Structure** + Each unit of work (a ScanJob) encapsulates: + + - Whether the job is batched or single‑query. + - The partition ID (which is pinned to a specific core). + - The number of neighbors (``k``) to return. + - A pointer to the query vector(s). + - Global query IDs and, for batched jobs, the number of queries. + +- **Global Aggregator** + A coordinator-managed Top‑K buffer that merges per‑core local aggregators to produce the final search result. + +Workflow and Job Distribution +------------------------------- +1. **Distribute Partitions to Cores:** + The PartitionManager assigns partitions to cores based on partition size. Each partition’s memory is allocated on the correct NUMA node using NUMA‑aware routines. + +2. **Per‑Core Job Queues:** + The QueryCoordinator creates a per‑core job queue inside each CoreResources structure. For each partition local to a core, a ScanJob is created (either for single-query or batched queries) and enqueued into that core’s job queue. + +3. **Worker Processing:** + Each worker thread (one per core) executes a stateless worker function that: + + - Sets affinity to the core it belongs to. + - Dequeues jobs from its core’s job queue. + - Processes each job (invoking the appropriate scan function). + - Merges results into the core’s local aggregator. + - Decrements a global atomic counter (or per‑core counter) and signals a condition variable for global coordination. + +4. **Global Aggregation:** + The coordinator periodically merges local aggregators into a global Top‑K buffer. + +5. **APS and Early Termination:** + The APS module periodically checks the global Top‑K buffer to determine if the recall target has been met. If so, it signals worker threads to stop scanning. diff --git a/docs/conf.py b/docs/conf.py index a00d32fb..66cfb841 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,26 +12,27 @@ # Add the path to your bindings (adjust as needed) sys.path.insert(0, quake_path) # -- Project information ----------------------------------------------------- -project = 'Quake' -author = 'Jason Mohoney' -copyright = f'{datetime.now().year}, {author}' -release = '0.0.1' +project = "Quake" +author = "Jason Mohoney" +copyright = f"{datetime.now().year}, {author}" +release = "0.0.1" # -- General configuration --------------------------------------------------- extensions = [ - 'sphinx.ext.autodoc', # Automatically extract docstrings from Python - 'sphinx.ext.napoleon', # Support for Google and NumPy style docstrings - 'sphinx.ext.viewcode', # Add links to highlighted source code - 'sphinxcontrib.mermaid', # Add support for Mermaid diagrams - 'sphinx.ext.graphviz', # Add support for Graphviz diagrams + "sphinx.ext.autodoc", # Automatically extract docstrings from Python + "sphinx.ext.napoleon", # Support for Google and NumPy style docstrings + "sphinx.ext.viewcode", # Add links to highlighted source code + "sphinxcontrib.mermaid", # Add support for Mermaid diagrams + "sphinx.ext.graphviz", # Add support for Graphviz diagrams ] -templates_path = ['_templates'] +templates_path = ["_templates"] exclude_patterns = [] # -- Options for HTML output ------------------------------------------------- html_theme = "sphinx_rtd_theme" -html_static_path = ['_static'] +html_static_path = ["_static"] + def setup(app): - app.add_css_file('css/modify.css') \ No newline at end of file + app.add_css_file("css/modify.css") diff --git a/docs/development_guide.rst b/docs/development_guide.rst index dd6e060c..27c19817 100644 --- a/docs/development_guide.rst +++ b/docs/development_guide.rst @@ -82,11 +82,11 @@ Testing Workflow -------------------------- -1. **Clone and Set Up:** +1. **Fork and Set Up:** .. code-block:: bash - git clone https://github.com/marius-team/quake.git + git fork https://github.com/marius-team/quake.git cd quake git submodule update --init --recursive @@ -111,15 +111,19 @@ C++ Build (optional, if you only want to work on Python code): .. code-block:: bash - mkdir build && cd build - cmake -DCMAKE_BUILD_TYPE=Release .. - make -j$(nproc) bindings + mkdir build && cd build + cmake -DCMAKE_BUILD_TYPE=Release \ + -DQUAKE_ENABLE_GPU=ON \ + -DQUAKE_USE_NUMA=ON \ + -DQUAKE_USE_AVX512=ON \ + -DBUILD_TESTS=ON .. + make -j$(nproc) bindings Python Build .. code-block:: bash - pip install . + pip install --no-use-pep517 . 5. **Run Tests:** @@ -141,10 +145,19 @@ Quake must be installed with pip to run the Python tests. Run them using pytest: pytest test/python/ -6. **Make Changes and submit a PR:** +6. **Run Autoformatting and Linters:** -After making changes, commit them and push to your branch. Then, create a PR on the main branch. +Lint checks need to pass before submitting a PR. + +We use `black`, `isort` and `flake8`. + +Run the following scripts to autoformat and run linters: + +.. code-block:: bash + + source scripts/autoformat.sh + source scripts/lint.sh -Conclusion ----------- -This guide is a living document. As Quake evolves, update it to reflect improvements and new practices. Our goal is to keep the codebase and its documentation clear, correct, and easy to contribute to. +7. **Make Changes and submit a PR:** + +After making changes, commit them and push to your branch. Then, create a PR on the main branch. diff --git a/docs/install.rst b/docs/install.rst index 33b9f8ff..42c7d334 100644 --- a/docs/install.rst +++ b/docs/install.rst @@ -21,7 +21,7 @@ Clone the repository and build the extension: git submodule update --init --recursive conda env create -f environments/ubuntu-latest/conda.yaml conda activate quake-env - pip install . + pip install --no-use-pep517 . For advanced build options (e.g. enabling GPU, NUMA, or AVX512), use the cmake build: diff --git a/docs/quickstart.rst b/docs/quickstart.rst index a393d0f5..1059009a 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -37,7 +37,7 @@ Install the Quake package (which includes the Python bindings): .. code-block:: bash - pip install . + pip install --no-use-pep517 . Step 4: Run Example Program ------------------------------- diff --git a/environments/ubuntu-latest/Dockerfile b/environments/ubuntu-latest/Dockerfile index c9134f0c..d99679a8 100644 --- a/environments/ubuntu-latest/Dockerfile +++ b/environments/ubuntu-latest/Dockerfile @@ -36,18 +36,21 @@ RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh - bash /tmp/miniconda.sh -b -p $CONDA_DIR && \ rm /tmp/miniconda.sh -# ----------------------------- -# Set up and activate conda environment -# ----------------------------- + +# Copy in your conda environment YAML COPY environments/ubuntu-latest/conda.yaml /tmp/conda.yaml -RUN conda env create -f /tmp/conda.yaml && \ - conda clean -afy -SHELL ["bash", "-c"] -RUN echo "conda activate test-env" >> ~/.bashrc -ENV CONDA_DEFAULT_ENV=test-env +# Create quake-env +RUN conda env create -f /tmp/conda.yaml && conda clean -afy # ----------------------------- -# Entry point +# Install PyTorch (CPU-only) # ----------------------------- -ENTRYPOINT ["/bin/bash", "-ci", "conda activate test-env && exec bash"] \ No newline at end of file +RUN conda run -n quake-env pip install torch --index-url https://download.pytorch.org/whl/cpu + +RUN echo "===== DEBUG: which conda =====" && which conda +RUN echo "===== DEBUG: conda info =====" && conda info +RUN echo "===== DEBUG: conda env list =====" && conda env list +RUN echo "===== DEBUG: quake-env check =====" && conda run -n quake-env python -c "import sys; print('OK in quake-env; python:', sys.executable)" + +CMD ["/bin/bash"] \ No newline at end of file diff --git a/environments/ubuntu-latest/conda.yaml b/environments/ubuntu-latest/conda.yaml index 1ab6789c..b6e94d9a 100644 --- a/environments/ubuntu-latest/conda.yaml +++ b/environments/ubuntu-latest/conda.yaml @@ -6,8 +6,7 @@ channels: dependencies: - python=3.11 - numpy - - pytorch - - cpuonly + - pandas - faiss-cpu - matplotlib - pytest @@ -18,4 +17,4 @@ dependencies: - sphinx_rtd_theme - sphinxcontrib-mermaid - graphviz - + - pyyaml \ No newline at end of file diff --git a/examples/quickstart.py b/examples/quickstart.py index b9310e86..56c1dee6 100644 --- a/examples/quickstart.py +++ b/examples/quickstart.py @@ -15,13 +15,15 @@ python examples/quickstart.py """ -import torch -import math import time -from quake import QuakeIndex, IndexBuildParams, SearchParams + +import torch + +from quake import IndexBuildParams, QuakeIndex, SearchParams from quake.datasets.ann_datasets import load_dataset from quake.utils import compute_recall + def main(): print("=== Quake Basic Example ===") @@ -40,7 +42,10 @@ def main(): build_params = IndexBuildParams() build_params.nlist = 1024 build_params.metric = "l2" - print("Building index with num_clusters=%d over %d vectors of dimension %d..." % (build_params.nlist, vectors.size(0), vectors.size(1))) + print( + "Building index with num_clusters=%d over %d vectors of dimension %d..." + % (build_params.nlist, vectors.size(0), vectors.size(1)) + ) start_time = time.time() index = QuakeIndex() index.build(vectors, ids, build_params) @@ -55,7 +60,10 @@ def main(): # or set a recall target # search_params.recall_target = 0.9 - print("Performing search of %d queries with k=%d and nprobe=%d..." % (queries.size(0), search_params.k, search_params.nprobe)) + print( + "Performing search of %d queries with k=%d and nprobe=%d..." + % (queries.size(0), search_params.k, search_params.nprobe) + ) start_time = time.time() search_result = index.search(queries, search_params) end_time = time.time() @@ -102,5 +110,6 @@ def main(): # index = QuakeIndex() # index.load("quake_index") -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/examples/workload_generator/basic_usage.py b/examples/workload_generator/basic_usage.py index 128c046a..fe6b8a93 100644 --- a/examples/workload_generator/basic_usage.py +++ b/examples/workload_generator/basic_usage.py @@ -13,8 +13,10 @@ import math from pathlib import Path -from quake.workload_generator import DynamicWorkloadGenerator, WorkloadEvaluator from quake.datasets.ann_datasets import load_dataset +from quake.index_wrappers.quake import QuakeWrapper +from quake.workload_generator import DynamicWorkloadGenerator, WorkloadEvaluator + def main(): # Directories for workload and evaluation output @@ -26,9 +28,9 @@ def main(): base_vectors, queries, _ = load_dataset("sift1m") # Workload generation parameters - insert_ratio = .9 + insert_ratio = 0.9 delete_ratio = 0.0 - query_ratio = .1 + query_ratio = 0.1 update_batch_size = 10000 query_batch_size = 10 number_of_operations = 1000 @@ -41,7 +43,7 @@ def main(): # Search parameters search_k = 10 - recall_target = .9 + recall_target = 0.9 # Create a DynamicWorkloadGenerator instance generator = DynamicWorkloadGenerator( @@ -59,32 +61,35 @@ def main(): cluster_sample_distribution=cluster_sample_distribution, queries=queries, query_cluster_sample_distribution=query_cluster_sample_distribution, - seed=seed + seed=seed, ) # Generate the workload (operations are saved to disk along with a runbook) print("Generating workload...") generator.generate_workload() - # Define an example index configuration. - index_cfg = { - "name": "Quake", - "build_params": { - "nc": int(math.sqrt(initial_size)) * 1, - } - } - # Create a WorkloadEvaluator instance and evaluate the workload - evaluator = WorkloadEvaluator( - workload_dir=workload_dir, - index_cfg=index_cfg, - output_dir=output_dir - ) + evaluator = WorkloadEvaluator(workload_dir=workload_dir, output_dir=output_dir) print("Evaluating workload...") - search_params = {"recall_target": recall_target, "k": search_k} - results = evaluator.evaluate_workload(search_params, do_maintenance=True) + + nc = 1000 + build_params = {"nc": nc, "metric": "l2"} + search_params = {"k": search_k, "recall_target": recall_target} + + index = QuakeWrapper() + + results = evaluator.evaluate_workload( + name="quake_test", + index=index, + build_params=build_params, + search_params=search_params, + do_maintenance=True, + ) + + print("Evaluation results:") + print(results) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..520eddf9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,52 @@ +[project] +name = "quake" +version = "0.0.1" +description = "Dynamic index for vector search" +readme = "README.md" +requires-python = ">=3.9" +authors = [ + { name = "Jason Mohoney", email = "mohoney2@wisc.edu" } +] +license = { file = "LICENSE" } + +dependencies = [ + "torch>=2.0", + "numpy", + "pandas", + "faiss-cpu", + "matplotlib" +] + +[tool.black] +line-length = 120 +target-version = ['py311'] +include = '\.pyi?$' +exclude = ''' +/( + | \.git + | \.mypy_cache + | \.tox + | \.venv + | _build + | build + | dist + | src/cpp +)/ +''' + +[tool.isort] +profile = "black" +skip = ["src/cpp/*", "build/*", "dist/*"] +line_length = 120 + +[project.optional-dependencies] +docs = [ + "sphinx", + "sphinx_rtd_theme", + "sphinxcontrib-mermaid", + "graphviz", + "pyyaml" +] +tests = [ + "pytest" +] \ No newline at end of file diff --git a/scripts/autoformat.sh b/scripts/autoformat.sh new file mode 100644 index 00000000..0898fa3d --- /dev/null +++ b/scripts/autoformat.sh @@ -0,0 +1,6 @@ +# This script will autoformat all the code in the project +# Run this script from the root of the project +echo "Running black..." +black . +echo "Running isort..." +isort . \ No newline at end of file diff --git a/scripts/check_lint.sh b/scripts/check_lint.sh new file mode 100644 index 00000000..5a8bed0b --- /dev/null +++ b/scripts/check_lint.sh @@ -0,0 +1,8 @@ +# This scripts checks the code for linting errors +# Run this script from the root of the project +echo "Running flake8..." +flake8 . +echo "Running black..." +black --check . +echo "Running isort..." +isort --check-only . \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 219a32ff..89aade7a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,16 +5,6 @@ description = Dynamic index for vector search long_description = file: README.md long_description_content_type = text/markdown -# Author information -author = Jason Mohoney -author_email = mohoney2@wisc.edu -maintainer = Jason Mohoney -maintainer_email = mohoney2@wisc.edu - -# License information -license = MIT -license_files = LICENSE - [options] install_requires = torch>=2.0 diff --git a/setup.py b/setup.py index ae808406..981380d8 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ import os import platform -import subprocess import shutil +import subprocess import sys from setuptools import Extension, setup @@ -57,10 +57,11 @@ def build_extension(self, ext): # if the gpu version of torch is installed, add the flag to the cmake args to enable the GPU build if torch.cuda.is_available(): cmake_args += ["-DQUAKE_ENABLE_GPU=ON"] - + else: + cmake_args += ["-DQUAKE_ENABLE_GPU=OFF", "-DTorch_NO_CUDA=ON", "-DTorch_USE_CUDA=OFF", "-DUSE_CUDA=OFF"] # check if numa is available try: - subprocess.check_output(["numactl", "--version"]) + subprocess.check_output(["numactl", "--show"]) cmake_args += ["-DQUAKE_USE_NUMA=ON"] except OSError: cmake_args += ["-DQUAKE_USE_NUMA=OFF"] @@ -92,7 +93,6 @@ def build_extension(self, ext): subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=self.build_temp, env=env) subprocess.check_call(["cmake", "--build", ".", "--target", "bindings"] + build_args, cwd=self.build_temp) - def generate_stubs(self): # Get the full path to the built extension. ext = self.extensions[0] @@ -113,11 +113,7 @@ def generate_stubs(self): os.makedirs(stub_output_dir, exist_ok=True) # Run stubgen on your module. - cmd = [ - "pybind11-stubgen", - "-o", stub_output_dir, - "quake._bindings" - ] + cmd = ["pybind11-stubgen", "-o", stub_output_dir, "quake._bindings"] subprocess.check_call(cmd, env=env) # The generated stub should be at /quake/_bindings.pyi. @@ -130,10 +126,11 @@ def generate_stubs(self): shutil.copyfile(generated_stub, dest) print(f"Stub file copied to {dest}") + setup( name="quake", version="0.0.1", ext_modules=[CMakeExtension("quake._bindings")], cmdclass={"build_ext": CMakeBuild}, zip_safe=False, -) \ No newline at end of file +) diff --git a/src/cpp/bindings/wrap.cpp b/src/cpp/bindings/wrap.cpp index 8a987a94..1ef561bd 100644 --- a/src/cpp/bindings/wrap.cpp +++ b/src/cpp/bindings/wrap.cpp @@ -208,16 +208,6 @@ PYBIND11_MODULE(_bindings, m) { (std::string("Delete threshold (ns). default = ") + std::to_string(DEFAULT_DELETE_THRESHOLD_NS)).c_str()) .def_readwrite("split_threshold_ns", &MaintenancePolicyParams::split_threshold_ns, (std::string("Split threshold (ns). default = ") + std::to_string(DEFAULT_SPLIT_THRESHOLD_NS)).c_str()) - .def_readwrite("k_large", &MaintenancePolicyParams::k_large, - (std::string("Large k value for maintenance. default = ") + std::to_string(DEFAULT_K_LARGE)).c_str()) - .def_readwrite("k_small", &MaintenancePolicyParams::k_small, - (std::string("Small k value for maintenance. default = ") + std::to_string(DEFAULT_K_SMALL)).c_str()) - .def_readwrite("modify_centroids", &MaintenancePolicyParams::modify_centroids, - (std::string("Flag to modify centroids during maintenance. default = ") + std::to_string(DEFAULT_MODIFY_CENTROIDS)).c_str()) - .def_readwrite("target_partition_size", &MaintenancePolicyParams::target_partition_size, - (std::string("Target partition size. default = ") + std::to_string(DEFAULT_TARGET_PARTITION_SIZE)).c_str()) - .def_readwrite("max_partition_ratio", &MaintenancePolicyParams::max_partition_ratio, - (std::string("Maximum allowed partition ratio. default = ") + std::to_string(DEFAULT_MAX_PARTITION_RATIO)).c_str()) .def("__repr__", [](const MaintenancePolicyParams &m) { std::ostringstream oss; oss << "{"; @@ -231,11 +221,6 @@ PYBIND11_MODULE(_bindings, m) { oss << "\"enable_delete_rejection\": " << (m.enable_delete_rejection ? "true" : "false") << ", "; oss << "\"delete_threshold_ns\": " << m.delete_threshold_ns << ", "; oss << "\"split_threshold_ns\": " << m.split_threshold_ns << ", "; - oss << "\"k_large\": " << m.k_large << ", "; - oss << "\"k_small\": " << m.k_small << ", "; - oss << "\"modify_centroids\": " << (m.modify_centroids ? "true" : "false") << ", "; - oss << "\"target_partition_size\": " << m.target_partition_size << ", "; - oss << "\"max_partition_ratio\": " << m.max_partition_ratio; oss << "}"; return oss.str(); }); diff --git a/src/cpp/include/clustering.h b/src/cpp/include/clustering.h index 4b2552f3..905fd68e 100644 --- a/src/cpp/include/clustering.h +++ b/src/cpp/include/clustering.h @@ -9,6 +9,20 @@ #include +class IndexPartition; + +/** + * @brief Clusters vectors into partitions using k-means. + * + * Uses the faiss::Clustering class to cluster vectors into n_clusters partitions. + * + * @param vectors The vectors to cluster. + * @param ids The IDs of the vectors. + * @param n_clusters The number of clusters to create. + * @param metric_type The metric type to use for clustering. + * @param niter The number of iterations to run k-means. + * @param initial_centroids The initial centroids to use for k-means. + */ shared_ptr kmeans(Tensor vectors, Tensor ids, int n_clusters, @@ -18,4 +32,24 @@ shared_ptr kmeans(Tensor vectors, Tensor initial_centroids = Tensor() ); + +/** + * @brief Refines partitions using k-means. + * + * Uses batched_scan_list to reassign each vector (from every partition) to its nearest + * centroid and then rebuilds the partitions and centroids. + * + * @param centroids The current centroids as an IndexPartition. + * @param index_partitions The current partitions. + * @param metric The metric type to use for clustering. + * @param refinement_iterations If 0, only reassign; otherwise, update centroids iteratively. + * + * @return A tuple with (updated centroids, new refined partitions) + */ +tuple>> kmeans_refine_partitions( + Tensor centroids, + vector> index_partitions, + MetricType metric, + int refinement_iterations = 0); + #endif //CLUSTERING_H diff --git a/src/cpp/include/common.h b/src/cpp/include/common.h index c17b02a6..e068a836 100644 --- a/src/cpp/include/common.h +++ b/src/cpp/include/common.h @@ -61,6 +61,10 @@ using std::make_shared; using std::size_t; using std::string; using std::chrono::high_resolution_clock; +using std::chrono::duration_cast; +using std::chrono::nanoseconds; +using std::chrono::microseconds; +using std::chrono::milliseconds; using faiss::idx_t; using faiss::MetricType; @@ -79,28 +83,27 @@ constexpr int DEFAULT_K = 1; ///< Default number of constexpr int DEFAULT_NPROBE = 1; ///< Default number of partitions to probe during search. constexpr float DEFAULT_RECALL_TARGET = -1.0f; ///< Default recall target (a negative value means no adaptive search). constexpr bool DEFAULT_BATCHED_SCAN = false; ///< Default flag for batched scanning. -constexpr bool DEFAULT_PRECOMPUTED = false; ///< Default flag to use precomputed incomplete beta fn for APS. -constexpr float DEFAULT_INITIAL_SEARCH_FRACTION = 0.2f; ///< Default initial fraction of partitions to search. -constexpr float DEFAULT_RECOMPUTE_THRESHOLD = 0.01f; ///< Default threshold to trigger recomputation of search parameters. +constexpr bool DEFAULT_PRECOMPUTED = true; ///< Default flag to use precomputed incomplete beta fn for APS. +constexpr float DEFAULT_INITIAL_SEARCH_FRACTION = 0.02f; ///< Default initial fraction of partitions to search. +constexpr float DEFAULT_RECOMPUTE_THRESHOLD = 0.001f; ///< Default threshold to trigger recomputation of search parameters. constexpr int DEFAULT_APS_FLUSH_PERIOD_US = 100; ///< Default period (in microseconds) for flushing the APS buffer. constexpr int DEFAULT_PRICE_THRESHOLD = INT_MAX; // Default constants for maintenance policy parameters constexpr const char* DEFAULT_MAINTENANCE_POLICY = "query_cost"; ///< Default maintenance policy type. constexpr int DEFAULT_WINDOW_SIZE = 1000; ///< Default window size for measuring hit rates. -constexpr int DEFAULT_REFINEMENT_RADIUS = 100; ///< Default radius for local partition refinement. +constexpr int DEFAULT_REFINEMENT_RADIUS = 25; ///< Default radius for local partition refinement. constexpr int DEFAULT_REFINEMENT_ITERATIONS = 3; ///< Default number of iterations for refinement. constexpr int DEFAULT_MIN_PARTITION_SIZE = 32; ///< Default minimum allowed partition size. constexpr float DEFAULT_ALPHA = 0.9f; ///< Default alpha parameter for maintenance. constexpr bool DEFAULT_ENABLE_SPLIT_REJECTION = true; ///< Default flag to enable rejection of splits. constexpr bool DEFAULT_ENABLE_DELETE_REJECTION = true; ///< Default flag to enable rejection of deletions. -constexpr float DEFAULT_DELETE_THRESHOLD_NS = 20.0f; ///< Default threshold in nanoseconds for deletion decisions. -constexpr float DEFAULT_SPLIT_THRESHOLD_NS = 20.0f; ///< Default threshold in nanoseconds for split decisions. -constexpr int DEFAULT_K_LARGE = 50; ///< Default "large" k value (used in de-drift maintenance). -constexpr int DEFAULT_K_SMALL = 50; ///< Default "small" k value (used in de-drift maintenance). -constexpr bool DEFAULT_MODIFY_CENTROIDS = true; ///< Default flag to modify centroids during maintenance. -constexpr int DEFAULT_TARGET_PARTITION_SIZE = 1000; ///< Default target partition size. -constexpr float DEFAULT_MAX_PARTITION_RATIO = 2.0f; ///< Default maximum allowed partition ratio. +constexpr float DEFAULT_DELETE_THRESHOLD_NS = 10.0f; ///< Default threshold in nanoseconds for deletion decisions. +constexpr float DEFAULT_SPLIT_THRESHOLD_NS = 10.0f; ///< Default threshold in nanoseconds for split decisions. + +const vector DEFAULT_LATENCY_ESTIMATOR_RANGE_N = {1, 2, 4, 16, 64, 256, 1024, 4096, 16384, 65536}; ///< Default range of n values for latency estimator. +const vector DEFAULT_LATENCY_ESTIMATOR_RANGE_K = {1, 4, 16, 64, 256}; ///< Default range of k values for latency estimator. +constexpr int DEFAULT_LATENCY_ESTIMATOR_NTRIALS = 5; ///< Default number of trials for latency estimator. // macros #define DEBUG_PRINT(x) std::cout << #x << " = " << x << std::endl; @@ -118,15 +121,6 @@ struct MaintenancePolicyParams { float delete_threshold_ns = DEFAULT_DELETE_THRESHOLD_NS; float split_threshold_ns = DEFAULT_SPLIT_THRESHOLD_NS; - // de-drift parameters - int k_large = DEFAULT_K_LARGE; - int k_small = DEFAULT_K_SMALL; - bool modify_centroids = DEFAULT_MODIFY_CENTROIDS; - - // lire parameters - int target_partition_size = DEFAULT_TARGET_PARTITION_SIZE; - float max_partition_ratio = DEFAULT_MAX_PARTITION_RATIO; - MaintenancePolicyParams() = default; }; @@ -190,7 +184,7 @@ struct SearchParams { int nprobe = DEFAULT_NPROBE; int k = DEFAULT_K; float recall_target = DEFAULT_RECALL_TARGET; - int num_threads = -1; // number of threads to use for search within a single worker + int num_threads = 1; // number of threads to use for search within a single worker float k_factor = 1.0f; bool use_precomputed = DEFAULT_PRECOMPUTED; bool batched_scan = DEFAULT_BATCHED_SCAN; diff --git a/src/cpp/include/dynamic_inverted_list.h b/src/cpp/include/dynamic_inverted_list.h index 2490cc68..a94d34c5 100644 --- a/src/cpp/include/dynamic_inverted_list.h +++ b/src/cpp/include/dynamic_inverted_list.h @@ -28,9 +28,9 @@ namespace faiss { int curr_list_id_ = 0; ///< Next available partition ID. int total_numa_nodes_ = 0; ///< Total NUMA nodes available. int next_numa_node_ = 0; ///< Next NUMA node to use (for round-robin allocation). - unordered_map> partitions_; ///< Map of partition ID to IndexPartition. int d_; ///< Dimensionality of the vectors (derived from code_size). int code_size_; ///< Size in bytes of each vector code. + unordered_map> partitions_; ///< Map of partition ID to IndexPartition. /** * @brief Constructor for DynamicInvertedLists. @@ -238,6 +238,14 @@ namespace faiss { */ bool get_vector_for_id(idx_t id, float* vector_values); + /** + * @brief No-copy retrieve vectors by their IDs. + * + * @param ids Vector of IDs to retrieve. + * @return Vector of pointers to the encoded vectors. + */ + vector get_vectors_by_id(vector ids); + /** * @brief Generate and return a new partition ID. * diff --git a/src/cpp/include/geometry.h b/src/cpp/include/geometry.h index f65268a5..e16eafc6 100644 --- a/src/cpp/include/geometry.h +++ b/src/cpp/include/geometry.h @@ -54,43 +54,51 @@ inline void print_array(const float *array, int dimension) { std::cout << std::endl << std::endl; } -inline Tensor compute_boundary_distances(const Tensor &query, const Tensor ¢roids, bool euclidean = true) { - Tensor nearest_centroid = centroids[0]; +inline vector compute_boundary_distances(const Tensor &query, vector centroids, bool euclidean = true) { + + auto start = std::chrono::high_resolution_clock::now(); int dimension = query.size(0); - std::vector boundary_distances(centroids.size(0), -1.0f); + std::vector boundary_distances(centroids.size(), -1.0f); const float *query_ptr = query.data_ptr(); - const float *nearest_centroid_ptr = nearest_centroid.data_ptr(); - const float *centroids_ptr = centroids.data_ptr(); + const float *nearest_centroid_ptr = centroids[0]; - std::vector line_vector(dimension); - std::vector midpoint(dimension); - std::vector projection(dimension); + vector line_vector(dimension); + vector midpoint(dimension); + vector residual(dimension); + + auto end = std::chrono::high_resolution_clock::now(); // used for euclidean distance if (euclidean) { - Tensor residual = query - nearest_centroid; - const float *residual_ptr = residual.data_ptr(); - for (int j = 1; j < centroids.size(0); j++) { - subtract_arrays(centroids_ptr + (dimension * j), nearest_centroid_ptr, line_vector.data(), dimension); - divide_array_by_constant(line_vector.data(), 2.0f, midpoint.data(), dimension); - - float norm = std::sqrt(faiss::fvec_inner_product(line_vector.data(), line_vector.data(), dimension)); - divide_array_by_constant(line_vector.data(), norm, line_vector.data(), dimension); - - float projected_distance = faiss::fvec_inner_product(residual_ptr, line_vector.data(), dimension); - multiply_array_by_constant(line_vector.data(), projected_distance, projection.data(), dimension); - - float distance_to_boundary_squared = faiss::fvec_L2sqr(midpoint.data(), projection.data(), dimension); - boundary_distances[j] = std::sqrt(distance_to_boundary_squared); + // Compute residual: r = q - c0. + faiss::fvec_sub(dimension, query_ptr, nearest_centroid_ptr, residual.data()); + + // For each centroid j (starting at index 1). + for (int j = 1; j < centroids.size(); j++) { + // Compute v = c_j - c0. + const float* c_j = centroids[j]; + faiss::fvec_sub(dimension, c_j, nearest_centroid_ptr, line_vector.data()); + + // Compute squared norm: A2 = ||v||^2. + float A2 = faiss::fvec_inner_product(line_vector.data(), line_vector.data(), dimension); + float A = std::sqrt(A2); // Guaranteed nonzero. + + // Compute dot product: dot = . + float dot_val = faiss::fvec_inner_product(residual.data(), line_vector.data(), dimension); + + // Instead of computing dot_val/A and 0.5*A separately, + // we compute: d = |dot_val - 0.5 * A2| / A. + float d = std::fabs(dot_val - 0.5f * A2) / A; + boundary_distances[j] = d; } } else { // for dot product distance float residual_angle = faiss::fvec_inner_product(query_ptr, nearest_centroid_ptr, dimension); - for (int j = 1; j < centroids.size(0); j++) { + for (int j = 1; j < centroids.size(); j++) { // get angle of the bisector using dot product - subtract_arrays(centroids_ptr + (dimension * j), nearest_centroid_ptr, line_vector.data(), dimension); + subtract_arrays(centroids[j], nearest_centroid_ptr, line_vector.data(), dimension); divide_array_by_constant(line_vector.data(), 2.0f, midpoint.data(), dimension); add_arrays(nearest_centroid_ptr, midpoint.data(), midpoint.data(), dimension); float norm = faiss::fvec_inner_product(midpoint.data(), midpoint.data(), dimension); @@ -101,7 +109,7 @@ inline Tensor compute_boundary_distances(const Tensor &query, const Tensor ¢ } } - return torch::tensor(boundary_distances).clone(); + return boundary_distances; } inline double incomplete_beta(double a, double b, double x) { @@ -334,24 +342,23 @@ inline Tensor compute_variance_in_direction_of_query(Tensor query, Tensor centro return torch::tensor(variances).clone(); } -inline Tensor compute_recall_profile(const Tensor &boundary_distances, float query_radius, int dimension, - const Tensor &partition_sizes = {}, bool use_precomputed = true, +inline vector compute_recall_profile(vector boundary_distances, float query_radius, int dimension, + vector partition_sizes = {}, bool use_precomputed = true, bool euclidean = true) { // boundary_distances shape is (num_partitions,) and num_partitions must be greater than 1 - if (boundary_distances.size(0) < 2) { + if (boundary_distances.size() < 2) { throw std::runtime_error("Boundary distances must have at least 2 partitions to create an estimate."); } - auto boundary_distances_ptr = boundary_distances.data_ptr(); - int num_partitions = boundary_distances.size(0); - std::vector partition_probabilities(num_partitions, 0.0f); + int num_partitions = boundary_distances.size(); + vector partition_probabilities(num_partitions, 0.0f); double total_volume = 0.0; - bool weigh_using_partition_sizes = partition_sizes.defined(); + bool weigh_using_partition_sizes = partition_sizes.size() == num_partitions; for (int j = 1; j < num_partitions; j++) { - float boundary_distance = boundary_distances_ptr[j]; + float boundary_distance = boundary_distances[j]; if (boundary_distance >= query_radius) { partition_probabilities[j] = 0.0; @@ -372,28 +379,31 @@ inline Tensor compute_recall_profile(const Tensor &boundary_distances, float que partition_probabilities[0] = 2.0 * partition_probabilities[1]; // partition_probabilities[0] = 1 - partition_probabilities[1]; - Tensor probabilities_tensor = torch::from_blob(partition_probabilities.data(), - {num_partitions}, - torch::kDouble). - clone(); - - if (weigh_using_partition_sizes) { - probabilities_tensor *= partition_sizes; - } + // if (weigh_using_partition_sizes) { + // for (int j = 0; j < num_partitions; j++) { + // partition_probabilities[j] *= partition_sizes[j]; + // } + // } // Ensure the probabilities sum to 1 - double sum_probabilities = probabilities_tensor.sum().item(); + double sum_probabilities = 0.0; + for (int j = 0; j < num_partitions; j++) { + sum_probabilities += partition_probabilities[j]; + } if (sum_probabilities > 0.0f) { - probabilities_tensor /= sum_probabilities; + for (int j = 0; j < num_partitions; j++) { + partition_probabilities[j] /= sum_probabilities; + } } else { - probabilities_tensor.fill_(0.0); - probabilities_tensor[0] = 1.0; + for (int j = 0; j < num_partitions; j++) { + partition_probabilities[j] = 1.0 / num_partitions; + } } // Compute the recall profile // Tensor recall_profile = torch::cumsum(probabilities_tensor, 0); - return probabilities_tensor; + return partition_probabilities; } inline float compute_intersection_volume_one(float boundary_distance, float query_radius, int dimension) { diff --git a/src/cpp/include/hit_count_tracker.h b/src/cpp/include/hit_count_tracker.h new file mode 100644 index 00000000..c12bd62a --- /dev/null +++ b/src/cpp/include/hit_count_tracker.h @@ -0,0 +1,114 @@ +// +// Created by Jason on 3/13/25. +// Prompt for GitHub Copilot: +// - Conform to the google style guide +// - Use descriptive variable names + +#ifndef HIT_COUNT_TRACKER_H +#define HIT_COUNT_TRACKER_H + +#include +#include +#include +#include +#include +#include + + +/** + * @brief HitCountTracker maintains per-query hit counts and scanned partition sizes in a sliding window. + **/ +class HitCountTracker { +public: + /** + * @brief Constructs a HitCountTracker. + * + * @param window_size Number of queries to maintain in the sliding window. + * @param total_vectors Total number of vectors in the index (used for computing the scan fraction). + */ + HitCountTracker(int window_size, int total_vectors); + + /** + * @brief Resets the tracker by clearing all recorded query data. + */ + void reset(); + + /** + * @brief Sets the total number of vectors in the index. + * + * @param total_vectors The new total vector count. + */ + void set_total_vectors(int total_vectors); + + /** + * @brief Adds per-query data. + * + * Records the partition IDs that were hit and their corresponding scanned sizes for a query. + * Both vectors must have the same length. + * + * @param hit_partition_ids Vector of partition IDs hit during the query. + * @param scanned_sizes Vector of scanned sizes corresponding to each partition hit. + */ + void add_query_data(const vector& hit_partition_ids, const vector& scanned_sizes); + + /** + * @brief Retrieves the current scan fraction averaged over the sliding window. + * + * @return The current scan fraction. + */ + float get_current_scan_fraction() const; + + /** + * @brief Retrieves the stored per-query hit counts. + * + * @return A constant reference to the vector containing per-query hit counts. + */ + const vector>& get_per_query_hits() const; + + /** + * @brief Retrieves the stored per-query scanned partition sizes. + * + * @return A constant reference to the vector containing per-query scanned sizes. + */ + const vector>& get_per_query_scanned_sizes() const; + + /** + * @brief Returns the sliding window size. + * + * @return The window size. + */ + int get_window_size() const; + + + /** + * @brief Returns the total number of queries recorded so far. + * + * @return The number of queries recorded. + */ + int64_t get_num_queries_recorded() const; + +private: + int window_size_; + int64_t total_vectors_; + int64_t curr_query_index_; // Points to the next slot to overwrite in the circular window. + int64_t num_queries_recorded_; // Total queries recorded so far (up to window_size_) + + vector> per_query_hits_; + vector> per_query_scanned_sizes_; + + // Running sum of the scan fractions in the current window. + float running_sum_scan_fraction_; + float current_scan_fraction_; + + /** + * @brief Computes the scan fraction for a query. + * + * The scan fraction is calculated as the sum of the scanned sizes divided by the total number of vectors. + * + * @param scanned_sizes Vector of scanned partition sizes for one query. + * @return The computed scan fraction. + */ + float compute_scan_fraction(const vector& scanned_sizes) const; +}; + +#endif // HIT_COUNT_TRACKER_H diff --git a/src/cpp/include/index_partition.h b/src/cpp/include/index_partition.h index e67988a6..98e79c65 100644 --- a/src/cpp/include/index_partition.h +++ b/src/cpp/include/index_partition.h @@ -19,7 +19,7 @@ class IndexPartition { public: int numa_node_ = -1; ///< Assigned NUMA node (-1 if not set) - int thread_id_ = -1; ///< Mapped thread ID for processing + int core_id_ = -1; ///< Mapped thread ID for processing int64_t buffer_size_ = 0; ///< Allocated capacity (in number of vectors) int64_t num_vectors_ = 0; ///< Current number of stored vectors @@ -29,6 +29,8 @@ class IndexPartition { idx_t* ids_ = nullptr; ///< Pointer to the vector IDs std::shared_ptr attributes_table_ = {}; + std::unordered_map id_to_index_; ///< Map of vector ID to index + /// Default constructor. IndexPartition() = default; @@ -155,6 +157,8 @@ class IndexPartition { */ void reallocate_memory(int64_t new_capacity); + void set_core_id(int core_id); + #ifdef QUAKE_USE_NUMA /** * @brief Set the NUMA node for the partition. diff --git a/src/cpp/include/latency_estimation.h b/src/cpp/include/latency_estimation.h deleted file mode 100644 index 2ed9b487..00000000 --- a/src/cpp/include/latency_estimation.h +++ /dev/null @@ -1,77 +0,0 @@ -// -// Created by Jason on 12/16/24. -// Prompt for GitHub Copilot: -// - Conform to the google style guide -// - Use descriptive variable names - -#ifndef LATENCY_ESTIMATION_H -#define LATENCY_ESTIMATION_H - -#include -#include - -// 2D function that estimates the scan latency of a list given it's size and the number of elements to retrieve -// l(n, k) = latency -// function is a linear interpolation of measured scan latency for different list sizes and k -// for points beyond the measured range, the function is extrapolated using the slope of the last two points -// default grid: n = [16, 64, 256, 1024, 4096, 16384, 65536, 262144, 1048576], k = [1, 4, 16, 64, 256, 1024] -class ListScanLatencyEstimator { -public: - // Constructor attempts to load latency profile from disk if the file path is - // provided. If loading fails (file not found or grid mismatch), it performs - // the profile_scan_latency() and saves to file. - ListScanLatencyEstimator(int d, - const std::vector &n_values, - const std::vector &k_values, - int n_trials = 100, - bool adaptive_nprobe = false, - const std::string &profile_filename = ""); - - // Profiles the scan latency and populates scan_latency_model_. - // This is expensive and should typically be called only once. - void profile_scan_latency(); - - // Estimates the scan latency for given n and k. - float estimate_scan_latency(int n, int k) const; - - // Setter for n_trials_. - void set_n_trials(int n_trials) { - n_trials_ = n_trials; - } - - // Saves the internally profiled latency model to a CSV file. - // Returns true on success, false otherwise. - bool save_latency_profile(const std::string &filename) const; - - // Loads an existing latency profile from a CSV file. - // Returns true on success, false otherwise. - bool load_latency_profile(const std::string &filename); - - // Public members for convenience/access. - int d_; - std::vector n_values_; - std::vector k_values_; - std::vector > scan_latency_model_; - int n_trials_; - -private: - // Helper function for interpolation (not used directly in this code, but - // shown as an example of how you might handle it). - bool get_interpolation_info(const std::vector &values, - int target, - int &lower, - int &upper, - float &frac) const; - - // Helper function to do linear extrapolation in 1D. - inline float linear_extrapolate(float f1, float f2, float fraction) const { - float slope = f2 - f1; - return f2 + slope * fraction; - } - - // The name of the CSV file to load/save from. Empty means "don't load/save." - std::string profile_filename_; -}; - - -#endif //LATENCY_ESTIMATION_H diff --git a/src/cpp/include/list_scanning.h b/src/cpp/include/list_scanning.h index 53d54dfe..2e9b130e 100644 --- a/src/cpp/include/list_scanning.h +++ b/src/cpp/include/list_scanning.h @@ -111,6 +111,7 @@ class TypedTopKBuffer { topk_[i] = { std::numeric_limits::max(), -1 }; } } + partitions_scanned_.store(0, std::memory_order_relaxed); } void add(DistanceType distance, IdType index) { @@ -236,7 +237,7 @@ inline std::tuple buffers_to_tensor(vector> create_buffers(int n, int k, bool is_descending) { vector> buffers(n); for (int i = 0; i < n; i++) { - buffers[i] = make_shared(k, is_descending, 4 * k); + buffers[i] = make_shared(k, is_descending, 10 * k); } return buffers; } diff --git a/src/cpp/include/maintenance_cost_estimator.h b/src/cpp/include/maintenance_cost_estimator.h new file mode 100644 index 00000000..abebdd53 --- /dev/null +++ b/src/cpp/include/maintenance_cost_estimator.h @@ -0,0 +1,219 @@ +// +// Created by Jason on 3/13/25. +// Prompt for GitHub Copilot: +// - Conform to the google style guide +// - Use descriptive variable names + +#ifndef MAINTENANCE_COST_ESTIMATOR_H +#define MAINTENANCE_COST_ESTIMATOR_H + +#include +#include +#include + +using std::vector; +using std::shared_ptr; + +/** + * @brief Estimates the scan latency for a list based on its size and the number of elements to retrieve. + * + * The latency function l(n, k) is determined via linear interpolation of measured scan latencies for + * different list sizes (n) and retrieval counts (k). For points outside the measured grid, the function + * extrapolates using the slope between the last two points. + */ +class ListScanLatencyEstimator { +public: + /** + * @brief Constructor. + * + * Attempts to load a latency profile from disk if a file path is provided. If loading fails + * (e.g. file not found or grid mismatch), the profile_scan_latency() method will be invoked + * to compute and save the latency profile. + * + * @param d Dimension of the vectors. + * @param n_values Vector of n-values for the grid. + * @param k_values Vector of k-values for the grid. + * @param n_trials Number of trials used for profiling (default: 100). + * @param adaptive_nprobe Flag indicating whether to use adaptive nprobe (default: false). + * @param profile_filename Optional CSV file name for loading/saving the profile (default: empty string). + */ + ListScanLatencyEstimator(int d, + const std::vector &n_values, + const std::vector &k_values, + int n_trials = 100, + bool adaptive_nprobe = false, + const std::string &profile_filename = ""); + + /** + * @brief Profiles the scan latency and populates the latency model. + * + * This operation is expensive and should typically be executed only once. + */ + void profile_scan_latency(); + + /** + * @brief Estimates the scan latency for a given list size and retrieval count. + * + * @param n List size. + * @param k Number of elements to retrieve. + * @return Estimated latency as a float. + */ + float estimate_scan_latency(int n, int k) const; + + + /** + * @brief Sets the number of trials to use for latency estimation. + * + * @param n_trials New number of trials. + */ + void set_n_trials(int n_trials) { + n_trials_ = n_trials; + } + + /** + * @brief Saves the internally profiled latency model to a CSV file. + * + * @param filename File path for saving the latency profile. + * @return True if saving is successful; false otherwise. + */ + bool save_latency_profile(const std::string &filename) const; + + /** + * @brief Loads an existing latency profile from a CSV file. + * + * @param filename File path to load the latency profile from. + * @return True if loading is successful; false otherwise. + */ + bool load_latency_profile(const std::string &filename); + + // Public members for convenience/access. + int d_; + std::vector n_values_; + std::vector k_values_; + std::vector > scan_latency_model_; + int n_trials_; + +private: + /** + * @brief Helper function for interpolation. + * + * This function is provided as an example of how you might implement interpolation logic. + * + * @param values The grid values. + * @param target The target value. + * @param lower (Output) Lower index. + * @param upper (Output) Upper index. + * @param frac (Output) Fractional interpolation value. + * @return True if successful; false otherwise. + */ + bool get_interpolation_info(const std::vector &values, + int target, + int &lower, + int &upper, + float &frac) const; + + /** + * @brief Performs linear extrapolation between two float values. + * + * @param f1 First value. + * @param f2 Second value. + * @param fraction Fractional distance between f1 and f2. + * @return Extrapolated value. + */ + inline float linear_extrapolate(float f1, float f2, float fraction) const { + float slope = f2 - f1; + return f2 + slope * fraction; + } + + /// @brief CSV file name for loading/saving the latency profile. + /// An empty string means no file I/O will be attempted. + std::string profile_filename_; +}; + + +/** + * @brief Computes cost deltas for maintenance actions (e.g., splitting or deleting partitions) + * using a latency estimation model. + * + * The MaintenanceCostEstimator uses a ListScanLatencyEstimator along with parameters such as alpha + * and k to compute the difference in latency (cost) that would result from a maintenance operation. + */ +class MaintenanceCostEstimator { +public: + /** + * @brief Constructor. + * + * @param d Dimension of the vectors. + * @param alpha Alpha parameter used to scale the cost for splitting. + * @param k Parameter used in latency estimation. + * @throws std::invalid_argument if k is non-positive or alpha is non-positive. + */ + MaintenanceCostEstimator(int d, float alpha, int k); + + /** + * @brief Computes the delta cost for splitting a partition. + * + * The computed delta represents the difference between the new cost after splitting + * (assuming an even split) and the original cost, plus the structural overhead of adding one partition. + * + * @param partition_size Size of the partition to split. + * @param hit_rate Hit rate (fraction) for the partition. + * @param total_partitions Total number of partitions before the split. + * @return The computed split delta. + */ + float compute_split_delta(int partition_size, float hit_rate, int total_partitions) const; + + /** + * @brief Computes the delta cost for deleting a partition. + * + * This function estimates the change in latency if a partition were deleted and its + * vectors redistributed across the remaining partitions. + * + * @param partition_size Size of the partition to delete. + * @param hit_rate Hit rate (fraction) for the partition. + * @param total_partitions Total number of partitions before deletion. + * @param avg_partition_hit_rate Average hit rate across all partitions. + * @param avg_partition_size Average partition size. + * @return The computed delete delta. + */ + float compute_delete_delta(int partition_size, float hit_rate, int total_partitions, float avg_partition_hit_rate, float avg_partition_size) const; + + /** + * @brief Computes the delete delta using reassignment information. + * + * This version considers the cost impact of reassigning vectors from the deleted partition + * to other partitions. + * + * @param partition_size Size of the partition to delete. + * @param hit_rate Hit rate (fraction) for the partition. + * @param total_partitions Total number of partitions before deletion. + * @param reassign_counts Vector containing the number of vectors reassigned to each partition. + * @param reassign_sizes Vector containing the sizes of the partitions to which vectors are reassigned. + * @param reassign_hit_rates Vector containing the hit rates of the partitions to which vectors are reassigned. + * @return The computed delete delta. + */ + float compute_delete_delta_w_reassign(int partition_size, float hit_rate, int total_partitions, const vector &reassign_counts, const vector &reassign_sizes, const vector &reassign_hit_rates) const; + + /** + * @brief Returns the latency estimator. + * + * @return A shared pointer to the ListScanLatencyEstimator. + */ + shared_ptr get_latency_estimator() const; + + + /** + * @brief Returns the parameter k used in latency estimation. + * + * @return The k value. + */ + int get_k() const; + +private: + float alpha_; + int k_; + int d_; + shared_ptr latency_estimator_; +}; + +#endif // MAINTENANCE_COST_ESTIMATOR_H diff --git a/src/cpp/include/maintenance_policies.h b/src/cpp/include/maintenance_policies.h index 447b7d75..6ee68075 100644 --- a/src/cpp/include/maintenance_policies.h +++ b/src/cpp/include/maintenance_policies.h @@ -1,146 +1,63 @@ -// -// Created by Jason on 9/20/24. -// Prompt for GitHub Copilot: -// - Conform to the google style guide -// - Use descriptive variable names - -#ifndef MAINTENANCE_POLICIES_H -#define MAINTENANCE_POLICIES_H - -#include -#include -#include - -class PartitionManager; - -struct PartitionState { - vector partition_ids; - vector partition_sizes; - vector partition_hit_rate; - float current_scan_fraction; - float alpha_estimate; -}; - +#ifndef MAINTENANCE_POLICY_REFACTORED_H +#define MAINTENANCE_POLICY_REFACTORED_H + +#include +#include +#include +#include + +#include "partition_manager.h" +#include "hit_count_tracker.h" +#include "maintenance_cost_estimator.h" + +/** + * @brief Maintenance policy that manages partition hit counts and + * performs maintenance operations (such as deletion and splitting) in a single pass. + * + */ class MaintenancePolicy { -public: - int curr_query_id_; - vector > per_query_hits_; - vector > per_query_scanned_partitions_sizes_; - float running_sum_scan_fraction_; - float current_scan_fraction_; - std::string maintenance_policy_name_; - - std::shared_ptr partition_manager_; - std::unordered_map per_partition_hits_; - std::unordered_map ancestor_partition_hits_; - std::unordered_map> split_records_; - - std::unordered_map deleted_partition_hit_rate_; - - std::unordered_set modified_partitions_; - - // parameters - int window_size_ = 2500; - int refinement_radius_ = 25; - int refinement_iterations_ = 3; - int min_partition_size_ = 32; - float alpha_ = .9; - bool enable_split_rejection_ = true; - bool enable_delete_rejection_ = true; - float delete_threshold_ns_ = 20.0; - float split_threshold_ns_ = 20.0; - bool debug_ = false; - - // latency estimator - std::vector latency_grid_n_values_ = {1, 2, 4, 16, 64, 256, 1024, 4096, 16384, 65536}; - std::vector latency_grid_k_values_ = {1, 4, 16, 64, 256}; - int n_trials_ = 5; - - std::shared_ptr latency_estimator_ = nullptr; - std::shared_ptr latency_estimator_adaptive_nprobe = nullptr; - - vector> get_split_history(); - - shared_ptr get_partition_state(bool only_modified = true); - - void set_partition_modified(int64_t partition_id); - - void set_partition_unmodified(int64_t partition_id); - - vector estimate_split_delta(shared_ptr state); - - vector estimate_delete_delta(shared_ptr state); - - float estimate_add_level_delta(); - - float estimate_remove_level_delta(); - - void decrement_hit_count(int64_t partition_id); - - void increment_hit_count(vector hit_partition_ids); - - void refine_partitions(Tensor partition_ids, int refinement_iterations); - - virtual void local_refinement(Tensor partition_ids, int refinement_radius) {} - - virtual Tensor check_and_delete_partitions() { return {}; } - - virtual std::tuple check_and_split_partitions() { return {}; } - - virtual shared_ptr maintenance(); - - void add_split(int64_t old_partition_id, int64_t left_partition_id, int64_t right_partition_id); - - void add_partition(int64_t partition_id, int64_t hits); - - void remove_partition(int64_t partition_id); -}; - -class QueryCostMaintenance : public MaintenancePolicy { -public: - QueryCostMaintenance(std::shared_ptr partition_manager, shared_ptr params = nullptr); - - float compute_alpha_for_window(); - - void local_refinement(Tensor partition_ids, int refinement_radius) override; - - Tensor check_and_delete_partitions() override; - - std::tuple check_and_split_partitions() override; + public: + /** + * @brief Construct a new MaintenancePolicy object. + * + * @param partition_manager Shared pointer to the partition manager. + * @param params Configuration parameters for the maintenance policy. + */ + MaintenancePolicy( + shared_ptr partition_manager, + shared_ptr params); + + /** + * @brief Perform maintenance operations including deletion and splitting. + * + * @return MaintenanceTimingInfo with timing details. + */ + shared_ptr perform_maintenance(); + + /** + * @brief Record a hit event for a given partition. + * + * @param partition_id Identifier of the partition. + */ + void record_query_hits(vector partition_ids); + + /** + * @brief Reset the internal maintenance state. + */ + void reset(); + + private: + shared_ptr partition_manager_; ///< Manages partition state. + shared_ptr params_; ///< Maintenance parameters. + shared_ptr cost_estimator_; ///< Cost estimator for maintenance actions. + shared_ptr hit_count_tracker_; ///< Hit count tracker for partition hit rates. + + /** + * @brief Perform local refinement on a set of partition IDs. + * + * @param partition_ids Tensor of partition IDs. + */ + void local_refinement(const Tensor& partition_ids); }; -class LireMaintenance : public MaintenancePolicy { -public: - int target_partition_size_; - float max_partition_ratio_; - int min_partition_size_; - - LireMaintenance(std::shared_ptr partition_manager, int target_partition_size, float max_partition_ratio, int min_partition_size) : target_partition_size_(target_partition_size), max_partition_ratio_(max_partition_ratio), min_partition_size_(min_partition_size) { - maintenance_policy_name_ = "lire"; - partition_manager_ = partition_manager; - refinement_iterations_ = 0; - } - - Tensor check_and_delete_partitions() override; - - std::tuple check_and_split_partitions() override; -}; - - -class DeDriftMaintenance : public MaintenancePolicy { -public: - int k_large_; - int k_small_; - bool modify_centroids_; - - DeDriftMaintenance(std::shared_ptr partition_manager, int k_large, int k_small, bool modify_centroids) : k_large_(k_large), k_small_(k_small), modify_centroids_(modify_centroids) { - maintenance_policy_name_ = "dedrift"; - partition_manager_ = partition_manager; - } - - shared_ptr maintenance() override; -}; - - - -#endif //MAINTENANCE_POLICIES_H +#endif // MAINTENANCE_POLICY_REFACTORED_H \ No newline at end of file diff --git a/src/cpp/include/parallel.h b/src/cpp/include/parallel.h index 2ab6a7f6..15f9c30d 100644 --- a/src/cpp/include/parallel.h +++ b/src/cpp/include/parallel.h @@ -9,7 +9,34 @@ #include #include -#include + +#ifdef __linux__ +#include +#include +#include +#include + + +inline bool set_affinity_linux(int core_id) { + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(core_id, &cpuset); + int ret = pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset); + return ret == 0; +} + +#endif + +inline bool set_thread_affinity(int core_id) { +#ifdef __APPLE__ + return false; // Not supported on macOS +#elif defined(__linux__) + return set_affinity_linux(core_id); +#else + std::cerr << "Platform not supported for setting thread affinity" << std::endl; + return false; +#endif +} template void parallel_for(IndexType start, IndexType end, Function func, int num_threads = -1) { diff --git a/src/cpp/include/partition_manager.h b/src/cpp/include/partition_manager.h index f4ba0d80..9764aeb8 100644 --- a/src/cpp/include/partition_manager.h +++ b/src/cpp/include/partition_manager.h @@ -25,7 +25,7 @@ class QuakeIndex; class PartitionManager { public: shared_ptr parent_ = nullptr; ///< Pointer to a higher-level parent index. - std::shared_ptr partitions_ = nullptr; ///< Pointer to the inverted lists. + std::shared_ptr partition_store_ = nullptr; ///< Pointer to the inverted lists. int64_t curr_partition_id_ = 0; ///< Current partition ID. bool debug_ = false; ///< If true, print debug information. @@ -79,9 +79,15 @@ class PartitionManager { */ Tensor get(const Tensor &ids); + /** + * @brief No copy version of get + * @param ids Vector of IDs. + */ + vector get_vectors(vector ids); + /** * @brief Split a given partition into multiple smaller ones. - * @param partition_id The ID/index of the partition to split. + * @param partition_ids The partition IDs to split. */ shared_ptr split_partitions(const Tensor &partition_ids); @@ -90,7 +96,7 @@ class PartitionManager { * @param partition_ids Tensor of shape [num_partitions] containing partition IDs. If empty, refines all partitions. * @param refinement_iterations Number of refinement iterations. If 0, then only reassigns vectors. */ - void refine_partitions(const Tensor &partition_ids = Tensor(), int refinement_iterations = 0); + void refine_partitions(Tensor partition_ids = Tensor(), int refinement_iterations = 0); /** * @brief Delete multiple partitions and reassign vectors @@ -112,18 +118,24 @@ class PartitionManager { */ shared_ptr select_partitions(const Tensor &partition_ids, bool copy = false); - /** - * @brief Randomly breaks up the single partition into multiple partitions and distributes the partitions. Only applicable for flat indexes. - * @param n_partitions The number of partitions to split the single partition into. - */ - void distribute_flat(int n_partitions); - /** * @brief Distribute the partitions across multiple workers. * @param num_workers The number of workers to distribute the partitions across. */ void distribute_partitions(int num_workers); + /** + * @brief Set the core ID for a given partition. + * @param partition_id The ID of the partition. + */ + void set_partition_core_id(int64_t partition_id, int core_id); + + /** + * @brief Return the core ID for a given partition. + * @param partition_id The ID of the partition. + */ + int get_partition_core_id(int64_t partition_id); + /** * @brief Return total number of vectors across all partitions. */ @@ -145,6 +157,18 @@ class PartitionManager { */ Tensor get_partition_sizes(Tensor partition_ids = Tensor()); + /** + * @brief Get the partition size. + * @param partition_ids Vector of partition IDs. + */ + vector get_partition_sizes(vector partition_ids); + + /** + * @brief Get the partition size. + * @param partition_id The ID of the partition. + */ + int64_t get_partition_size(int64_t partition_id); + /** * @brief Get the partition IDs. */ diff --git a/src/cpp/include/query_coordinator.h b/src/cpp/include/query_coordinator.h index ebe6ee02..a3ae0dd5 100644 --- a/src/cpp/include/query_coordinator.h +++ b/src/cpp/include/query_coordinator.h @@ -15,62 +15,181 @@ class QuakeIndex; class PartitionManager; +/** + * @brief Structure representing a scan job. + * + * A ScanJob encapsulates all parameters required to perform a scan on a given index partition. + */ struct ScanJob { - int64_t partition_id; - int k; - const float* query_vector; // Pointer to the query vector - vector query_ids; - - bool is_batched = false; // false = single-query job, true = multi-query job - int64_t num_queries = 0; // number of queries in batched mode + int64_t partition_id; ///< The identifier of the partition to be scanned. + int k; ///< The number of neighbors (Top-K) to return. + const float* query_vector; ///< Pointer to the query vector. + vector query_ids; ///< Global query IDs; used in batched mode. + bool is_batched = false; ///< Indicates whether this is a batched query job. + int64_t num_queries = 0; ///< The number of queries in batched mode. + int rank = 0; ///< Rank of the partition }; +/** + * @brief The QueryCoordinator class. + * + * Distributes query scanning work across worker threads, aggregates results, + * and supports both parallel and serial scan modes. + */ class QueryCoordinator { public: - shared_ptr partition_manager_; - shared_ptr maintenance_policy_; - shared_ptr parent_; - MetricType metric_; - - vector worker_threads_; - int num_workers_; - bool workers_initialized_ = false; - vector> jobs_queue_; - std::unordered_map jobs_; - - // Top-K Buffers - vector> query_topk_buffers_; - - // Synchronization - std::mutex result_mutex_; - std::atomic stop_workers_; - - bool debug_ = false; + // Public member variables (for internal use) + shared_ptr partition_manager_; ///< Manager for partition assignments. + shared_ptr maintenance_policy_; ///< Policy for index maintenance. + shared_ptr parent_; ///< Pointer to the parent index. + MetricType metric_; ///< Distance metric for search queries. + + /** + * @brief Structure representing per-core resources. + * + * Each core maintains its own pool of Top‑K buffers, a local query buffer, and a dedicated job queue. + */ + struct CoreResources { + int core_id; ///< Logical identifier of the core. + vector> topk_buffer_pool; ///< Preallocated Top‑K buffers. + vector local_query_buffer; ///< Local aggregator for query results. + moodycamel::BlockingConcurrentQueue job_queue; ///< Job queue for scan jobs. + }; + vector core_resources_; ///< Per‑core resources for worker threads. + bool workers_initialized_ = false; ///< Flag indicating if worker threads are initialized. + int num_workers_; ///< Total number of worker threads. + vector worker_threads_; ///< Container for worker threads. + vector worker_job_counter_; ///< Job counters for each worker. + vector> global_topk_buffer_pool_; ///< Global aggregator buffers. + std::mutex global_mutex_; ///< Mutex for global synchronization. + std::condition_variable global_cv_; ///< Condition variable for thread coordination. + std::atomic stop_workers_; ///< Flag to signal workers to terminate. + bool debug_ = false; ///< Debug mode flag. + + vector>> job_flags_; ///< Flags to track job completion + std::atomic job_pull_time_ns = 0; ///< Time spent pulling jobs from the queue. + std::atomic job_process_time_ns = 0; ///< Time spent processing jobs. + + + /** + * @brief Constructs a QueryCoordinator. + * + * @param parent Shared pointer to the parent QuakeIndex. + * @param partition_manager Shared pointer to the PartitionManager. + * @param maintenance_policy Shared pointer to the MaintenancePolicy. + * @param metric Distance metric used in search operations. + * @param num_workers Number of worker threads to initialize (default is 0, where 0 means no parallelism). + */ QueryCoordinator(shared_ptr parent, shared_ptr partition_manager, shared_ptr maintenance_policy, MetricType metric, int num_workers=0); + /** + * @brief Destructor for QueryCoordinator. + * + * Cleans up resources and shuts down worker threads. + */ ~QueryCoordinator(); + /** + * @brief Initiates a search operation. + * + * Searches the parent first to determine the partitions to scan. Then calls scan_partitions to perform the scan. + * + * @param x Tensor containing the query vector(s). + * @param search_params Shared pointer to search parameters. + * @return Shared pointer to the final SearchResult. + */ shared_ptr search(Tensor x, shared_ptr search_params); + /** + * @brief Performs a scan on the specified partitions. + * + * Selects the appropriate scan method based on the search parameters and coordinator configuration. + * + * @param x Tensor containing the query vector(s). + * @param partition_ids Tensor with the list of partition IDs to scan. + * @param search_params Shared pointer to search parameters. + * @return Shared pointer to the aggregated SearchResult. + */ shared_ptr scan_partitions(Tensor x, Tensor partition_ids, shared_ptr search_params); + /** + * @brief Executes a serial scan over the provided partitions. + * + * Performs a non-parallel scan, processing partitions sequentially. + * + * @param x Tensor containing the query vector(s). + * @param partition_ids Tensor with the list of partition IDs to scan. + * @param search_params Shared pointer to search parameters. + * @return Shared pointer to the SearchResult. + */ shared_ptr serial_scan(Tensor x, Tensor partition_ids, shared_ptr search_params); + /** + * @brief Executes a batched serial scan for multiple queries. + * + * Groups queries by the partitions they need to scan and processes them in batches. + * + * @param x Tensor containing the query vector(s). + * @param partition_ids Tensor with the list of partition IDs to scan. + * @param search_params Shared pointer to search parameters. + * @return Shared pointer to the SearchResult. + */ shared_ptr batched_serial_scan(Tensor x, Tensor partition_ids, shared_ptr search_params); + /** + * @brief Initializes worker threads for parallel scanning. + * + * Spawns worker threads and allocates per-core resources for processing scan jobs. + * + * @param num_workers Number of worker threads to initialize. + */ void initialize_workers(int num_workers); + /** + * @brief Shuts down all worker threads. + * + * Signals each worker to terminate and waits for their completion. + */ void shutdown_workers(); + /** + * @brief Function executed by each worker thread. + * + * Processes scan jobs from the worker's job queue + * + * @param worker_id Identifier for the worker thread. + */ void partition_scan_worker_fn(int worker_id); + /** + * @brief Worker thread function to perform partition scanning. + * + * Processes scan jobs and returns the aggregated search result. + * + * @param x Tensor containing the query vector(s). + * @param partition_ids Tensor with the list of partition IDs to scan. + * @param search_params Shared pointer to search parameters. + * @return Shared pointer to the SearchResult. + */ shared_ptr worker_scan(Tensor x, Tensor partition_ids, shared_ptr search_params); +private: + /** + * @brief Allocates per-core resources. + * + * Sets up necessary buffers and job queues for a specific core. + * + * @param core_idx The index of the core. + * @param num_queries Number of queries to support. + * @param k Number of nearest neighbors (Top-K) to retrieve. + * @param d Dimensionality of the query vectors. + */ + void allocate_core_resources(int core_idx, int num_queries, int k, int d); }; #endif //QUERY_COORDINATOR_H diff --git a/src/cpp/src/clustering.cpp b/src/cpp/src/clustering.cpp index f4407c40..94142589 100644 --- a/src/cpp/src/clustering.cpp +++ b/src/cpp/src/clustering.cpp @@ -5,6 +5,8 @@ // - Use descriptive variable names #include "clustering.h" +#include "index_partition.h" +#include #include #include "faiss/Clustering.h" #include @@ -30,7 +32,7 @@ shared_ptr kmeans(Tensor vectors, int d = vectors.size(1); // Create a flat index appropriate to the metric. - faiss::IndexFlat* index_ptr = nullptr; + faiss::IndexFlat *index_ptr = nullptr; if (metric_type == faiss::METRIC_INNER_PRODUCT) index_ptr = new faiss::IndexFlatIP(d); else @@ -120,4 +122,89 @@ shared_ptr kmeans(Tensor vectors, delete index_ptr; return clustering; -} \ No newline at end of file +} + +tuple >> kmeans_refine_partitions( + Tensor centroids, + vector> partitions, + MetricType metric, + int refinement_iterations) { + + // Determine number of clusters and dimension. + int n_clusters = centroids.size(0); + int d = centroids.size(1); + + // Run for the desired number of iterations (if refinement_iterations==0, do one pass). + int iterations = (refinement_iterations > 0) ? refinement_iterations : 1; + + Tensor centroid_sums = torch::zeros_like(centroids); + Tensor centroid_counts = torch::zeros({n_clusters}, torch::kInt64); + auto centroid_sums_accessor = centroid_sums.accessor(); + auto centroid_counts_accessor = centroid_counts.accessor(); + + vector> prev_partitions = partitions; + vector> new_partitions; + + for (int iter = 0; iter < iterations; iter++) { + + if (iter > 0) { + centroids = centroid_sums / centroid_counts.unsqueeze(1).to(torch::kFloat32); + } + + // Reset accumulators. + centroid_sums.zero_(); + centroid_counts.zero_(); + new_partitions.clear(); + new_partitions.resize(n_clusters); + + for (int i = 0; i < n_clusters; i++) { + new_partitions[i] = make_shared(); + new_partitions[i]->set_code_size(partitions[0]->code_size_); + new_partitions[i]->resize(10); + } + + float *centroids_ptr = centroids.data_ptr(); + + // Process each existing partition. + for (auto &part: partitions) { + int64_t nvec = part->num_vectors_; + if (nvec <= 0) continue; + + float *part_vecs = (float *) part->codes_; + int64_t *part_vec_ids = part->ids_; + + // Create batched TopK buffers (k=1 for nearest centroid). + vector > buffers = create_buffers(nvec, 1, false); + + // Use batched_scan_list to get nearest centroid for each vector. + batched_scan_list(part_vecs, + centroids_ptr, + nullptr, + nvec, + n_clusters, + d, + buffers, + metric); + + // For each vector in this partition, determine its assignment. + for (int i = 0; i < nvec; i++) { + vector assign = buffers[i]->get_topk_indices(); // top1 assignment + int assigned_cluster = assign[0]; + + // Update accumulators. + float *vec_ptr = part_vecs + i * d; + int64_t *vec_id = part_vec_ids + i; + + for (int j = 0; j < d; j++) { + centroid_sums_accessor[assigned_cluster][j] += vec_ptr[j]; + } + centroid_counts_accessor[assigned_cluster]++; + + new_partitions[assigned_cluster]->append(1, vec_id, (uint8_t *) vec_ptr); + } + } // end for each partition + std::move(new_partitions.begin(), new_partitions.end(), partitions.begin()); + } // end iterations + + return std::make_tuple(centroids, partitions); +} diff --git a/src/cpp/src/dynamic_inverted_list.cpp b/src/cpp/src/dynamic_inverted_list.cpp index fc157f67..26f67f3f 100644 --- a/src/cpp/src/dynamic_inverted_list.cpp +++ b/src/cpp/src/dynamic_inverted_list.cpp @@ -255,11 +255,6 @@ namespace faiss { shared_ptr old_part = old_it->second; // remove vectors that moved for (auto &kv: vectors_for_new_partition) { - // kv.first = new partition, kv.second = vector indices - // We must identify vector IDs from old_partition and remove them. - // However, the provided arrays (new_vectors, new_vector_ids) presumably correspond - // to vectors that were part of old_vector_partition originally. - // We can remove by IDs directly: for (int idx: kv.second) { idx_t old_id = (idx_t) new_vector_ids[idx]; int64_t pos = old_part->find_id(old_id); @@ -277,6 +272,7 @@ namespace faiss { // Already doesn't exist return; } + partitions_.erase(it); nlist--; } @@ -301,10 +297,6 @@ namespace faiss { } bool DynamicInvertedLists::get_vector_for_id(idx_t id, float *vector_values) { - // This assumes that the stored codes are actually floats or float-like data. - // If codes_ are PQ codes or compressed, you'd need decompression. - // If they are raw float vectors, this works. - for (auto &kv: partitions_) { shared_ptr part = kv.second; int64_t pos = part->find_id(id); @@ -318,6 +310,27 @@ namespace faiss { return false; } + vector DynamicInvertedLists::get_vectors_by_id(vector ids) { + + vector ret; + for (int64_t id : ids) { + bool found = false; + for (auto &kv: partitions_) { + shared_ptr part = kv.second; + int64_t pos = part->find_id(id); + if (pos != -1) { + ret.push_back(reinterpret_cast(part->codes_ + pos * part->code_size_)); + found = true; + break; + } + } + if (!found) { + throw std::runtime_error("ID not found in any partition"); + } + } + return ret; + } + size_t DynamicInvertedLists::get_new_list_id() { return curr_list_id_++; } @@ -568,7 +581,7 @@ int DynamicInvertedLists::get_thread(size_t list_no) { if (it == partitions_.end()) { throw std::runtime_error("List does not exist in get_thread"); } - return it->second->thread_id_; + return it->second->core_id_; } void DynamicInvertedLists::set_thread(size_t list_no, int new_thread_id) { @@ -576,7 +589,7 @@ void DynamicInvertedLists::set_thread(size_t list_no, int new_thread_id) { if (it == partitions_.end()) { throw std::runtime_error("List does not exist in set_thread"); } - it->second->thread_id_ = new_thread_id; + it->second->core_id_ = new_thread_id; } #endif diff --git a/src/cpp/src/hit_count_tracker.cpp b/src/cpp/src/hit_count_tracker.cpp new file mode 100644 index 00000000..fe4e598f --- /dev/null +++ b/src/cpp/src/hit_count_tracker.cpp @@ -0,0 +1,86 @@ +#include "hit_count_tracker.h" + +HitCountTracker::HitCountTracker(int window_size, int total_vectors) + : window_size_(window_size), + total_vectors_(total_vectors), + curr_query_index_(0), + num_queries_recorded_(0), + running_sum_scan_fraction_(0.0f), + current_scan_fraction_(1.0f) { + if (window_size_ <= 0) { + throw std::invalid_argument("Window size must be positive"); + } + if (total_vectors_ <= 0) { + throw std::invalid_argument("Total vectors must be positive"); + } + per_query_hits_.resize(window_size_); + per_query_scanned_sizes_.resize(window_size_); +} + +void HitCountTracker::reset() { + curr_query_index_ = 0; + num_queries_recorded_ = 0; + running_sum_scan_fraction_ = 0.0f; + current_scan_fraction_ = 1.0f; + per_query_hits_.clear(); + per_query_hits_.resize(window_size_); + per_query_scanned_sizes_.clear(); + per_query_scanned_sizes_.resize(window_size_); +} + +void HitCountTracker::set_total_vectors(int total_vectors) { + if (total_vectors <= 0) { + throw std::invalid_argument("Total vectors must be positive"); + } + total_vectors_ = total_vectors; +} + +float HitCountTracker::compute_scan_fraction(const vector& scanned_sizes) const { + int sum = std::accumulate(scanned_sizes.begin(), scanned_sizes.end(), 0); + return static_cast(sum) / static_cast(total_vectors_); +} + +void HitCountTracker::add_query_data(const vector& hit_partition_ids, + const vector& scanned_sizes) { + if (hit_partition_ids.size() != scanned_sizes.size()) { + throw std::invalid_argument("hit_partition_ids and scanned_sizes must be of equal length"); + } + float query_fraction = compute_scan_fraction(scanned_sizes); + + // If we haven't filled the window yet, simply add new data. + if (num_queries_recorded_ < window_size_) { + per_query_hits_[num_queries_recorded_] = hit_partition_ids; + per_query_scanned_sizes_[num_queries_recorded_] = scanned_sizes; + running_sum_scan_fraction_ += query_fraction; + num_queries_recorded_++; + } else { + // Window is full; subtract oldest query data and replace it. + running_sum_scan_fraction_ -= compute_scan_fraction(per_query_scanned_sizes_[curr_query_index_]); + per_query_hits_[curr_query_index_] = hit_partition_ids; + per_query_scanned_sizes_[curr_query_index_] = scanned_sizes; + running_sum_scan_fraction_ += query_fraction; + curr_query_index_ = (curr_query_index_ + 1) % window_size_; + } + int effective_window = (num_queries_recorded_ < window_size_) ? num_queries_recorded_ : window_size_; + current_scan_fraction_ = running_sum_scan_fraction_ / static_cast(effective_window); +} + +float HitCountTracker::get_current_scan_fraction() const { + return current_scan_fraction_; +} + +const vector>& HitCountTracker::get_per_query_hits() const { + return per_query_hits_; +} + +const vector>& HitCountTracker::get_per_query_scanned_sizes() const { + return per_query_scanned_sizes_; +} + +int HitCountTracker::get_window_size() const { + return window_size_; +} + +int64_t HitCountTracker::get_num_queries_recorded() const { + return num_queries_recorded_; +} \ No newline at end of file diff --git a/src/cpp/src/index_partition.cpp b/src/cpp/src/index_partition.cpp index bcbd02c7..fa11b1e5 100644 --- a/src/cpp/src/index_partition.cpp +++ b/src/cpp/src/index_partition.cpp @@ -19,7 +19,7 @@ IndexPartition::IndexPartition(int64_t num_vectors, codes_ = nullptr; ids_ = nullptr; numa_node_ = -1; - thread_id_ = -1; + core_id_ = -1; ensure_capacity(num_vectors); append(num_vectors, ids, codes); } @@ -67,6 +67,12 @@ void IndexPartition::append(int64_t n_entry, const idx_t* new_ids, const uint8_t attributes_table_ = concatenated_table.ValueOrDie(); } num_vectors_ += n_entry; + + // + // // insert new ids into id_to_index_ + // for (int64_t i = 0; i < n_entry; i++) { + // id_to_index_[new_ids[i]] = num_vectors_ - n_entry + i; + // } } void IndexPartition::update(int64_t offset, int64_t n_entry, const idx_t* new_ids, const uint8_t* new_codes) { @@ -92,6 +98,14 @@ void IndexPartition::remove(int64_t index) { int64_t last_idx = num_vectors_ - 1; const size_t code_bytes = static_cast(code_size_); + + // // Update id_to_index_ + // idx_t last_id = ids_[last_idx]; + // idx_t removed_id = ids_[index]; + // + // id_to_index_[last_id] = index; + // id_to_index_.erase(removed_id); + std::memcpy(codes_ + index * code_bytes, codes_ + last_idx * code_bytes, code_bytes); ids_[index] = ids_[last_idx]; @@ -161,7 +175,7 @@ void IndexPartition::resize(int64_t new_capacity) { void IndexPartition::clear() { free_memory(); numa_node_ = -1; - thread_id_ = -1; + core_id_ = -1; buffer_size_ = 0; num_vectors_ = 0; code_size_ = 0; @@ -170,12 +184,25 @@ void IndexPartition::clear() { } int64_t IndexPartition::find_id(idx_t id) const { + + // use map + // auto it = id_to_index_.find(id); + // if (it == id_to_index_.end()) { + // return -1; + // } + // return it->second; + + // use linear search for (int64_t i = 0; i < num_vectors_; i++) { if (ids_[i] == id) { return i; } } - return -1; // not found + return -1; +} + +void IndexPartition::set_core_id(int core_id) { + core_id_ = core_id; } #ifdef QUAKE_USE_NUMA @@ -216,7 +243,7 @@ void IndexPartition::set_numa_node(int new_numa_node) { void IndexPartition::move_from(IndexPartition&& other) { numa_node_ = other.numa_node_; - thread_id_ = other.thread_id_; + core_id_ = other.core_id_; buffer_size_ = other.buffer_size_; num_vectors_ = other.num_vectors_; code_size_ = other.code_size_; diff --git a/src/cpp/src/latency_estimation.cpp b/src/cpp/src/maintenance_cost_estimator.cpp similarity index 65% rename from src/cpp/src/latency_estimation.cpp rename to src/cpp/src/maintenance_cost_estimator.cpp index d6d951dc..8fd2a9c1 100644 --- a/src/cpp/src/latency_estimation.cpp +++ b/src/cpp/src/maintenance_cost_estimator.cpp @@ -1,17 +1,13 @@ -// -// Created by Jason on 12/16/24. -// Prompt for GitHub Copilot: -// - Conform to the google style guide -// - Use descriptive variable names -// - -#include "latency_estimation.h" -#include "list_scanning.h" +#include "maintenance_cost_estimator.h" +#include +#include +#include +#include // A simple helper to split a string by delimiter. // You can replace this with any library function if you wish. static std::vector split_string(const std::string &str, - char delim) { + char delim) { std::vector tokens; std::stringstream ss(str); std::string item; @@ -367,3 +363,140 @@ bool ListScanLatencyEstimator::load_latency_profile(const std::string &filename) scan_latency_model_ = std::move(file_latency_model); return true; } + + +MaintenanceCostEstimator::MaintenanceCostEstimator(int d, float alpha, int k) + : d_(d), alpha_(alpha), k_(k) { + if (k_ <= 0) { + throw std::invalid_argument("k must be positive"); + } + if (alpha_ <= 0.0f) { + throw std::invalid_argument("alpha must be positive"); + } + + latency_estimator_ = make_shared( + d_, + DEFAULT_LATENCY_ESTIMATOR_RANGE_N, + DEFAULT_LATENCY_ESTIMATOR_RANGE_K, + DEFAULT_LATENCY_ESTIMATOR_NTRIALS); +} + +float MaintenanceCostEstimator::compute_split_delta(int partition_size, float hit_rate, int total_partitions) const { + // Compute overhead incurred by adding one more partition. + float delta_overhead = latency_estimator_->estimate_scan_latency(total_partitions + 1, k_) - + latency_estimator_->estimate_scan_latency(total_partitions, k_); + // Cost before splitting. + float old_cost = latency_estimator_->estimate_scan_latency(partition_size, k_) * hit_rate; + // Cost after splitting: assume the partition is split in half and cost doubles due to two partitions, + // scaled by the alpha factor. + float new_cost = latency_estimator_->estimate_scan_latency(partition_size / 2, k_) * hit_rate * (2.0f * alpha_); + return delta_overhead + new_cost - old_cost; +} + + +float MaintenanceCostEstimator::compute_delete_delta( + int partition_size, // size of the candidate (deleted) partition + float hit_rate, // scan fraction ("hit rate") for the candidate partition + int total_partitions, + float avg_partition_hit_rate, // average scan fraction for each partition + float avg_partition_size // average size of partitions +) const { + if (total_partitions <= 1) { + // Can't delete if there's only 1 partition + return 0.0f; + } + + // ---------------------------------------------------- + // 1) Structural overhead difference: + // = L(T-1, k) - L(T, k). + // ---------------------------------------------------- + float latency_T = latency_estimator_->estimate_scan_latency(total_partitions, k_); + float latency_T_minus_1 = latency_estimator_->estimate_scan_latency(total_partitions - 1, k_); + float delta_overhead = latency_T_minus_1 - latency_T; + + // ---------------------------------------------------- + // 2) Scanning cost difference: + // + // Old scanning cost (approx): + // cost_old = (T-1)*(\bar{p}) * L(\bar{n}, k) + // + p_d * L(n_d, k) + // + // New scanning cost (approx): + // each of (T-1) partitions has new size \bar{n}' = \bar{n} + n_d/(T-1) + // and new scan fraction \bar{p}' = \bar{p} + p_d/(T-1) + // cost_new = (T-1)*(\bar{p}') * L(\bar{n}', k) + // ---------------------------------------------------- + float cost_old = (total_partitions - 1) * avg_partition_hit_rate + * latency_estimator_->estimate_scan_latency(avg_partition_size, k_) + + hit_rate + * latency_estimator_->estimate_scan_latency(partition_size, k_); + + // Compute the "new" size and scan fraction after merging + float merged_size = avg_partition_size + static_cast(partition_size) / (total_partitions - 1); + float merged_hit_rate = avg_partition_hit_rate + hit_rate / static_cast(total_partitions - 1); + + float cost_new; + if (partition_size < total_partitions) { + // assume at most partition_size partitions get the extra vectors + cost_new = partition_size * merged_hit_rate * latency_estimator_->estimate_scan_latency(avg_partition_size + 1, k_) + + (total_partitions - partition_size - 1) * merged_hit_rate * latency_estimator_->estimate_scan_latency(avg_partition_size, k_); + } else { + cost_new = (total_partitions - 1) * merged_hit_rate * latency_estimator_->estimate_scan_latency(ceil(merged_size), k_); + } + + float delta_scanning = cost_new - cost_old; + + // ---------------------------------------------------- + // 3) Final delete delta = structural overhead + scanning delta + // ---------------------------------------------------- + float delta = delta_overhead + delta_scanning; + return delta; +} + +float MaintenanceCostEstimator::compute_delete_delta_w_reassign(int partition_size, float hit_rate, int total_partitions, const vector &reassign_counts, const vector &reassign_sizes, const vector &reassign_hit_rates) const { + + if (total_partitions <= 1) { + // Can't delete if there's only 1 partition + return 0.0f; + } + + int n_reassign = reassign_counts.size(); + assert(reassign_sizes.size() == n_reassign); + assert(reassign_hit_rates.size() == n_reassign); + + // ---------------------------------------------------- + // 1) Structural overhead difference: + // = L(T-1, k) - L(T, k). + // ---------------------------------------------------- + float latency_T = latency_estimator_->estimate_scan_latency(total_partitions, k_); + float latency_T_minus_1 = latency_estimator_->estimate_scan_latency(total_partitions - 1, k_); + float delta_overhead = latency_T_minus_1 - latency_T; + + // ---------------------------------------------------- + // 2) Compute cost delta using reassignments + // Delta = Removal of old + increase of existing + // ---------------------------------------------------- + float removal_delta = hit_rate * latency_estimator_->estimate_scan_latency(partition_size, k_); + float reassign_delta = 0.0; + for (int i = 0; i < n_reassign; i++) { + float old = reassign_hit_rates[i] * latency_estimator_->estimate_scan_latency(reassign_sizes[i], k_); + float new_size = reassign_sizes[i] + partition_size; + float new_hit_rate = reassign_hit_rates[i] + hit_rate; + reassign_delta += new_hit_rate * latency_estimator_->estimate_scan_latency(new_size, k_) - old; + } + + // ---------------------------------------------------- + // 3) Final delete delta = structural overhead + scanning delta + // ---------------------------------------------------- + float delta = delta_overhead + removal_delta + reassign_delta; + return delta; +} + + +shared_ptr MaintenanceCostEstimator::get_latency_estimator() const { + return latency_estimator_; +} + +int MaintenanceCostEstimator::get_k() const { + return k_; +} diff --git a/src/cpp/src/maintenance_policies.cpp b/src/cpp/src/maintenance_policies.cpp index b3d675fb..21ced2bb 100644 --- a/src/cpp/src/maintenance_policies.cpp +++ b/src/cpp/src/maintenance_policies.cpp @@ -1,613 +1,202 @@ -// -// Created by Jason on 10/7/24. -// Prompt for GitHub Copilot: -// - Conform to the google style guide -// - Use descriptive variable names - -#include -#include -#include -#include - -vector > -MaintenancePolicy::get_split_history() { - if (debug_) { - std::cout << "[MaintenancePolicy] get_split_history: Entered." << std::endl; - } - vector > split_history; - - // iterate over all splits in the split records, record the hit count of the deleted partition and the two split partitions - for (auto tup : split_records_) { - int64_t parent_id = std::get<0>(tup); - int64_t left_id = std::get<1>(tup).first; - int64_t right_id = std::get<1>(tup).second; - - // get hits for parent - int64_t parent_hits = deleted_partition_hit_rate_[parent_id]; - int64_t left_hits; - int64_t right_hits; - - // get the hits for the left_id and right_id - if (deleted_partition_hit_rate_.find(left_id) == deleted_partition_hit_rate_.end()) { - // partition has not been deleted - left_hits = per_partition_hits_[left_id]; - } else { - left_hits = deleted_partition_hit_rate_[left_id]; - } - - if (deleted_partition_hit_rate_.find(right_id) == deleted_partition_hit_rate_.end()) { - right_hits = per_partition_hits_[right_id]; - } else { - right_hits = deleted_partition_hit_rate_[right_id]; - } - - if (debug_) { - std::cout << "[MaintenancePolicy] get_split_history: Parent " << parent_id - << " (hits=" << parent_hits << ") split into (" - << left_id << " with hits=" << left_hits << ", " - << right_id << " with hits=" << right_hits << ")." << std::endl; - } - split_history.push_back(std::make_tuple(parent_id, parent_hits, left_id, left_hits, right_id, right_hits)); - } - - if (debug_) { - std::cout << "[MaintenancePolicy] get_split_history: Returning " << split_history.size() << " records." << std::endl; - } - return split_history; -} - -shared_ptr MaintenancePolicy::get_partition_state(bool only_modified) { - if (debug_) { - std::cout << "[MaintenancePolicy] get_partition_state: Entered with only_modified=" - << (only_modified ? "true" : "false") << std::endl; - } - vector partition_ids; - vector partition_sizes; - vector partition_hit_rate; - - if (only_modified) { - partition_ids = vector(modified_partitions_.begin(), modified_partitions_.end()); - } else { - Tensor p_ids = partition_manager_->partitions_->get_partition_ids(); - auto p_ids_accessor = p_ids.accessor(); - partition_ids = vector(p_ids.size(0)); - for (int i = 0; i < p_ids.size(0); i++) { - partition_ids[i] = p_ids_accessor[i]; - } - } - - // For each partition, get the size and compute the hit rate. - for (int64_t partition_id : partition_ids) { - int64_t partition_size = partition_manager_->partitions_->list_size(partition_id); - int64_t hits = per_partition_hits_[partition_id]; - int curr_window_size = std::max(std::min(window_size_, curr_query_id_), 1); - float hit_rate = hits / (float)curr_window_size; - partition_sizes.push_back(partition_size); - partition_hit_rate.push_back(hit_rate); - if (debug_) { - std::cout << "[MaintenancePolicy] Partition " << partition_id << ": size=" << partition_size - << ", hits=" << hits << ", window=" << curr_window_size - << ", hit_rate=" << hit_rate << std::endl; - } - } - - shared_ptr state = std::make_shared(); - state->partition_ids = partition_ids; - state->partition_sizes = partition_sizes; - state->partition_hit_rate = partition_hit_rate; - - if (debug_) { - std::cout << "[MaintenancePolicy] get_partition_state: Returning state with " - << partition_ids.size() << " partitions." << std::endl; - } - return state; -} - -void MaintenancePolicy::set_partition_modified(int64_t partition_id) { - modified_partitions_.insert(partition_id); - if (debug_) { - std::cout << "[MaintenancePolicy] set_partition_modified: Marked partition " - << partition_id << " as modified." << std::endl; - } -} - -void MaintenancePolicy::set_partition_unmodified(int64_t partition_id) { - modified_partitions_.erase(partition_id); - if (debug_) { - std::cout << "[MaintenancePolicy] set_partition_unmodified: Unmarked partition " - << partition_id << " as modified." << std::endl; - } -} +#include "maintenance_policies.h" + +#include +#include +#include +#include + +#include "quake_index.h" + +using std::chrono::steady_clock; +using std::chrono::microseconds; +using std::chrono::duration_cast; +using std::vector; +using std::unordered_map; +using std::shared_ptr; + + +MaintenancePolicy::MaintenancePolicy( + shared_ptr partition_manager, + shared_ptr params) + : partition_manager_(partition_manager), + params_(params) { + // Initialize the cost estimator. + cost_estimator_ = std::make_shared( + partition_manager_->d(), // Assumes PartitionManager::get_dimension() exists. + params_->alpha, + 10); + // Initialize the hit count tracker using the window size and total vector count. + hit_count_tracker_ = std::make_shared( + params_->window_size, partition_manager_->ntotal()); +} + +shared_ptr MaintenancePolicy::perform_maintenance() { + // only consider split/deletion once the window is full + + int64_t num_queries = hit_count_tracker_->get_num_queries_recorded(); + if (hit_count_tracker_->get_num_queries_recorded() < params_->window_size) { + std::cout << "Window not full yet. " << num_queries << " queries recorded and " << params_->window_size + << " queries required." << std::endl; + return std::make_shared(); + } + + auto start_total = steady_clock::now(); + // STEP 1: Aggregate hit counts from the HitCountTracker. + vector > per_query_hits = hit_count_tracker_->get_per_query_hits(); + unordered_map aggregated_hits; + for (const auto &query_hits: per_query_hits) { + for (int64_t pid: query_hits) { + aggregated_hits[pid]++; + } + } + + Tensor all_partition_ids_tens = partition_manager_->get_partition_ids(); + vector all_partition_ids = vector(all_partition_ids_tens.data_ptr(), + all_partition_ids_tens.data_ptr() + + all_partition_ids_tens.size(0)); + + // STEP 2: Use cost estimation to decide which partitions to delete or split. + int total_partitions = partition_manager_->nlist(); + float current_scan_fraction = hit_count_tracker_->get_current_scan_fraction(); + vector partitions_to_delete; + vector partitions_to_split; -void MaintenancePolicy::decrement_hit_count(int64_t partition_id) { - if (debug_) { - std::cout << "[MaintenancePolicy] decrement_hit_count: Called for partition " - << partition_id << std::endl; - } - // check if id is in the partition hits - if (per_partition_hits_.find(partition_id) == per_partition_hits_.end()) { - // find it in the removed partition hits - if (ancestor_partition_hits_.find(partition_id) == ancestor_partition_hits_.end()) { - throw std::runtime_error("Partition not found in decrement_hit_count"); - } - if (ancestor_partition_hits_[partition_id] > 0) { - int64_t left_partition_id, right_partition_id; - std::tie(left_partition_id, right_partition_id) = split_records_[partition_id]; - if (debug_) { - std::cout << "[MaintenancePolicy] decrement_hit_count: Partition " << partition_id - << " not in per_partition_hits_. Delegating decrement to children: " - << left_partition_id << ", " << right_partition_id << std::endl; + int avg_partition_size = partition_manager_->ntotal() / total_partitions; + for (const auto &partition_id: all_partition_ids) { + // Get hit count and hit rate for the partition. + int hit_count = aggregated_hits[partition_id]; + float hit_rate = static_cast(hit_count) / static_cast(params_->window_size); + int partition_size = partition_manager_->get_partition_size(partition_id); + + // Deletion decision. + float delete_delta = cost_estimator_->compute_delete_delta( + partition_size, hit_rate, total_partitions, current_scan_fraction, avg_partition_size); + + if (delete_delta < -params_->delete_threshold_ns) { + + if (params_->enable_delete_rejection && partition_size > params_->min_partition_size) { + // check the assignments of the partitions to be deleted. + auto search_params = make_shared(); + search_params->k = 2; // get the top 2 partitions, ignore the first one as it is the partition itself + search_params->batched_scan = true; + float *partition_vectors = (float *) partition_manager_->partition_store_->partitions_[partition_id]->codes_; + Tensor part_vecs = torch::from_blob(partition_vectors, {(int64_t) partition_manager_->partition_store_->list_size(partition_id), + partition_manager_->d()}, torch::kFloat32); + auto res = partition_manager_->parent_->search(part_vecs, search_params); + + Tensor reassign_ids = res->ids.flatten(); + + // remove the partition itself + reassign_ids = reassign_ids.masked_select(reassign_ids != partition_id); + + // Get A) the unique partitions, B) the number reassigned, C) the size of the partitions, D) hit rates of the partitions + Tensor uniques; + Tensor counts; + std::tie(uniques, std::ignore, counts) = torch::_unique2(reassign_ids, true, false, true); + Tensor part_sizes = partition_manager_->get_partition_sizes(uniques); + + // convert to vectors + vector reassign_id_vec = vector(uniques.data_ptr(), uniques.data_ptr() + uniques.size(0)); + + vector reassign_sizes = vector(part_sizes.data_ptr(), + part_sizes.data_ptr() + part_sizes.size(0)); + vector reassign_counts = vector(counts.data_ptr(), + counts.data_ptr() + counts.size(0)); + vector hit_rates; + for (int64_t reassign_id: reassign_id_vec) { + hit_rates.push_back(static_cast(aggregated_hits[reassign_id]) / static_cast(params_->window_size)); + } + + float delta = cost_estimator_->compute_delete_delta_w_reassign(partition_manager_->get_partition_size(partition_id), + static_cast(aggregated_hits[partition_id]) / static_cast(params_->window_size), + total_partitions, + reassign_counts, + reassign_sizes, + hit_rates); + + if (delta < -params_->delete_threshold_ns) { + partitions_to_delete.push_back(partition_id); + } + } else { + partitions_to_delete.push_back(partition_id); } - decrement_hit_count(left_partition_id); - decrement_hit_count(right_partition_id); - ancestor_partition_hits_[partition_id]--; - } - } else { - if (per_partition_hits_[partition_id] > 0) { - per_partition_hits_[partition_id]--; - if (debug_) { - std::cout << "[MaintenancePolicy] decrement_hit_count: Decremented partition " - << partition_id << " to " << per_partition_hits_[partition_id] << std::endl; + } else { + if (partition_size > params_->min_partition_size) { + float split_delta = cost_estimator_->compute_split_delta( + partition_size, hit_rate, total_partitions); + if (split_delta < -params_->split_threshold_ns) { + partitions_to_split.push_back(partition_id); + } } } } -} -void MaintenancePolicy::increment_hit_count(vector hit_partition_ids) { - if (debug_) { - std::cout << "[MaintenancePolicy] increment_hit_count: Processing partitions:"; - for (auto id : hit_partition_ids) { - std::cout << " " << id; - } - std::cout << std::endl; - } + // Convert partition ID vectors to Torch tensors. + Tensor partitions_to_delete_tens = torch::from_blob( + partitions_to_delete.data(), {static_cast(partitions_to_delete.size())}, + torch::kInt64).clone(); + Tensor partitions_to_split_tens = torch::from_blob( + partitions_to_split.data(), {static_cast(partitions_to_split.size())}, + torch::kInt64).clone(); - const int64_t total_vectors = partition_manager_->ntotal(); - if (total_vectors == 0) { - throw std::runtime_error("Error: index_->ntotal() is zero in increment_hit_count."); + // STEP 3: Process deletions. + auto start_delete = steady_clock::now(); + if (partitions_to_delete_tens.numel() > 0) { + partition_manager_->delete_partitions(partitions_to_delete_tens); } + auto end_delete = steady_clock::now(); - int vectors_scanned_new = 0; - for (const auto &hit : hit_partition_ids) { - per_partition_hits_[hit]++; - int size = partition_manager_->partitions_->list_size(hit); - vectors_scanned_new += size; - modified_partitions_.insert(hit); - if (debug_) { - std::cout << "[MaintenancePolicy] increment_hit_count: Partition " << hit - << " new hit count: " << per_partition_hits_[hit] << ", size: " << size << std::endl; - } - } - float new_scan_fraction = static_cast(vectors_scanned_new) / total_vectors; + // STEP 4: Process splits. + auto start_split = steady_clock::now(); + shared_ptr split_partitions; + if (partitions_to_split_tens.numel() > 0) { - int current_query_index = curr_query_id_ % window_size_; - if (curr_query_id_ >= window_size_) { - const auto &oldest_hits = per_query_hits_[current_query_index]; - const auto &oldest_sizes = per_query_scanned_partitions_sizes_[current_query_index]; + // split the partitions into two + split_partitions = partition_manager_->split_partitions(partitions_to_split_tens); - if (debug_) { - std::cout << "[MaintenancePolicy] increment_hit_count: Removing oldest query data at index " - << current_query_index << std::endl; - } - for (const auto &hit : oldest_hits) { - decrement_hit_count(hit); - } - int vectors_scanned_old = 0; - for (const auto &size : oldest_sizes) { - vectors_scanned_old += size; - } - float oldest_scan_fraction = static_cast(vectors_scanned_old) / total_vectors; - running_sum_scan_fraction_ -= oldest_scan_fraction; - if (debug_) { - std::cout << "[MaintenancePolicy] increment_hit_count: Old scan fraction " - << oldest_scan_fraction << " removed." << std::endl; - } - } - per_query_hits_[current_query_index] = hit_partition_ids; - std::vector hits_sizes; - hits_sizes.reserve(hit_partition_ids.size()); - for (const auto &hit : hit_partition_ids) { - hits_sizes.push_back(partition_manager_->partitions_->list_size(hit)); - } - per_query_scanned_partitions_sizes_[current_query_index] = hits_sizes; - running_sum_scan_fraction_ += new_scan_fraction; - int current_window_size = std::min(static_cast(curr_query_id_) + 1, window_size_); - current_scan_fraction_ = running_sum_scan_fraction_ / static_cast(current_window_size); - if (current_scan_fraction_ == 0.0) { - current_scan_fraction_ = 1.0; - } - if (debug_) { - std::cout << "[MaintenancePolicy] increment_hit_count: Query " << curr_query_id_ - << " added with scan fraction " << new_scan_fraction - << ", new current_scan_fraction: " << current_scan_fraction_ << std::endl; - } - curr_query_id_++; -} + // remove old partitions + partition_manager_->delete_partitions(partitions_to_split_tens, false); -vector MaintenancePolicy::estimate_split_delta(shared_ptr state) { - if (debug_) { - std::cout << "[MaintenancePolicy] estimate_split_delta: Calculating split deltas." << std::endl; + // add new partitions + partition_manager_->add_partitions(split_partitions); } - vector deltas; - int n_partitions = partition_manager_->nlist(); - int k = 10; - float alpha = alpha_; - float delta_overhead = (latency_estimator_->estimate_scan_latency(n_partitions + 1, k) - - latency_estimator_->estimate_scan_latency(n_partitions, k)); - if (debug_) { - std::cout << "[MaintenancePolicy] estimate_split_delta: delta_overhead = " << delta_overhead << std::endl; + auto end_split = steady_clock::now(); + // STEP 5: Perform local refinement on newly split partitions. + if (split_partitions && split_partitions->partition_ids.numel() > 0) { + local_refinement(split_partitions->partition_ids); } - for (int i = 0; i < state->partition_ids.size(); i++) { - int64_t partition_id = state->partition_ids[i]; - int64_t partition_size = state->partition_sizes[i]; - float hit_rate = state->partition_hit_rate[i]; - - float old_cost = latency_estimator_->estimate_scan_latency(partition_size, k) * hit_rate; - float new_cost = latency_estimator_->estimate_scan_latency(partition_size / 2, k) * hit_rate * (2 * alpha); - float delta = delta_overhead + new_cost - old_cost; - deltas.push_back(delta); - if (debug_) { - std::cout << "[MaintenancePolicy] estimate_split_delta: Partition " << partition_id - << " (size=" << partition_size << ", hit_rate=" << hit_rate - << ") => old_cost=" << old_cost << ", new_cost=" << new_cost - << ", delta=" << delta << std::endl; - } - } - return deltas; -} - -vector MaintenancePolicy::estimate_delete_delta(shared_ptr state) { - if (debug_) { - std::cout << "[MaintenancePolicy] estimate_delete_delta: Calculating delete deltas." << std::endl; - } - vector deltas; - int n_partitions = partition_manager_->nlist(); - int k = 10; - float alpha = alpha_; - float delta_overhead = (latency_estimator_->estimate_scan_latency(n_partitions - 1, k) - - latency_estimator_->estimate_scan_latency(n_partitions, k)); - if (debug_) { - std::cout << "[MaintenancePolicy] estimate_delete_delta: delta_overhead = " << delta_overhead << std::endl; - } - for (int i = 0; i < state->partition_ids.size(); i++) { - int64_t partition_id = state->partition_ids[i]; - int64_t partition_size = state->partition_sizes[i]; - float hit_rate = state->partition_hit_rate[i]; - float delta_reassign = current_scan_fraction_ * latency_estimator_->estimate_scan_latency(partition_size, k); - float delta = delta_overhead + delta_reassign; - deltas.push_back(delta); - if (debug_) { - std::cout << "[MaintenancePolicy] estimate_delete_delta: Partition " << partition_id - << " (size=" << partition_size << ", hit_rate=" << hit_rate - << ") => delta_reassign=" << delta_reassign << ", delta=" << delta << std::endl; - } - } - return deltas; -} + auto end_total = steady_clock::now(); -float MaintenancePolicy::estimate_add_level_delta() { - if (debug_) { - std::cout << "[MaintenancePolicy] estimate_add_level_delta: Returning 0.0" << std::endl; - } - return 0.0; -} - -float MaintenancePolicy::estimate_remove_level_delta() { - if (debug_) { - std::cout << "[MaintenancePolicy] estimate_remove_level_delta: Returning 0.0" << std::endl; - } - return 0.0; -} - -shared_ptr MaintenancePolicy::maintenance() { - if (debug_) { - std::cout << "[MaintenancePolicy] maintenance: Starting maintenance." << std::endl; - } + // STEP 6: Fill in timing details. shared_ptr timing_info = std::make_shared(); + timing_info->delete_time_us = duration_cast(end_delete - start_delete).count(); + timing_info->split_time_us = duration_cast(end_split - start_split).count(); + timing_info->total_time_us = duration_cast(end_total - start_total).count(); - auto total_start = std::chrono::high_resolution_clock::now(); - auto start = std::chrono::high_resolution_clock::now(); - Tensor delete_ids = check_and_delete_partitions(); - auto end = std::chrono::high_resolution_clock::now(); - timing_info->delete_time_us = std::chrono::duration_cast(end - start).count(); - timing_info->n_deletes = delete_ids.size(0); - - if (debug_) { - std::cout << "[MaintenancePolicy] maintenance: Deleted " << timing_info->n_deletes - << " partitions." << std::endl; - } - - start = std::chrono::high_resolution_clock::now(); - Tensor split_ids; - Tensor old_centroids; - Tensor old_ids; - std::tie(split_ids, old_centroids, old_ids) = check_and_split_partitions(); - end = std::chrono::high_resolution_clock::now(); - timing_info->split_time_us = std::chrono::duration_cast(end - start).count(); - timing_info->n_splits = old_ids.size(0); - - if (debug_) { - std::cout << "[MaintenancePolicy] maintenance: " << timing_info->n_splits - << " splits detected." << std::endl; - } - - start = std::chrono::high_resolution_clock::now(); - if (split_ids.size(0) > 0) { - local_refinement(split_ids, refinement_radius_); - if (debug_) { - std::cout << "[MaintenancePolicy] maintenance: Performed local refinement on splits." << std::endl; - } - } - end = std::chrono::high_resolution_clock::now(); - timing_info->split_refine_time_us = std::chrono::duration_cast(end - start).count(); - auto total_end = std::chrono::high_resolution_clock::now(); - timing_info->total_time_us = std::chrono::duration_cast(total_end - total_start).count(); - - if (debug_) { - std::cout << "[MaintenancePolicy] maintenance: Completed in " - << timing_info->total_time_us << " microseconds." << std::endl; - } return timing_info; } -void MaintenancePolicy::add_split(int64_t old_partition_id, int64_t left_partition_id, int64_t right_partition_id) { - int64_t num_queries = std::max(std::min(curr_query_id_, window_size_), 1); - split_records_[old_partition_id] = std::make_pair(left_partition_id, right_partition_id); - per_partition_hits_[left_partition_id] = per_partition_hits_[old_partition_id]; - per_partition_hits_[right_partition_id] = per_partition_hits_[old_partition_id]; - ancestor_partition_hits_[old_partition_id] = per_partition_hits_[old_partition_id]; - deleted_partition_hit_rate_[old_partition_id] = (float) per_partition_hits_[old_partition_id] / num_queries; - per_partition_hits_.erase(old_partition_id); - if (debug_) { - std::cout << "[MaintenancePolicy] add_split: Partition " << old_partition_id - << " split into " << left_partition_id << " and " << right_partition_id - << " with original hits " << ancestor_partition_hits_[old_partition_id] << std::endl; - } -} - -void MaintenancePolicy::add_partition(int64_t partition_id, int64_t hits) { - per_partition_hits_[partition_id] = hits; - if (debug_) { - std::cout << "[MaintenancePolicy] add_partition: Added partition " << partition_id - << " with hit count " << hits << std::endl; - } -} - -void MaintenancePolicy::remove_partition(int64_t partition_id) { - int64_t num_queries = std::max(std::min(curr_query_id_, window_size_), 1); - ancestor_partition_hits_[partition_id] = per_partition_hits_[partition_id]; - deleted_partition_hit_rate_[partition_id] = per_partition_hits_[partition_id] / num_queries; - per_partition_hits_.erase(partition_id); - if (debug_) { - std::cout << "[MaintenancePolicy] remove_partition: Removed partition " << partition_id - << ". Final hit rate: " << deleted_partition_hit_rate_[partition_id] << std::endl; - } -} - -void MaintenancePolicy::refine_partitions(Tensor partition_ids, int refinement_iterations) { - if (debug_) { - std::cout << "[MaintenancePolicy] refine_partitions: Called with " - << (partition_ids.defined() ? std::to_string(partition_ids.numel()) : "all partitions") - << " and iterations=" << refinement_iterations << std::endl; - } - // TODO: Add detailed logging when refinement is implemented. +void MaintenancePolicy::record_query_hits(vector partition_ids) { + vector scanned_sizes = partition_manager_->get_partition_sizes(partition_ids); + hit_count_tracker_->add_query_data(partition_ids, scanned_sizes); } -QueryCostMaintenance::QueryCostMaintenance(std::shared_ptr partition_manager, - shared_ptr params) { - maintenance_policy_name_ = "query_cost"; - window_size_ = params->window_size; - refinement_radius_ = params->refinement_radius; - min_partition_size_ = params->min_partition_size; - alpha_ = params->alpha; - enable_split_rejection_ = params->enable_split_rejection; - enable_delete_rejection_ = params->enable_delete_rejection; - - curr_query_id_ = 0; - per_query_hits_ = vector >(window_size_); - per_query_scanned_partitions_sizes_ = vector >(window_size_); - partition_manager_ = partition_manager; - current_scan_fraction_ = 1.0; - - std::string profile_filename = "latency_profile.csv"; - - latency_estimator_ = std::make_shared( - partition_manager_->d(), - latency_grid_n_values_, - latency_grid_k_values_, - n_trials_, - false, - profile_filename); - - if (debug_) { - std::cout << "[QueryCostMaintenance] Constructor: Initialized with window_size=" << window_size_ - << ", refinement_radius=" << refinement_radius_ - << ", min_partition_size=" << min_partition_size_ - << ", alpha=" << alpha_ << std::endl; - } -} - -float QueryCostMaintenance::compute_alpha_for_window() { - if (debug_) { - std::cout << "[QueryCostMaintenance] compute_alpha_for_window: Computing alpha from split history." << std::endl; - } - if (split_records_.empty()) { - if (debug_) { - std::cout << "[QueryCostMaintenance] compute_alpha_for_window: No split records found." << std::endl; - } - return 0; - } - - float total_alpha = 0.0; - for (const auto &split: split_records_) { - int64_t parent_id = std::get<0>(split); - int64_t left_id = std::get<1>(split).first; - int64_t right_id = std::get<1>(split).second; - - float left_hit_rate; - float right_hit_rate; - if (deleted_partition_hit_rate_.find(left_id) != deleted_partition_hit_rate_.end()) { - left_hit_rate = deleted_partition_hit_rate_[left_id]; - } else { - left_hit_rate = (float) per_partition_hits_[left_id] / window_size_; - } - - if (deleted_partition_hit_rate_.find(right_id) != deleted_partition_hit_rate_.end()) { - right_hit_rate = deleted_partition_hit_rate_[right_id]; - } else { - right_hit_rate = (float) per_partition_hits_[right_id] / window_size_; - } - - float parent_hit_rate = deleted_partition_hit_rate_[parent_id]; - float curr_alpha = (left_hit_rate + right_hit_rate) / (2 * parent_hit_rate); - if (debug_) { - std::cout << "[QueryCostMaintenance] compute_alpha_for_window: For split from partition " << parent_id - << ", left hit_rate=" << left_hit_rate << ", right hit_rate=" << right_hit_rate - << ", parent hit_rate=" << parent_hit_rate << ", alpha=" << curr_alpha << std::endl; - } - total_alpha += curr_alpha; - } - float computed_alpha = total_alpha / split_records_.size(); - if (debug_) { - std::cout << "[QueryCostMaintenance] compute_alpha_for_window: Computed alpha = " << computed_alpha << std::endl; - } - return computed_alpha; +void MaintenancePolicy::reset() { + hit_count_tracker_->reset(); } -void QueryCostMaintenance::local_refinement(Tensor partition_ids, int refinement_radius) { - if (debug_) { - std::cout << "[QueryCostMaintenance] local_refinement: Refining partitions: " << std::to_string(partition_ids.numel()) - << " with radius " << refinement_radius << std::endl; - } +void MaintenancePolicy::local_refinement(const torch::Tensor &partition_ids) { Tensor split_centroids = partition_manager_->parent_->get(partition_ids); auto search_params = std::make_shared(); search_params->nprobe = 1000; - search_params->k = refinement_radius; + search_params->k = params_->refinement_radius; + + if (params_->refinement_radius == 0) { + return; + } auto result = partition_manager_->parent_->search(split_centroids, search_params); Tensor refine_ids = std::get<0>(torch::_unique(result->ids)); refine_ids = refine_ids.masked_select(refine_ids != -1); - refine_partitions(refine_ids, refinement_iterations_); - if (debug_) { - std::cout << "[QueryCostMaintenance] local_refinement: Completed refinement." << std::endl; - } + partition_manager_->refine_partitions(refine_ids, params_->refinement_iterations); } - -Tensor QueryCostMaintenance::check_and_delete_partitions() { - if (debug_) { - std::cout << "[QueryCostMaintenance] check_and_delete_partitions: Starting deletion check." << std::endl; - } - if (partition_manager_->parent_ == nullptr) { - if (debug_) { - std::cout << "[QueryCostMaintenance] check_and_delete_partitions: No parent index; skipping deletion." << std::endl; - } - return {}; - } - int64_t n_partitions = partition_manager_->parent_->ntotal(); - int64_t num_queries = std::min(curr_query_id_, window_size_); - shared_ptr state = get_partition_state(false); - vector delete_delta = estimate_delete_delta(state); - Tensor delete_delta_tensor = - torch::from_blob(delete_delta.data(), {(int64_t) delete_delta.size()}, torch::kFloat32).clone(); - - vector partitions_to_delete; - vector partition_to_delete_delta; - vector partition_to_delete_hit_rate; - for (int i = 0; i < delete_delta.size(); i++) { - if (delete_delta[i] < -delete_threshold_ns_) { - int64_t curr_partition_id = state->partition_ids[i]; - partitions_to_delete.push_back(curr_partition_id); - partition_to_delete_delta.push_back(delete_delta[i]); - partition_to_delete_hit_rate.push_back(state->partition_hit_rate[i]); - if (debug_) { - std::cout << "[QueryCostMaintenance] check_and_delete_partitions: Marking partition " << curr_partition_id - << " for deletion with delta " << delete_delta[i] - << " and hit rate " << state->partition_hit_rate[i] << std::endl; - } - remove_partition(curr_partition_id); - } - } - - Tensor partition_ids_tensor = torch::from_blob(partitions_to_delete.data(), - {(int64_t) partitions_to_delete.size()}, torch::kInt64).clone(); - shared_ptr clustering = partition_manager_->select_partitions(partition_ids_tensor, true); - partition_manager_->delete_partitions(partition_ids_tensor, true); - return partition_ids_tensor; -} - -std::tuple QueryCostMaintenance::check_and_split_partitions() { - if (debug_) { - std::cout << "[QueryCostMaintenance] check_and_split_partitions: Starting split check." << std::endl; - } - if (partition_manager_->parent_ == nullptr) { - if (debug_) { - std::cout << "[QueryCostMaintenance] check_and_split_partitions: No parent index; skipping splits." << std::endl; - } - return {}; - } - int64_t n_partitions = partition_manager_->nlist(); - shared_ptr state = get_partition_state(false); - vector split_deltas = estimate_split_delta(state); - - vector partitions_to_split; - for (int i = 0; i < split_deltas.size(); i++) { - int64_t partition_size = state->partition_sizes[i]; - if (split_deltas[i] < -split_threshold_ns_ && partition_size > 2 * min_partition_size_) { - partitions_to_split.push_back(state->partition_ids[i]); - if (debug_) { - std::cout << "[QueryCostMaintenance] check_and_split_partitions: Partition " - << state->partition_ids[i] << " marked for splitting (size=" << partition_size - << ", delta=" << split_deltas[i] << ")." << std::endl; - } - } - } - Tensor partitions_to_split_tensor = torch::from_blob(partitions_to_split.data(), - {(int64_t) partitions_to_split.size()}, torch::kInt64); - shared_ptr split_partitions = partition_manager_->split_partitions(partitions_to_split_tensor); - - Tensor old_centroids = partition_manager_->parent_->get(partitions_to_split_tensor); - partition_manager_->delete_partitions(partitions_to_split_tensor, false); - partition_manager_->add_partitions(split_partitions); - Tensor new_ids = split_partitions->partition_ids; - if (debug_) { - std::cout << "[QueryCostMaintenance] check_and_split_partitions: Added " << new_ids.size(0) - << " new partitions." << std::endl; - } - auto new_ids_accessor = split_partitions->partition_ids.accessor(); - auto split_ids_accessor = partitions_to_split_tensor.accessor(); - for (int i = 0; i < partitions_to_split_tensor.size(0); i++) { - int64_t old_partition_id = split_ids_accessor[i]; - int64_t left_partition_id = new_ids_accessor[i * 2]; - int64_t right_partition_id = new_ids_accessor[i * 2 + 1]; - add_split(old_partition_id, left_partition_id, right_partition_id); - } - - return {split_partitions->partition_ids, old_centroids, partitions_to_split_tensor}; -} - -shared_ptr DeDriftMaintenance::maintenance() { - if (debug_) { - std::cout << "[DeDriftMaintenance] maintenance: Starting dedrift maintenance." << std::endl; - } - shared_ptr timing_info; - auto total_start = std::chrono::high_resolution_clock::now(); - - // Perform dedrift maintenance tasks (e.g., reassign centroids) - // For now, we just call local refinement on selected partitions. - Tensor partition_ids = partition_manager_->get_partition_ids(); - Tensor partition_sizes = partition_manager_->get_partition_sizes(partition_ids); - - Tensor sort_args = partition_sizes.argsort(0, true); - Tensor small_partition_ids = partition_ids.index_select(0, sort_args.narrow(0, 0, k_small_)); - Tensor large_partition_ids = partition_ids.index_select(0, sort_args.narrow(0, partition_ids.size(0) - k_large_, k_large_)); - Tensor all_partition_ids = torch::cat({small_partition_ids, large_partition_ids}, 0); - refine_partitions(all_partition_ids, refinement_iterations_); - - auto total_end = std::chrono::high_resolution_clock::now(); - timing_info->total_time_us = std::chrono::duration_cast(total_end - total_start).count(); - if (debug_) { - std::cout << "[DeDriftMaintenance] maintenance: Completed dedrift maintenance in " - << timing_info->total_time_us << " us." << std::endl; - } - return timing_info; -} \ No newline at end of file diff --git a/src/cpp/src/partition_manager.cpp b/src/cpp/src/partition_manager.cpp index e99deb2d..8659f3f2 100644 --- a/src/cpp/src/partition_manager.cpp +++ b/src/cpp/src/partition_manager.cpp @@ -26,7 +26,7 @@ static inline const uint8_t *as_uint8_ptr(const Tensor &float_tensor) { PartitionManager::PartitionManager() { parent_ = nullptr; - partitions_ = nullptr; + partition_store_ = nullptr; } PartitionManager::~PartitionManager() { @@ -56,9 +56,9 @@ void PartitionManager::init_partitions( "[PartitionManager] init_partitions: parent's ntotal does not match partition_ids.size(0)."); } - // Create the local partitions_: + // Create the local partition_store_: size_t code_size_bytes = static_cast(dim * sizeof(float)); - partitions_ = std::make_shared( + partition_store_ = std::make_shared( 0, code_size_bytes ); @@ -70,7 +70,7 @@ void PartitionManager::init_partitions( // Add an empty list for each partition ID auto partition_ids_accessor = clustering->partition_ids.accessor(); for (int64_t i = 0; i < nlist; i++) { - partitions_->add_list(partition_ids_accessor[i]); + partition_store_->add_list(partition_ids_accessor[i]); if (debug_) { std::cout << "[PartitionManager] init_partitions: Added empty list for partition " << i << std::endl; } @@ -106,7 +106,7 @@ void PartitionManager::init_partitions( resident_ids_.insert(id_val); } } - partitions_->add_entries( + partition_store_->add_entries( partition_ids_accessor[i], count, id.data_ptr(), @@ -182,8 +182,8 @@ shared_ptr PartitionManager::add( /// Input validation ////////////////////////////////////////// auto s1 = std::chrono::high_resolution_clock::now(); - if (!partitions_) { - throw runtime_error("[PartitionManager] add: partitions_ is null. Did you call init_partitions?"); + if (!partition_store_) { + throw runtime_error("[PartitionManager] add: partition_store_ is null. Did you call init_partitions?"); } if (!vectors.defined() || !vector_ids.defined()) { @@ -237,9 +237,9 @@ shared_ptr PartitionManager::add( } } - // checks assignments are less than partitions_->curr_list_id_ + // checks assignments are less than partition_store_->curr_list_id_ if (assignments.defined() && (assignments >= curr_partition_id_).any().item()) { - throw runtime_error("[PartitionManager] add: assignments must be less than partitions_->curr_list_id_."); + throw runtime_error("[PartitionManager] add: assignments must be less than partition_store_->curr_list_id_."); } auto e1 = std::chrono::high_resolution_clock::now(); timing_info->input_validation_time_us = std::chrono::duration_cast(e1 - s1).count(); @@ -291,7 +291,7 @@ shared_ptr PartitionManager::add( /// Add vectors to partitions ////////////////////////////////////////// auto s3 = std::chrono::high_resolution_clock::now(); - size_t code_size_bytes = partitions_->code_size; + size_t code_size_bytes = partition_store_->code_size; auto id_ptr = vector_ids.data_ptr(); auto id_accessor = vector_ids.accessor(); const uint8_t *code_ptr = as_uint8_ptr(vectors); @@ -303,9 +303,9 @@ shared_ptr PartitionManager::add( << " into partition " << pid << std::endl; } - std::shared_ptr filtered_table_result = filterRowById(attributes_table, id_accessor[i]); - partitions_->add_entries( + std::shared_ptr filtered_table_result = filterRowById(attributes_table, id_accessor[i]); + partition_store_->add_entries( pid, /*n_entry=*/1, id_ptr + i, @@ -326,8 +326,8 @@ shared_ptr PartitionManager::remove(const Tensor &ids) { if (debug_) { std::cout << "[PartitionManager] remove: Removing " << ids.size(0) << " ids." << std::endl; } - if (!partitions_) { - throw runtime_error("[PartitionManager] remove: partitions_ is null."); + if (!partition_store_) { + throw runtime_error("[PartitionManager] remove: partition_store_ is null."); } if (!ids.defined() || ids.size(0) == 0) { if (debug_) { @@ -367,8 +367,8 @@ shared_ptr PartitionManager::remove(const Tensor &ids) { timing_info->find_partition_time_us = std::chrono::duration_cast(e2 - s2).count(); auto s3 = std::chrono::high_resolution_clock::now(); - partitions_->remove_vectors(to_remove); - // TODO: Remove associated attribute data as well??? + + partition_store_->remove_vectors(to_remove); if (debug_) { std::cout << "[PartitionManager] remove: Completed removal." << std::endl; } @@ -383,11 +383,11 @@ Tensor PartitionManager::get(const Tensor &ids) { std::cout << "[PartitionManager] get: Retrieving vectors for " << ids.size(0) << " ids." << std::endl; } auto ids_accessor = ids.accessor(); - Tensor vectors = torch::empty({ids.size(0), partitions_->d_}, torch::kFloat32); + Tensor vectors = torch::empty({ids.size(0), partition_store_->d_}, torch::kFloat32); auto vectors_ptr = vectors.data_ptr(); for (int64_t i = 0; i < ids.size(0); i++) { - partitions_->get_vector_for_id(ids_accessor[i], vectors_ptr + i * partitions_->d_); + partition_store_->get_vector_for_id(ids_accessor[i], vectors_ptr + i * partition_store_->d_); } if (debug_) { std::cout << "[PartitionManager] get: Retrieval complete." << std::endl; @@ -395,6 +395,11 @@ Tensor PartitionManager::get(const Tensor &ids) { return vectors; } +vector PartitionManager::get_vectors(vector ids) { + return partition_store_->get_vectors_by_id(ids); +} + + shared_ptr PartitionManager::select_partitions(const Tensor &select_ids, bool copy) { if (debug_) { std::cout << "[PartitionManager] select_partitions: Selecting partitions from provided ids." << std::endl; @@ -402,12 +407,12 @@ shared_ptr PartitionManager::select_partitions(const Tensor &select_ Tensor centroids = parent_->get(select_ids); vector cluster_vectors; vector cluster_ids; - int d = (int) partitions_->d_; + int d = (int) partition_store_->d_; auto selected_ids_accessor = select_ids.accessor(); for (int i = 0; i < select_ids.size(0); i++) { int64_t list_no = selected_ids_accessor[i]; - int64_t list_size = partitions_->list_size(list_no); + int64_t list_size = partition_store_->list_size(list_no); if (list_size == 0) { cluster_vectors.push_back(torch::empty({0, d}, torch::kFloat32)); cluster_ids.push_back(torch::empty({0}, torch::kInt64)); @@ -416,8 +421,8 @@ shared_ptr PartitionManager::select_partitions(const Tensor &select_ } continue; } - auto codes = partitions_->get_codes(list_no); - auto ids = partitions_->get_ids(list_no); + auto codes = partition_store_->get_codes(list_no); + auto ids = partition_store_->get_ids(list_no); Tensor cluster_vectors_i = torch::from_blob((void *) codes, {list_size, d}, torch::kFloat32); Tensor cluster_ids_i = torch::from_blob((void *) ids, {list_size}, torch::kInt64); if (copy) { @@ -452,7 +457,7 @@ shared_ptr PartitionManager::split_partitions(const Tensor &partitio int64_t num_partitions_to_split = partition_ids.size(0); int64_t num_splits = 2; int64_t total_new_partitions = num_partitions_to_split * num_splits; - int d = partitions_->d_; + int d = partition_store_->d_; Tensor split_centroids = torch::empty({total_new_partitions, d}, torch::kFloat32); vector split_vectors; @@ -498,120 +503,44 @@ shared_ptr PartitionManager::split_partitions(const Tensor &partitio return split_clustering; } -void PartitionManager::refine_partitions(const Tensor &partition_ids, int iterations) { +void PartitionManager::refine_partitions(Tensor partition_ids, int iterations) { if (debug_) { std::cout << "[PartitionManager] refine_partitions: Refining partitions with iterations = " << iterations << std::endl; } - Tensor pids = partition_ids.defined() && partition_ids.size(0) > 0 - ? partition_ids - : get_partition_ids(); - if (!pids.size(0) || !partitions_ || !parent_) { - throw runtime_error("[PartitionManager] refine_partitions: no partitions to refine."); - } - - faiss::MetricType mt = parent_->metric_; - shared_ptr selected_parts = select_partitions(pids); - Tensor centroids = selected_parts->centroids; - int64_t nclusters = pids.size(0); - int64_t d = centroids.size(1); - bool isIP = (mt == faiss::METRIC_INNER_PRODUCT); - - auto pids_accessor = pids.accessor(); - for (int iter = 0; iter < iterations; iter++) { - Tensor new_centroids = torch::zeros_like(centroids); - Tensor counts = torch::zeros({nclusters}, torch::kLong); - - // #pragma omp parallel - { - Tensor local_centroids = torch::zeros_like(centroids); - Tensor local_counts = torch::zeros({nclusters}, torch::kLong); - - // #pragma omp for nowait - for (int64_t i = 0; i < nclusters; i++) { - Tensor vecs = selected_parts->vectors[i]; - if (!vecs.defined() || !vecs.size(0)) continue; - - Tensor dist = isIP - ? -torch::mm(vecs, centroids.t()) - : torch::cdist(vecs, centroids); - - auto min_res = dist.min(/*dim=*/1, /*keepdim=*/false); - Tensor labels = std::get<1>(min_res); - auto lbl_acc = labels.accessor(); + if (!partition_ids.defined()) { + partition_ids = parent_->get_ids(); + } - for (int64_t row = 0; row < vecs.size(0); row++) { - int64_t c = lbl_acc[row]; - local_centroids[c] += vecs[row]; - local_counts[c] += 1; - } - } - // #pragma omp critical - { - new_centroids += local_centroids; - counts += local_counts; - } + if (partition_ids.size(0) == 0) { + if (debug_) { + std::cout << "[PartitionManager] refine_partitions: No partitions to refine. Exiting." << std::endl; } + return; + } - if (iter < iterations) { - auto counts_acc = counts.accessor(); - for (int64_t c = 0; c < nclusters; c++) { - int64_t n = counts_acc[c]; - if (n > 0) { - centroids[c] = new_centroids[c] / (float)n; - if (isIP) { - float norm = centroids[c].norm().item(); - if (norm > 1e-12f) { - centroids[c] /= norm; - } - } - } - } - } + auto pids = partition_ids.accessor(); + + Tensor current_centroids = parent_->get(partition_ids); + vector> index_partitions(partition_ids.size(0)); + for (int i = 0; i < partition_ids.size(0); i++) { + index_partitions[i] = partition_store_->partitions_[pids[i]]; } - // Final assignment - for (int64_t i = 0; i < nclusters; i++) { - int64_t pid = pids_accessor[i]; - Tensor vecs = selected_parts->vectors[i]; - Tensor ids = selected_parts->vector_ids[i]; - if (!vecs.defined() || !vecs.size(0)) continue; - - Tensor dist = isIP - ? -torch::mm(vecs, centroids.t()) - : torch::cdist(vecs, centroids); - auto min_res = dist.min(/*dim=*/1, /*keepdim=*/false); - Tensor labels = std::get<1>(min_res); - - std::vector to_remove; - to_remove.reserve(ids.size(0)); - std::vector new_pids_array(ids.size(0)); - - auto lbl_acc = labels.accessor(); - auto ids_acc = ids.accessor(); - for (int64_t row = 0; row < vecs.size(0); row++) { - int64_t c = lbl_acc[row]; - if (c != i) { - to_remove.push_back((idx_t)ids_acc[row]); - } - new_pids_array[row] = pids_accessor[c]; - } + std::tie(current_centroids, index_partitions) = kmeans_refine_partitions(current_centroids, + index_partitions, + parent_->metric_, + iterations); - partitions_->batch_update_entries( - pid, - new_pids_array.data(), - (u_int8_t *) vecs.data_ptr(), - ids_acc.data(), - vecs.size(0) - ); - if (debug_) { - std::cout << "[PartitionManager] refine_partitions: After updating partition " - << pid << ", new size: " << partitions_->list_size(pid) << std::endl; - } + // modify centroids + parent_->modify(partition_ids, current_centroids); + + // replace partitions + for (int i = 0; i < partition_ids.size(0); i++) { + partition_store_->partitions_[pids[i]] = index_partitions[i]; } - parent_->modify(pids, centroids); if (debug_) { std::cout << "[PartitionManager] refine_partitions: Completed refinement." << std::endl; } @@ -632,8 +561,8 @@ void PartitionManager::add_partitions(shared_ptr partitions) { auto p_ids_accessor = partitions->partition_ids.accessor(); for (int64_t i = 0; i < nlist; i++) { int64_t list_no = p_ids_accessor[i]; - partitions_->add_list(list_no); - partitions_->add_entries( + partition_store_->add_list(list_no); + partition_store_->add_entries( list_no, partitions->vectors[i].size(0), partitions->vector_ids[i].data_ptr(), @@ -659,7 +588,7 @@ void PartitionManager::delete_partitions(const Tensor &partition_ids, bool reass auto partition_ids_accessor = partition_ids.accessor(); for (int i = 0; i < partition_ids.size(0); i++) { int64_t list_no = partition_ids_accessor[i]; - partitions_->remove_list(list_no); + partition_store_->remove_list(list_no); if (debug_) { std::cout << "[PartitionManager] delete_partitions: Removed partition " << list_no << std::endl; } @@ -683,26 +612,27 @@ void PartitionManager::delete_partitions(const Tensor &partition_ids, bool reass } } -void PartitionManager::distribute_flat(int n_partitions) { + +void PartitionManager::distribute_partitions(int num_workers) { if (debug_) { - std::cout << "[PartitionManager] distribute_flat: Distributing flat index into " << n_partitions << " partitions." << std::endl; + std::cout << "[PartitionManager] distribute_partitions: Attempting to distribute partitions across " + << num_workers << " workers." << std::endl; } - if (parent_ != nullptr) { - throw runtime_error("Index is not flat"); - } else { - auto codes = (float *) partitions_->get_codes(0); - auto ids = (int64_t *) partitions_->get_ids(0); - int64_t ntotal = partitions_->list_size(0); + + if (parent_ == nullptr) { + auto codes = (float *) partition_store_->get_codes(0); + auto ids = (int64_t *) partition_store_->get_ids(0); + int64_t ntotal = partition_store_->list_size(0); Tensor vectors = torch::from_blob(codes, {ntotal, d()}, torch::kFloat32); Tensor vector_ids = torch::from_blob(ids, {ntotal}, torch::kInt64); - Tensor partition_assignments = torch::randint(n_partitions, {vectors.size(0)}, torch::kInt64); - Tensor partition_ids = torch::arange(n_partitions, torch::kInt64); - Tensor centroids = torch::empty({n_partitions, d()}, torch::kFloat32); - vector new_vectors(n_partitions); - vector new_ids(n_partitions); + Tensor partition_assignments = torch::randint(num_workers, {vectors.size(0)}, torch::kInt64); + Tensor partition_ids = torch::arange(num_workers, torch::kInt64); + Tensor centroids = torch::empty({num_workers, d()}, torch::kFloat32); + vector new_vectors(num_workers); + vector new_ids(num_workers); - for (int i = 0; i < n_partitions; i++) { + for (int i = 0; i < num_workers; i++) { Tensor ids = torch::nonzero(partition_assignments == i).squeeze(1); new_vectors[i] = vectors.index_select(0, ids); new_ids[i] = vector_ids.index_select(0, ids); @@ -724,52 +654,47 @@ void PartitionManager::distribute_flat(int n_partitions) { std::cout << "[PartitionManager] distribute_flat: Distribution complete." << std::endl; } } -} -void PartitionManager::distribute_partitions(int num_workers) { - if (debug_) { - std::cout << "[PartitionManager] distribute_partitions: Attempting to distribute partitions across " - << num_workers << " workers." << std::endl; - } - if (parent_ == nullptr) { - if (debug_) { - std::cout << "[PartitionManager] distribute_partitions: Index is flat." << std::endl; - } - throw runtime_error("Index is not partitioned"); - } else { - // TODO: Implement distribute_partitions with logging as needed. - if (debug_) { - std::cout << "[PartitionManager] distribute_partitions: (Not yet implemented)" << std::endl; - } + Tensor partition_ids = get_partition_ids(); + for (int i = 0; i < partition_store_->nlist; i++) { + set_partition_core_id(partition_ids[i].item(), i % num_workers); } } +void PartitionManager::set_partition_core_id(int64_t partition_id, int core_id) { + partition_store_->partitions_[partition_id]->core_id_ = core_id; +} + +int PartitionManager::get_partition_core_id(int64_t partition_id) { + return partition_store_->partitions_[partition_id]->core_id_; +} + int64_t PartitionManager::ntotal() const { - if (!partitions_) { + if (!partition_store_) { return 0; } - return partitions_->ntotal(); + return partition_store_->ntotal(); } int64_t PartitionManager::nlist() const { - if (!partitions_) { + if (!partition_store_) { return 0; } - return partitions_->nlist; + return partition_store_->nlist; } int PartitionManager::d() const { - if (!partitions_) { + if (!partition_store_) { return 0; } - return partitions_->d_; + return partition_store_->d_; } Tensor PartitionManager::get_partition_ids() { if (debug_) { std::cout << "[PartitionManager] get_partition_ids: Retrieving partition ids." << std::endl; } - return partitions_->get_partition_ids(); + return partition_store_->get_partition_ids(); } Tensor PartitionManager::get_ids() { @@ -779,20 +704,28 @@ Tensor PartitionManager::get_ids() { for (int i = 0; i < partition_ids.size(0); i++) { int64_t list_no = partition_ids_accessor[i]; - Tensor curr_ids = torch::from_blob((void *) partitions_->get_ids(list_no), - {(int64_t) partitions_->list_size(list_no)}, torch::kInt64); + Tensor curr_ids = torch::from_blob((void *) partition_store_->get_ids(list_no), + {(int64_t) partition_store_->list_size(list_no)}, torch::kInt64); ids.push_back(curr_ids); } return torch::cat(ids, 0); } +vector PartitionManager::get_partition_sizes(vector partition_ids) { + vector partition_sizes; + for (int64_t partition_id : partition_ids) { + partition_sizes.push_back(partition_store_->list_size(partition_id)); + } + return partition_sizes; +} + Tensor PartitionManager::get_partition_sizes(Tensor partition_ids) { if (debug_) { std::cout << "[PartitionManager] get_partition_sizes: Getting sizes for partitions." << std::endl; } - if (!partitions_) { - throw runtime_error("[PartitionManager] get_partition_sizes: partitions_ is null."); + if (!partition_store_) { + throw runtime_error("[PartitionManager] get_partition_sizes: partition_store_ is null."); } if (!partition_ids.defined() || partition_ids.size(0) == 0) { partition_ids = get_partition_ids(); @@ -803,7 +736,7 @@ Tensor PartitionManager::get_partition_sizes(Tensor partition_ids) { auto partition_sizes_accessor = partition_sizes.accessor(); for (int i = 0; i < partition_ids.size(0); i++) { int64_t list_no = partition_ids_accessor[i]; - partition_sizes_accessor[i] = partitions_->list_size(list_no); + partition_sizes_accessor[i] = partition_store_->list_size(list_no); if (debug_) { std::cout << "[PartitionManager] get_partition_sizes: Partition " << list_no << " size: " << partition_sizes_accessor[i] << std::endl; @@ -812,12 +745,17 @@ Tensor PartitionManager::get_partition_sizes(Tensor partition_ids) { return partition_sizes; } +int64_t PartitionManager::get_partition_size(int64_t partition_id) { + return partition_store_->list_size(partition_id); +} + + bool PartitionManager::validate() { if (debug_) { std::cout << "[PartitionManager] validate: Validating partitions." << std::endl; } - if (!partitions_) { - throw runtime_error("[PartitionManager] validate: partitions_ is null."); + if (!partition_store_) { + throw runtime_error("[PartitionManager] validate: partition_store_ is null."); } return true; } @@ -827,10 +765,10 @@ void PartitionManager::save(const string &path) { if (debug_) { std::cout << "[PartitionManagerPartitionManager] save: Saving partitions to " << path << std::endl; } - if (!partitions_) { + if (!partition_store_) { throw runtime_error("No partitions to save"); } - partitions_->save(path); + partition_store_->save(path); if (debug_) { std::cout << "[PartitionManager] save: Save complete." << std::endl; } @@ -840,11 +778,11 @@ void PartitionManager::load(const string &path) { if (debug_) { std::cout << "[PartitionManager] load: Loading partitions from " << path << std::endl; } - if (!partitions_) { - partitions_ = std::make_shared(0, 0); + if (!partition_store_) { + partition_store_ = std::make_shared(0, 0); } - partitions_->load(path); - curr_partition_id_ = partitions_->nlist; + partition_store_->load(path); + curr_partition_id_ = partition_store_->nlist; if (check_uniques_) { // add ids into resident set diff --git a/src/cpp/src/quake_index.cpp b/src/cpp/src/quake_index.cpp index 45b10099..3959911d 100644 --- a/src/cpp/src/quake_index.cpp +++ b/src/cpp/src/quake_index.cpp @@ -6,6 +6,7 @@ #include #include +#include QuakeIndex::QuakeIndex(int current_level) { // Initialize the QuakeIndex @@ -56,6 +57,7 @@ shared_ptr QuakeIndex::build(Tensor x, Tensor ids, shared_ptr(current_level_ + 1); auto parent_build_params = make_shared(); parent_build_params->metric = build_params_->metric; + parent_build_params->num_workers = build_params_->num_workers; parent_->build(clustering->centroids, clustering->partition_ids, parent_build_params); // initialize the partition manager @@ -146,7 +148,11 @@ shared_ptr QuakeIndex::modify(Tensor ids, Tensor x) { void QuakeIndex::initialize_maintenance_policy(shared_ptr maintenance_policy_params) { maintenance_policy_params_ = maintenance_policy_params; - maintenance_policy_ = make_shared(partition_manager_, maintenance_policy_params); + maintenance_policy_ = make_shared(partition_manager_, maintenance_policy_params); + + if (query_coordinator_ != nullptr) { + query_coordinator_->maintenance_policy_ = maintenance_policy_; + } } shared_ptr QuakeIndex::maintenance() { @@ -154,7 +160,7 @@ shared_ptr QuakeIndex::maintenance() { throw std::runtime_error("[QuakeIndex::maintenance()] No maintenance policy set."); } - return maintenance_policy_->maintenance(); + return maintenance_policy_->perform_maintenance(); } bool QuakeIndex::validate() { @@ -245,7 +251,7 @@ void QuakeIndex::load(const std::string& dir_path, int n_workers) { std::string parent_dir = (fs::path(dir_path) / "parent").string(); if (fs::exists(parent_dir) && fs::is_directory(parent_dir)) { parent_ = std::make_shared(); - parent_->load(parent_dir); + parent_->load(parent_dir, n_workers); partition_manager_->parent_ = parent_; } else { parent_ = nullptr; diff --git a/src/cpp/src/query_coordinator.cpp b/src/cpp/src/query_coordinator.cpp index 1580bd65..9c4fde0a 100644 --- a/src/cpp/src/query_coordinator.cpp +++ b/src/cpp/src/query_coordinator.cpp @@ -26,7 +26,8 @@ QueryCoordinator::QueryCoordinator(shared_ptr parent, maintenance_policy_(maintenance_policy), metric_(metric), num_workers_(num_workers), - stop_workers_(false) { + workers_initialized_(false) { + if (num_workers_ > 0) { initialize_workers(num_workers_); } @@ -37,38 +38,41 @@ QueryCoordinator::~QueryCoordinator() { shutdown_workers(); } +void QueryCoordinator::allocate_core_resources(int core_idx, int num_queries, int k, int d) { + CoreResources &res = core_resources_[core_idx]; + res.core_id = core_idx; + res.local_query_buffer.resize(num_queries * d * sizeof(float)); + res.topk_buffer_pool.resize(num_queries); + for (int q = 0; q < num_queries; ++q) { + res.topk_buffer_pool[q] = make_shared(k, metric_ == faiss::METRIC_INNER_PRODUCT); + res.job_queue = moodycamel::BlockingConcurrentQueue(); + } + +} + // Initialize Worker Threads -void QueryCoordinator::initialize_workers(int num_workers) { +void QueryCoordinator::initialize_workers(int num_cores) { if (workers_initialized_) { std::cerr << "[QueryCoordinator::initialize_workers] Workers already initialized." << std::endl; return; } - std::cout << "[QueryCoordinator::initialize_workers] Initializing " << num_workers << " worker threads." << + std::cout << "[QueryCoordinator::initialize_workers] Initializing " << num_cores << " worker threads." << std::endl; - // Reserve space for worker threads and job queues - worker_threads_.reserve(num_workers); - jobs_queue_.resize(num_workers); - - std::cout << "[QueryCoordinator::initialize_workers] Creating " << num_workers << " worker threads." << std::endl; - - // Spawn worker threads - for (int worker_id = 0; worker_id < num_workers; worker_id++) { - worker_threads_.emplace_back(&QueryCoordinator::partition_scan_worker_fn, this, worker_id); - } - - std::cout << "[QueryCoordinator::initialize_workers] Worker threads created." << std::endl; + partition_manager_->distribute_partitions(num_cores); - // if the index is flat we should modify the partition manager to have a single partition - if (parent_ == nullptr) { - partition_manager_->distribute_flat(num_workers); - } else { - partition_manager_->distribute_partitions(num_workers); + core_resources_.resize(num_cores); + worker_threads_.resize(num_cores); + worker_job_counter_.reserve(num_cores); + for (int i = 0; i < num_cores; i++) { + if (!set_thread_affinity(i)) { + std::cout << "[QueryCoordinator::initialize_workers] Failed to set thread affinity on core " << i << std::endl; + } + allocate_core_resources(i, 1, 10, partition_manager_->d()); + worker_threads_[i] = std::thread(&QueryCoordinator::partition_scan_worker_fn, this, i); + worker_job_counter_[i] = 0; } - - std::cout << "[QueryCoordinator::initialize_workers] Partitions distributed." << std::endl; - workers_initialized_ = true; } @@ -78,101 +82,99 @@ void QueryCoordinator::shutdown_workers() { return; } - // Signal all workers to stop by enqueueing a special shutdown job ID (-1) - for (int worker_id = 0; worker_id < num_workers_; ++worker_id) { - jobs_queue_[worker_id].enqueue(-1); // -1 is reserved for shutdown + stop_workers_.store(true); + // Enqueue a special shutdown job for each core. + for (auto &res : core_resources_) { + ScanJob termination_job; + termination_job.partition_id = -1; + res.job_queue.enqueue(termination_job); } - - // Join all worker threads - for (auto &thread: worker_threads_) { - if (thread.joinable()) { - thread.join(); - } + // Join all worker threads. + for (auto &thr : worker_threads_) { + if (thr.joinable()) + thr.join(); } - - // Clear worker-related data worker_threads_.clear(); - jobs_queue_.clear(); - stop_workers_.store(false); workers_initialized_ = false; - - // Clear any remaining jobs to prevent memory leaks - { - std::lock_guard lock(result_mutex_); - jobs_.clear(); - query_topk_buffers_.clear(); - } } // Worker Thread Function -void QueryCoordinator::partition_scan_worker_fn(int worker_id) { - // For non-batched jobs, we reuse a single buffer. - shared_ptr local_topk_buffer; - // For batched jobs, we use a thread-local vector to avoid repeated allocation. - thread_local std::vector> local_buffers; +void QueryCoordinator::partition_scan_worker_fn(int core_index) { + + CoreResources &res = core_resources_[core_index]; + + if (!set_thread_affinity(core_index)) { + std::cout << "[QueryCoordinator::partition_scan_worker_fn] Failed to set thread affinity on core " << core_index << std::endl; + } + + while (true) { - int job_id; - jobs_queue_[worker_id].wait_dequeue(job_id); + ScanJob job; + + auto job_wait_start = std::chrono::high_resolution_clock::now(); + res.job_queue.wait_dequeue(job); + auto job_wait_end = std::chrono::high_resolution_clock::now(); + + job_pull_time_ns += std::chrono::duration_cast(job_wait_end - job_wait_start).count(); + + auto job_process_start = std::chrono::high_resolution_clock::now(); + shared_ptr local_topk_buffer = res.topk_buffer_pool[0]; // Shutdown signal: -1 indicates the worker should exit. - if (job_id == -1) { + if (job.partition_id == -1) { break; } - // Look up the job. - ScanJob job; - { - std::lock_guard lock(result_mutex_); - auto it = jobs_.find(job_id); - if (it == jobs_.end()) { - std::cerr << "[partition_scan_worker_fn] Invalid job_id " << job_id << std::endl; - continue; - } - job = it->second; - } - - // If the query has already been processed, skip this job. - if (!query_topk_buffers_[job.query_ids[0]]->currently_processing_query()) { - { - std::lock_guard lock(result_mutex_); - jobs_.erase(job_id); - } + // Ignore this job if the global buffer is not processing queries. + if (!global_topk_buffer_pool_[job.query_ids[0]]->currently_processing_query()) { + // decrement the job counter + global_topk_buffer_pool_[job.query_ids[0]]->record_empty_job(); continue; } + worker_job_counter_[core_index]++; + // Retrieve partition data. - const float *partition_codes = (float *) partition_manager_->partitions_->get_codes(job.partition_id); - const int64_t *partition_ids = (int64_t *) partition_manager_->partitions_->get_ids(job.partition_id); - int64_t partition_size = partition_manager_->partitions_->list_size(job.partition_id); + const float *partition_codes = (float *) partition_manager_->partition_store_->get_codes(job.partition_id); + const int64_t *partition_ids = (int64_t *) partition_manager_->partition_store_->get_ids(job.partition_id); + int64_t partition_size = partition_manager_->partition_store_->list_size(job.partition_id); // Branch for non-batched jobs. if (!job.is_batched) { - const float *query_vector = job.query_vector; - if (!query_vector) { - throw std::runtime_error("[partition_scan_worker_fn] query_vector is null."); + + // Allocate a thread-local query buffer if needed. + if (res.local_query_buffer.size() < partition_manager_->d() * sizeof(float)) { + res.local_query_buffer.resize(partition_manager_->d() * sizeof(float)); + } + + // Copy the contents of the query vector to the local buffer using memcpy. + if (memcpy(res.local_query_buffer.data(), job.query_vector, partition_manager_->d() * sizeof(float)) == nullptr) { + throw std::runtime_error("[partition_scan_worker_fn] memcpy failed."); } + if (local_topk_buffer == nullptr) { - local_topk_buffer = std::make_shared(job.k, metric_ == faiss::METRIC_INNER_PRODUCT); + throw std::runtime_error("[partition_scan_worker_fn] local_topk_buffer is null."); } else { local_topk_buffer->set_k(job.k); local_topk_buffer->reset(); } // Perform the scan on the partition. - scan_list(query_vector, partition_codes, partition_ids, - partition_size, partition_manager_->d(), *local_topk_buffer, metric_); + scan_list((float *) res.local_query_buffer.data(), + partition_codes, + partition_ids, + partition_size, + partition_manager_->d(), + *local_topk_buffer, + metric_); + vector topk = local_topk_buffer->get_topk(); vector topk_indices = local_topk_buffer->get_topk_indices(); int64_t n_results = topk_indices.size(); // Merge local results into the global query buffer. - { - std::lock_guard lock(result_mutex_); - if (query_topk_buffers_[job.query_ids[0]]) { - query_topk_buffers_[job.query_ids[0]]->batch_add(topk.data(), topk_indices.data(), n_results); - } - jobs_.erase(job_id); - } + global_topk_buffer_pool_[job.query_ids[0]]->batch_add(topk.data(), topk_indices.data(), n_results); + job_flags_[job.query_ids[0]][job.rank] = true; } // Batched job branch. else { @@ -180,36 +182,64 @@ void QueryCoordinator::partition_scan_worker_fn(int worker_id) { throw std::runtime_error("[partition_scan_worker_fn] Invalid batched job."); } + // Allocate a thread-local query buffer if needed. + if (res.local_query_buffer.size() < partition_manager_->d() * sizeof(float) * job.num_queries) { + res.local_query_buffer.resize(partition_manager_->d() * sizeof(float) * job.num_queries); + } + + int64_t d = partition_manager_->d(); + std::vector query_subset(job.num_queries * d); + for (int i = 0; i < job.num_queries; i++) { + int64_t global_q = job.query_ids[i]; + memcpy(&query_subset[i * d], + job.query_vector + global_q * d, + d * sizeof(float)); + } + // Then copy query_subset into your local buffer if needed: + if(memcpy(res.local_query_buffer.data(), + query_subset.data(), + query_subset.size() * sizeof(float)) == nullptr) { + throw std::runtime_error("[partition_scan_worker_fn] memcpy failed."); + } + // Use a thread_local vector to hold one buffer per query. - if (local_buffers.size() < static_cast(job.num_queries)) { - local_buffers.resize(job.num_queries); + if (res.topk_buffer_pool.size() < static_cast(job.num_queries)) { + res.topk_buffer_pool.resize(job.num_queries); for (int64_t q = 0; q < job.num_queries; ++q) { - local_buffers[q] = std::make_shared(job.k, metric_ == faiss::METRIC_INNER_PRODUCT); + res.topk_buffer_pool[q] = std::make_shared(job.k, metric_ == faiss::METRIC_INNER_PRODUCT); } } else { for (int64_t q = 0; q < job.num_queries; ++q) { - local_buffers[q]->set_k(job.k); - local_buffers[q]->reset(); + res.topk_buffer_pool[q]->set_k(job.k); + res.topk_buffer_pool[q]->reset(); } } + // Process the batched job. - batched_scan_list(job.query_vector, partition_codes, partition_ids, - job.num_queries, partition_size, - partition_manager_->d(), local_buffers, metric_); - { - std::lock_guard lock(result_mutex_); - for (int64_t q = 0; q < job.num_queries; q++) { - int64_t global_q = job.query_ids[q]; - vector topk = local_buffers[q]->get_topk(); - vector topk_indices = local_buffers[q]->get_topk_indices(); - int n_results = topk_indices.size(); - if (query_topk_buffers_[global_q]) { - query_topk_buffers_[global_q]->batch_add(topk.data(), topk_indices.data(), n_results); - } - } - jobs_.erase(job_id); + batched_scan_list((float *) res.local_query_buffer.data(), + partition_codes, + partition_ids, + job.num_queries, + partition_size, + partition_manager_->d(), + res.topk_buffer_pool, + metric_); + + vector> topk_list(job.num_queries); + vector> topk_indices_list(job.num_queries); + for (int64_t q = 0; q < job.num_queries; q++) { + topk_list[q] = res.topk_buffer_pool[q]->get_topk(); + topk_indices_list[q] = res.topk_buffer_pool[q]->get_topk_indices(); + } + + for (int64_t q = 0; q < job.num_queries; q++) { + int64_t global_q = job.query_ids[q]; + int n_results = topk_indices_list[q].size(); + global_topk_buffer_pool_[global_q]->batch_add(topk_list[q].data(), topk_indices_list[q].data(), n_results); } } + auto job_process_end = std::chrono::high_resolution_clock::now(); + job_process_time_ns += std::chrono::duration_cast(job_process_end - job_process_start).count(); } } @@ -222,7 +252,6 @@ shared_ptr QueryCoordinator::worker_scan( throw std::runtime_error("[QueryCoordinator::worker_scan] partition_manager_ is null."); } - // Handle trivial case: no input queries if (!x.defined() || x.size(0) == 0) { auto empty_result = std::make_shared(); empty_result->ids = torch::empty({0}, torch::kInt64); @@ -231,252 +260,219 @@ shared_ptr QueryCoordinator::worker_scan( return empty_result; } - - // Basic parameters int64_t num_queries = x.size(0); int64_t dimension = x.size(1); int k = search_params->k; int64_t nlist = partition_manager_->nlist(); - bool use_aps = search_params->recall_target > 0.0 && !search_params->batched_scan; - shared_ptr timing_info = make_shared(); + bool use_aps = (search_params->recall_target > 0.0 && !search_params->batched_scan); + auto timing_info = make_shared(); timing_info->n_queries = num_queries; timing_info->n_clusters = nlist; timing_info->search_params = search_params; - auto start_time = std::chrono::high_resolution_clock::now(); - // Initialize a Top-K buffer for each query using a global buffer pool. + float *x_ptr = x.data_ptr(); + + auto start_time = high_resolution_clock::now(); + + if (partition_ids.dim() == 1) { + partition_ids = partition_ids.unsqueeze(0).expand({num_queries, partition_ids.size(0)}); + } + + job_flags_.clear(); + job_flags_.resize(num_queries); + for (int64_t q = 0; q < num_queries; q++) { + job_flags_[q] = vector>(partition_ids.size(1)); + for (int64_t p = 0; p < partition_ids.size(1); p++) { + job_flags_[q][p] = false; + } + } + + job_pull_time_ns = 0; + job_process_time_ns = 0; + { - std::lock_guard lock(result_mutex_); - // If our global buffer pool is smaller than needed, enlarge it. - if (query_topk_buffers_.size() < static_cast(num_queries)) { - int old_size = query_topk_buffers_.size(); - query_topk_buffers_.resize(num_queries); + std::lock_guard lock(global_mutex_); + if (global_topk_buffer_pool_.size() < static_cast(num_queries)) { + std::cout << "Resizing query_topk_buffers_ from " << global_topk_buffer_pool_.size() + << " to " << num_queries << std::endl; + int old_size = global_topk_buffer_pool_.size(); + global_topk_buffer_pool_.resize(num_queries); for (int64_t q = old_size; q < num_queries; q++) { - query_topk_buffers_[q] = std::make_shared(k, metric_ == faiss::METRIC_INNER_PRODUCT); + global_topk_buffer_pool_[q] = std::make_shared(k, metric_ == faiss::METRIC_INNER_PRODUCT); } } else { - // Otherwise, reset and update k for each existing buffer. for (int64_t q = 0; q < num_queries; q++) { - query_topk_buffers_[q]->set_k(k); - query_topk_buffers_[q]->reset(); + global_topk_buffer_pool_[q]->set_k(k); + global_topk_buffer_pool_[q]->reset(); + global_topk_buffer_pool_[q]->set_processing_query(true); } } - // Set the job count per query based on partition_ids shape. - if (partition_ids.dim() == 1) { - for (int64_t q = 0; q < num_queries; q++) { - query_topk_buffers_[q]->set_jobs_left(partition_ids.size(0)); - } - } else { - for (int64_t q = 0; q < num_queries; q++) { - query_topk_buffers_[q]->set_jobs_left(partition_ids.size(1)); - } + for (int64_t q = 0; q < num_queries; q++) { + global_topk_buffer_pool_[q]->set_jobs_left(partition_ids.size(1)); } } - auto end_time = std::chrono::high_resolution_clock::now(); - timing_info->buffer_init_time_ns = std::chrono::duration_cast(end_time - start_time). - count(); + auto end_time = high_resolution_clock::now(); + timing_info->buffer_init_time_ns = + duration_cast(end_time - start_time).count(); - - start_time = std::chrono::high_resolution_clock::now(); - // ============================ - // 1) BATCHED-SCAN BRANCH - // ============================ + start_time = high_resolution_clock::now(); if (search_params->batched_scan) { - // Force partition_ids to shape [num_queries, num_partitions] - if (partition_ids.dim() == 1) { - partition_ids = partition_ids.unsqueeze(0).expand({num_queries, partition_ids.size(0)}); - } auto partition_ids_accessor = partition_ids.accessor(); - // Collect all unique partitions we need to scan - std::unordered_map > per_partition_query_ids; + std::unordered_map> per_partition_query_ids; for (int64_t q = 0; q < num_queries; q++) { for (int64_t p = 0; p < partition_ids.size(1); p++) { int64_t pid = partition_ids_accessor[q][p]; - if (pid < 0) { - continue; - } + if (pid < 0) continue; per_partition_query_ids[pid].push_back(q); } } - - vector unique_partitions; - for (const auto &kv: per_partition_query_ids) { - unique_partitions.push_back(kv.first); - } - - // Enqueue exactly one job per unique partition - int job_counter = 0; - for (auto partition_id: unique_partitions) { - int job_id = job_counter++; - - int nq = per_partition_query_ids[partition_id].size(); - - ScanJob scan_job; - scan_job.is_batched = true; // <=== Key difference - scan_job.partition_id = partition_id; - scan_job.k = k; - scan_job.query_vector = x.data_ptr(); - scan_job.num_queries = nq; - scan_job.query_ids = per_partition_query_ids[partition_id]; { - std::lock_guard lock(result_mutex_); - jobs_[job_id] = scan_job; - } - - // Assign a worker: e.g., partition_id mod num_workers_ - int worker_id = partition_id % num_workers_; - jobs_queue_[worker_id].enqueue(job_id); - } - } - - // ============================ - // 2) SINGLE-QUERY SCAN BRANCH - // ============================ - else { - // If shape is 1D, we expand so that each query sees the same partitions - if (partition_ids.dim() == 1) { - partition_ids = partition_ids.unsqueeze(0).expand({num_queries, partition_ids.size(0)}); + for (auto &kv : per_partition_query_ids) { + ScanJob job; + job.is_batched = true; + job.partition_id = kv.first; + job.k = k; + job.query_vector = x.data_ptr(); + job.num_queries = kv.second.size(); + job.query_ids = kv.second; + int core_id = partition_manager_->get_partition_core_id(kv.first); + core_resources_[core_id].job_queue.enqueue(job); } + } else { auto partition_ids_accessor = partition_ids.accessor(); - // Create a job for each (query, partition) - vector > worker_ids_per_job_id; - for (int64_t q = 0; q < num_queries; q++) { + int64_t start = 0; + int64_t end = num_queries; + parallel_for(start, end, [&](int64_t q) { for (int64_t p = 0; p < partition_ids.size(1); p++) { - int64_t partition_id = partition_ids_accessor[q][p]; - if (partition_id == -1) { - continue; // skip invalid - } - - // Generate a unique job_id - // For big nlist, watch for overflow - int job_id = static_cast(q * nlist + partition_id); - - ScanJob scan_job; - scan_job.is_batched = false; - scan_job.query_ids = {q}; - scan_job.partition_id = partition_id; - scan_job.k = k; - scan_job.query_vector = x[q].data_ptr(); - scan_job.num_queries = 1; { - std::lock_guard lock(result_mutex_); - jobs_[job_id] = scan_job; - } - - int worker_id = partition_id % num_workers_; - worker_ids_per_job_id.push_back({worker_id, job_id}); + int64_t pid = partition_ids_accessor[q][p]; + if (pid == -1) continue; + + ScanJob job; + job.is_batched = false; + job.query_ids = {q}; + job.partition_id = pid; + job.k = k; + job.query_vector = x_ptr + q * dimension; + job.num_queries = 1; + job.rank = p; + + int core_id = partition_manager_->get_partition_core_id(pid); + core_resources_[core_id].job_queue.enqueue(job); } - } - for (const auto &pair: worker_ids_per_job_id) { - jobs_queue_[pair.first].enqueue(pair.second); - } + }, search_params->num_threads); } - end_time = std::chrono::high_resolution_clock::now(); - timing_info->job_enqueue_time_ns = std::chrono::duration_cast(end_time - start_time). - count(); - - auto last_flush_time = std::chrono::high_resolution_clock::now(); + end_time = high_resolution_clock::now(); + timing_info->job_enqueue_time_ns = duration_cast(end_time - start_time).count(); - // ============================ - // 3) WAIT FOR WORKERS - // ============================ - start_time = std::chrono::high_resolution_clock::now(); - vector boundary_distances = vector(num_queries); + auto last_flush_time = high_resolution_clock::now(); + vector> boundary_distances(num_queries); if (use_aps) { for (int64_t q = 0; q < num_queries; q++) { - Tensor cluster_centroids = parent_->get(partition_ids[q]); + vector partition_ids_to_scan_vec = vector(partition_ids[q].data_ptr(), + partition_ids[q].data_ptr() + partition_ids[q].size(0)); + vector cluster_centroids = parent_->partition_manager_->get_vectors(partition_ids_to_scan_vec); boundary_distances[q] = compute_boundary_distances(x[q], - cluster_centroids, - metric_ == faiss::METRIC_L2); + cluster_centroids, + metric_ == faiss::METRIC_L2); } } - end_time = std::chrono::high_resolution_clock::now(); - timing_info->boundary_distance_time_ns = std::chrono::duration_cast(end_time - start_time) - .count(); - start_time = std::chrono::high_resolution_clock::now(); + start_time = high_resolution_clock::now(); + last_flush_time = high_resolution_clock::now(); + + vector query_radius(num_queries, 0.0f); + vector> probs(num_queries); + while (true) { - { - int time_since_last_flush_us = std::chrono::duration_cast( - std::chrono::high_resolution_clock::now() - last_flush_time).count(); - if (use_aps && time_since_last_flush_us > search_params->aps_flush_period_us) { - for (int64_t q = 0; q < num_queries; q++) { - shared_ptr curr_buffer = query_topk_buffers_[q]; - Tensor curr_boundary_distances = boundary_distances[q]; - - int partitions_scanned = curr_buffer->get_num_partitions_scanned(); - if (curr_buffer->currently_processing_query() && partitions_scanned > 0 && partitions_scanned < - curr_boundary_distances.size(0)) { - float query_radius = curr_buffer->get_kth_distance(); - Tensor partition_probabilities = compute_recall_profile(curr_boundary_distances, - query_radius, - dimension, - {}, - search_params->use_precomputed, - metric_ == faiss::METRIC_L2); - Tensor recall_profile = torch::cumsum(partition_probabilities, 0); - last_flush_time = std::chrono::high_resolution_clock::now(); - if (recall_profile[partitions_scanned - 1].item() > search_params->recall_target) { - curr_buffer->set_processing_query(false); - break; + // check if all jobs have been processed + bool all_done = true; + for (int64_t q = 0; q < num_queries; q++) { + all_done = all_done && (global_topk_buffer_pool_[q]->jobs_left_ == 0); + } + + if (all_done) { + break; + } + + // check if recall target has been reached + if (use_aps && duration_cast(high_resolution_clock::now() - last_flush_time).count() + > search_params->aps_flush_period_us) { + for (int64_t q = 0; q < num_queries; q++) { + auto curr_buffer = global_topk_buffer_pool_[q]; + int scanned = curr_buffer->get_num_partitions_scanned(); + if (curr_buffer->currently_processing_query() && + scanned > 0 && scanned < (int) boundary_distances[q].size()) { + float radius = curr_buffer->get_kth_distance(); + + if (query_radius[q] != radius) { + query_radius[q] = radius; + + // recompute recall profile if the radius has changed + probs[q] = compute_recall_profile(boundary_distances[q], + radius, + dimension, + {}, + search_params->use_precomputed, + metric_ == faiss::METRIC_L2); + } + + float cum = 0.0f; + for (int i = 0; i < partition_ids.size(1); i++) { + if (job_flags_[q][i]) { + cum += probs[q][i]; } } + if (cum > search_params->recall_target) { + curr_buffer->set_processing_query(false); + } } } - - std::lock_guard lock(result_mutex_); - if (jobs_.empty()) { - break; // all jobs done - } + last_flush_time = high_resolution_clock::now(); } - std::this_thread::sleep_for(std::chrono::microseconds(1)); + std::this_thread::sleep_for(microseconds(1)); } - end_time = std::chrono::high_resolution_clock::now(); - timing_info->job_wait_time_ns = std::chrono::duration_cast(end_time - start_time).count(); + end_time = high_resolution_clock::now(); + timing_info->job_wait_time_ns = + duration_cast(end_time - start_time).count(); - // ============================ - // 4) AGGREGATE RESULTS - // ============================ - start_time = std::chrono::high_resolution_clock::now(); + // Aggregate results. + start_time = high_resolution_clock::now(); auto topk_ids = torch::full({num_queries, k}, -1, torch::kInt64); - auto topk_distances = torch::full({num_queries, k}, - std::numeric_limits::infinity(), torch::kFloat32); - auto topk_ids_accessor = topk_ids.accessor(); - auto topk_distances_accessor = topk_distances.accessor(); { - std::lock_guard lock(result_mutex_); + auto topk_dists = torch::full({num_queries, k}, + std::numeric_limits::infinity(), torch::kFloat32); + auto ids_accessor = topk_ids.accessor(); + auto dists_accessor = topk_dists.accessor(); + { + std::lock_guard lock(global_mutex_); for (int64_t q = 0; q < num_queries; q++) { - auto topk = query_topk_buffers_[q]->get_topk(); - auto ids = query_topk_buffers_[q]->get_topk_indices(); - auto distances = topk; // same vector - + auto topk = global_topk_buffer_pool_[q]->get_topk(); + auto ids = global_topk_buffer_pool_[q]->get_topk_indices(); for (int i = 0; i < k; i++) { if (i < (int) ids.size()) { - topk_ids_accessor[q][i] = ids[i]; - topk_distances_accessor[q][i] = distances[i]; + ids_accessor[q][i] = ids[i]; + dists_accessor[q][i] = topk[i]; } else { - // if metric is inner product, fill with -infinity beyond topk - topk_ids_accessor[q][i] = -1; - topk_distances_accessor[q][i] = (metric_ == faiss::METRIC_INNER_PRODUCT) - ? -std::numeric_limits::infinity() - : std::numeric_limits::infinity(); + ids_accessor[q][i] = -1; + dists_accessor[q][i] = (metric_ == faiss::METRIC_INNER_PRODUCT) + ? -std::numeric_limits::infinity() + : std::numeric_limits::infinity(); } } - timing_info->partitions_scanned = query_topk_buffers_[q]->get_num_partitions_scanned(); + timing_info->partitions_scanned = global_topk_buffer_pool_[q]->get_num_partitions_scanned(); } } - end_time = std::chrono::high_resolution_clock::now(); - timing_info->result_aggregate_time_ns = std::chrono::duration_cast(end_time - start_time). - count(); - - // Final SearchResult + end_time = high_resolution_clock::now(); + timing_info->result_aggregate_time_ns = duration_cast(end_time - start_time).count(); auto search_result = std::make_shared(); search_result->ids = topk_ids; - search_result->distances = topk_distances; + search_result->distances = topk_dists; search_result->timing_info = timing_info; return search_result; } + bool* create_bitmap(std::unordered_map id_to_price, int64_t* list_ids, int64_t num_ids, shared_ptr search_params) { @@ -507,7 +503,7 @@ shared_ptr QueryCoordinator::serial_scan(Tensor x, Tensor partitio return empty_result; } - auto start_time = std::chrono::high_resolution_clock::now(); + auto start_time = high_resolution_clock::now(); int64_t num_queries = x.size(0); int64_t dimension = x.size(1); @@ -516,7 +512,7 @@ shared_ptr QueryCoordinator::serial_scan(Tensor x, Tensor partitio // Preallocate output tensors. auto ret_ids = torch::full({num_queries, k}, -1, torch::kInt64); auto ret_dists = torch::full({num_queries, k}, - std::numeric_limits::infinity(), torch::kFloat32); + std::numeric_limits::infinity(), torch::kFloat32); auto timing_info = std::make_shared(); timing_info->n_queries = num_queries; @@ -531,11 +527,11 @@ shared_ptr QueryCoordinator::serial_scan(Tensor x, Tensor partitio partition_ids_to_scan = partition_ids_to_scan.unsqueeze(0).expand({num_queries, partition_ids_to_scan.size(0)}); } auto partition_ids_accessor = partition_ids_to_scan.accessor(); - float* x_ptr = x.data_ptr(); + float *x_ptr = x.data_ptr(); // Allocate per-query result vectors. - std::vector> all_topk_dists(num_queries); - std::vector> all_topk_ids(num_queries); + vector> all_topk_dists(num_queries); + vector> all_topk_ids(num_queries); // Use our custom parallel_for to process queries in parallel. parallel_for(0, num_queries, [&](int64_t q) { @@ -544,25 +540,22 @@ shared_ptr QueryCoordinator::serial_scan(Tensor x, Tensor partitio const float* query_vec = x_ptr + q * dimension; int num_parts = partition_ids_to_scan.size(1); - Tensor boundary_distances; - Tensor partition_probabilities; + vector boundary_distances; + vector partition_probs; float query_radius = 1000000.0; if (metric_ == faiss::METRIC_INNER_PRODUCT) { query_radius = -1000000.0; } - Tensor sort_args = torch::arange(partition_ids_to_scan.size(1), torch::kInt64); Tensor partition_sizes = partition_manager_->get_partition_sizes(partition_ids_to_scan[q]); if (use_aps) { - Tensor cluster_centroids = parent_->get(partition_ids_to_scan[q]); + vector partition_ids_to_scan_vec = std::vector(partition_ids_to_scan[q].data_ptr(), + partition_ids_to_scan[q].data_ptr() + partition_ids_to_scan[q].size(0)); + vector cluster_centroids = parent_->partition_manager_->get_vectors(partition_ids_to_scan_vec); boundary_distances = compute_boundary_distances(x[q], cluster_centroids, metric_ == faiss::METRIC_L2); - - // sort order by boundary distance - sort_args = torch::argsort(boundary_distances, 0, false); } - auto sort_args_accessor = sort_args.accessor(); for (int p = 0; p < num_parts; p++) { int64_t pi = partition_ids_accessor[q][p]; @@ -570,12 +563,12 @@ shared_ptr QueryCoordinator::serial_scan(Tensor x, Tensor partitio continue; // Skip invalid partitions } - start_time = std::chrono::high_resolution_clock::now(); - float *list_vectors = (float *) partition_manager_->partitions_->get_codes(pi); - int64_t *list_ids = (int64_t *) partition_manager_->partitions_->get_ids(pi); + start_time = high_resolution_clock::now(); + float *list_vectors = (float *) partition_manager_->partition_store_->get_codes(pi); + int64_t *list_ids = (int64_t *) partition_manager_->partition_store_->get_ids(pi); std::shared_ptr partition_attributes_table = - partition_manager_->partitions_->partitions_[pi]->attributes_table_; - int64_t list_size = partition_manager_->partitions_->list_size(pi); + partition_manager_->partition_store_->partitions_[pi]->attributes_table_; + int64_t list_size = partition_manager_->partition_store_->list_size(pi); std::shared_ptr id_array = nullptr; std::shared_ptr price_array = nullptr; @@ -600,7 +593,7 @@ shared_ptr QueryCoordinator::serial_scan(Tensor x, Tensor partitio scan_list(query_vec, list_vectors, list_ids, - partition_manager_->partitions_->list_size(pi), + partition_manager_->partition_store_->list_size(pi), dimension, *topk_buf, metric_, @@ -619,18 +612,23 @@ shared_ptr QueryCoordinator::serial_scan(Tensor x, Tensor partitio float curr_radius = topk_buf->get_kth_distance(); float percent_change = abs(curr_radius - query_radius) / curr_radius; - start_time = std::chrono::high_resolution_clock::now(); + start_time = high_resolution_clock::now(); if (use_aps) { if (percent_change > search_params->recompute_threshold) { query_radius = curr_radius; - partition_probabilities = compute_recall_profile(boundary_distances, - query_radius, - dimension, - {}, - search_params->use_precomputed, - metric_ == faiss::METRIC_L2).cumsum(0); + + partition_probs = compute_recall_profile(boundary_distances, + query_radius, + dimension, + {}, + search_params->use_precomputed, + metric_ == faiss::METRIC_L2); } - if (partition_probabilities[p].item() >= search_params->recall_target) { + float recall_estimate = 0.0; + for (int i = 0; i < p; i++) { + recall_estimate += partition_probs[i]; + } + if (recall_estimate >= search_params->recall_target) { break; } } @@ -657,8 +655,8 @@ shared_ptr QueryCoordinator::serial_scan(Tensor x, Tensor partitio } } - auto end_time = std::chrono::high_resolution_clock::now(); - timing_info->total_time_ns = std::chrono::duration_cast(end_time - start_time).count(); + auto end_time = high_resolution_clock::now(); + timing_info->total_time_ns = duration_cast(end_time - start_time).count(); auto search_result = std::make_shared(); search_result->ids = ret_ids; @@ -674,7 +672,7 @@ shared_ptr QueryCoordinator::search(Tensor x, shared_ptr(); - auto start = std::chrono::high_resolution_clock::now(); + auto start = high_resolution_clock::now(); // if there is no parent, then the coordinator is operating on a flat index and we need to scan all partitions Tensor partition_ids_to_scan; @@ -695,21 +693,12 @@ shared_ptr QueryCoordinator::search(Tensor x, shared_ptrnlist() * search_params->initial_search_fraction), 1); parent_search_params->k = initial_num_partitions_to_search; } else { - parent_search_params->k = search_params->nprobe; + parent_search_params->k = std::min(search_params->nprobe, (int) partition_manager_->nlist()); } auto parent_search_result = parent_->search(x, parent_search_params); partition_ids_to_scan = parent_search_result->ids; parent_timing_info = parent_search_result->timing_info; - - // if (maintenance_policy_ != nullptr) { - // for (int i = 0; i < partition_ids_to_scan.size(0); i++) { - // vector hit_partition_ids_vec = vector(partition_ids_to_scan[i].data_ptr(), - // partition_ids_to_scan[i].data_ptr() + - // partition_ids_to_scan[i].size(0)); - // maintenance_policy_->increment_hit_count(hit_partition_ids_vec); - // } - // } } auto search_result = scan_partitions(x, partition_ids_to_scan, search_params); @@ -754,8 +743,8 @@ shared_ptr QueryCoordinator::search(Tensor x, shared_ptrtiming_info->parent_info = parent_timing_info; - auto end = std::chrono::high_resolution_clock::now(); - search_result->timing_info->total_time_ns = std::chrono::duration_cast(end - start). + auto end = high_resolution_clock::now(); + search_result->timing_info->total_time_ns = duration_cast(end - start). count(); return search_result; @@ -781,7 +770,6 @@ shared_ptr QueryCoordinator::batched_serial_scan( Tensor x, Tensor partition_ids, shared_ptr search_params) { - if (!partition_manager_) { throw std::runtime_error("[QueryCoordinator::batched_serial_scan] partition_manager_ is null."); } @@ -795,7 +783,7 @@ shared_ptr QueryCoordinator::batched_serial_scan( // Timing info (could be extended as needed) auto timing_info = std::make_shared(); - auto start = std::chrono::high_resolution_clock::now(); + auto start = high_resolution_clock::now(); int64_t num_queries = x.size(0); int k = (search_params && search_params->k > 0) ? search_params->k : 1; @@ -820,10 +808,15 @@ shared_ptr QueryCoordinator::batched_serial_scan( } } - // For each unique partition, process the corresponding batch of queries. - for (auto &entry : queries_by_partition) { - int64_t pid = entry.first; - vector query_indices = entry.second; + std::vector>> queries_vec; + queries_vec.reserve(queries_by_partition.size()); + for (const auto &entry : queries_by_partition) { + queries_vec.push_back(entry); + } + + parallel_for((int64_t) 0, (int64_t) queries_by_partition.size(), [&](int64_t i) { + int64_t pid = queries_vec[i].first; + auto query_indices = queries_vec[i].second; // Create a tensor for the indices and then a subset of the queries. Tensor indices_tensor = torch::tensor(query_indices, torch::kInt64); @@ -831,9 +824,9 @@ shared_ptr QueryCoordinator::batched_serial_scan( int64_t batch_size = x_subset.size(0); // Get the partition’s data. - const float *list_codes = (float *) partition_manager_->partitions_->get_codes(pid); - const int64_t *list_ids = partition_manager_->partitions_->get_ids(pid); - int64_t list_size = partition_manager_->partitions_->list_size(pid); + const float *list_codes = (float *) partition_manager_->partition_store_->get_codes(pid); + const int64_t *list_ids = partition_manager_->partition_store_->get_ids(pid); + int64_t list_size = partition_manager_->partition_store_->list_size(pid); int64_t d = partition_manager_->d(); // Create temporary Top-K buffers for this sub-batch. @@ -857,7 +850,9 @@ shared_ptr QueryCoordinator::batched_serial_scan( // Merge: global buffer adds the new candidate distances/ids. global_buffers[global_q]->batch_add(local_dists.data(), local_ids.data(), local_ids.size()); } - } + + + }, search_params->num_threads); // Aggregate the final results into output tensors. auto topk_ids = torch::full({num_queries, k}, -1, torch::kInt64); @@ -886,8 +881,8 @@ shared_ptr QueryCoordinator::batched_serial_scan( // Optionally record per-query partition scan counts here. } - auto end = std::chrono::high_resolution_clock::now(); - timing_info->total_time_ns = std::chrono::duration_cast(end - start).count(); + auto end = high_resolution_clock::now(); + timing_info->total_time_ns = duration_cast(end - start).count(); // Prepare and return the final search result. auto search_result = std::make_shared(); @@ -895,4 +890,4 @@ shared_ptr QueryCoordinator::batched_serial_scan( search_result->distances = topk_dists; search_result->timing_info = timing_info; return search_result; -} \ No newline at end of file +} diff --git a/src/python/__init__.py b/src/python/__init__.py index 3bab7521..90addaa6 100644 --- a/src/python/__init__.py +++ b/src/python/__init__.py @@ -1,13 +1,8 @@ try: - import torch + import torch # noqa: F401 - """ - Quake Python API - - This module provides the Python bindings for the Quake index - """ - from ._bindings import * + from ._bindings import * # noqa: F401 F403 except ModuleNotFoundError as e: print(e) - print("Bindings not installed") \ No newline at end of file + print("Bindings not installed") diff --git a/src/python/datasets/ann_datasets.py b/src/python/datasets/ann_datasets.py index 638ec957..89297392 100644 --- a/src/python/datasets/ann_datasets.py +++ b/src/python/datasets/ann_datasets.py @@ -6,10 +6,11 @@ import numpy as np import torch -from quake.utils import download_url, extract_file, fvecs_to_tensor, ivecs_to_tensor, to_path, to_torch +from quake.utils import download_url, extract_file, fvecs_to_tensor, ivecs_to_tensor, to_path DEFAULT_DOWNLOAD_DIR = Path("data/") + class Dataset(abc.ABC): url: str = None download_dir: Path = None @@ -72,7 +73,7 @@ def load_ground_truth(self) -> Union[np.ndarray, torch.Tensor]: def load_dataset( - name: str, download_dir: str = DEFAULT_DOWNLOAD_DIR, overwrite_download: bool = False + name: str, download_dir: str = DEFAULT_DOWNLOAD_DIR, overwrite_download: bool = False ) -> List[Union[np.ndarray, torch.Tensor]]: if name.lower() == "sift1m": dataset = Sift1m(download_dir=download_dir) @@ -80,4 +81,4 @@ def load_dataset( raise RuntimeError("Unimplemented dataset") dataset.download(overwrite=overwrite_download) - return dataset.load() \ No newline at end of file + return dataset.load() diff --git a/src/python/index_wrappers/diskann.py b/src/python/index_wrappers/diskann.py index d5f8b363..9b0d834f 100644 --- a/src/python/index_wrappers/diskann.py +++ b/src/python/index_wrappers/diskann.py @@ -1,14 +1,14 @@ # from pathlib import Path +import time from typing import Tuple import diskannpy as dap import numpy as np import torch -import time +from quake import SearchTimingInfo from quake.index_wrappers.wrapper import IndexWrapper from quake.utils import to_numpy, to_torch -from quake import SearchTimingInfo class DiskANNDynamic(IndexWrapper): @@ -97,7 +97,6 @@ def search( assert self.index is not None assert query.ndim == 2 - query = to_numpy(query) timing_info = SearchTimingInfo() diff --git a/src/python/index_wrappers/faiss_hnsw.py b/src/python/index_wrappers/faiss_hnsw.py index 75081010..9fc7dcc2 100644 --- a/src/python/index_wrappers/faiss_hnsw.py +++ b/src/python/index_wrappers/faiss_hnsw.py @@ -1,13 +1,13 @@ +import time from typing import Optional, Tuple, Union import faiss import torch -import time +from quake import SearchTimingInfo +from quake.index_wrappers.faiss_ivf import metric_str_to_faiss from quake.index_wrappers.wrapper import IndexWrapper from quake.utils import to_numpy, to_torch -from quake.index_wrappers.faiss_wrapper import faiss_metric_to_str, metric_str_to_faiss -from quake import SearchTimingInfo class FaissHNSW(IndexWrapper): @@ -36,7 +36,14 @@ def d(self) -> int: """ return self.index.d - def build(self, vectors: torch.Tensor, m: int = 32, ef_construction: int = 40, metric: str = "l2", ids: Optional[torch.Tensor] = None): + def build( + self, + vectors: torch.Tensor, + m: int = 32, + ef_construction: int = 40, + metric: str = "l2", + ids: Optional[torch.Tensor] = None, + ): """ Build the index with the given vectors and arguments. @@ -82,7 +89,6 @@ def search(self, query: torch.Tensor, k: int, ef_search: int = 16) -> Tuple[torc distances = to_torch(distances) indices = to_torch(indices) - return indices, distances, timing_info def save(self, filename: str): diff --git a/src/python/index_wrappers/faiss_ivf.py b/src/python/index_wrappers/faiss_ivf.py index 76813fac..4326cfac 100644 --- a/src/python/index_wrappers/faiss_ivf.py +++ b/src/python/index_wrappers/faiss_ivf.py @@ -1,13 +1,13 @@ +import time from enum import Enum from typing import Optional, Tuple, Union import faiss import torch -import time +from quake import SearchResult, SearchTimingInfo from quake.index_wrappers.wrapper import IndexWrapper from quake.utils import to_torch -from quake import SearchTimingInfo, SearchResult def metric_str_to_faiss(metric: str) -> int: @@ -91,20 +91,20 @@ def index_state(self) -> dict: return { "n_list": self.index.nlist, "n_total": self.index.ntotal, - "metric": faiss_metric_to_str(self.index.metric_type) + "metric": faiss_metric_to_str(self.index.metric_type), } def maintenance(self): return def build( - self, - vectors: torch.Tensor, - nc: int, - m: int = 0, - b: int = 0, - metric: str = "l2", - ids: Optional[torch.Tensor] = None, + self, + vectors: torch.Tensor, + nc: int, + m: int = 0, + b: int = 0, + metric: str = "l2", + ids: Optional[torch.Tensor] = None, ): """ Build the index with the given vectors and arguments. @@ -196,7 +196,9 @@ def remove(self, ids: torch.Tensor): return None - def search(self, query: torch.Tensor, k: int, nprobe: int = 1, rf: int = 1, batched_scan: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + def search( + self, query: torch.Tensor, k: int, nprobe: int = 1, rf: int = 1, batched_scan: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Find the k-nearest neighbors of the query vectors. diff --git a/src/python/index_wrappers/quake.py b/src/python/index_wrappers/quake.py index 65851146..6c1fffb8 100644 --- a/src/python/index_wrappers/quake.py +++ b/src/python/index_wrappers/quake.py @@ -1,9 +1,9 @@ -from typing import Optional, Tuple, Union, List +from typing import Optional, Tuple, Union -import quake import torch -from quake import QuakeIndex +import quake +from quake import QuakeIndex from quake.index_wrappers.wrapper import IndexWrapper @@ -47,8 +47,16 @@ def index_state(self) -> dict: "n_total": self.index.ntotal(), } - def build(self, vectors: torch.Tensor, nc: int, metric: str = "l2", ids: Optional[torch.Tensor] = None, - n_workers: int = 0, m: int = -1, code_size: int = 8): + def build( + self, + vectors: torch.Tensor, + nc: int, + metric: str = "l2", + ids: Optional[torch.Tensor] = None, + num_workers: int = 0, + m: int = -1, + code_size: int = 8, + ): """ Build the index with the given vectors and arguments. @@ -63,11 +71,13 @@ def build(self, vectors: torch.Tensor, nc: int, metric: str = "l2", ids: Optiona vec_dim = vectors.shape[1] metric = metric.lower() print( - f"Building index with {vectors.shape[0]} vectors of dimension {vec_dim} and {nc} centroids, with metric {metric}.") + f"Building index with {vectors.shape[0]} vectors of dimension {vec_dim} " + f"and {nc} centroids, with metric {metric}." + ) build_params = quake.IndexBuildParams() build_params.metric = metric build_params.nlist = nc - build_params.num_workers = n_workers + build_params.num_workers = num_workers self.index = QuakeIndex() @@ -102,8 +112,20 @@ def remove(self, ids: torch.Tensor): assert ids.ndim == 1 return self.index.remove(ids) - def search(self, query: torch.Tensor, k: int, nprobe: int = 1, batched_scan = False, recall_target: float = -1, k_factor=4.0, use_precomputed = True) -> Tuple[ - torch.Tensor, torch.Tensor]: + def search( + self, + query: torch.Tensor, + k: int, + nprobe: int = 1, + batched_scan=False, + recall_target: float = -1, + k_factor=4.0, + use_precomputed=True, + initial_search_fraction=0.05, + recompute_threshold=0.1, + aps_flush_period_us=50, + n_threads=1, + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Find the k-nearest neighbors of the query vectors. @@ -118,7 +140,12 @@ def search(self, query: torch.Tensor, k: int, nprobe: int = 1, batched_scan = Fa search_params.recall_target = recall_target search_params.use_precomputed = use_precomputed search_params.batched_scan = batched_scan + search_params.initial_search_fraction = initial_search_fraction + search_params.recompute_threshold = recompute_threshold + search_params.aps_flush_period_us = aps_flush_period_us search_params.k = k + search_params.num_threads = n_threads + return self.index.search(query, search_params) def maintenance(self): @@ -136,17 +163,28 @@ def save(self, filename: str): """ self.index.save(str(filename)) - def load(self, filename: str, n_workers: int = 1, use_numa: bool = False, verbose: bool = False, - verify_numa: bool = False, same_core: bool = True, use_centroid_workers: bool = False, use_adaptive_n_probe : bool = False): + def load( + self, + filename: str, + n_workers: int = 0, + use_numa: bool = False, + verbose: bool = False, + verify_numa: bool = False, + same_core: bool = True, + use_centroid_workers: bool = False, + use_adaptive_n_probe: bool = False, + ): """ Load the index from a file. :param filename: The name of the file to load the index from. """ print( - f"Loading index from {filename}, with {n_workers} workers, use_numa={use_numa}, verbose={verbose}, verify_numa={verify_numa}, same_core={same_core}, use_centroid_workers={use_centroid_workers}") + f"Loading index from {filename}, with {n_workers} workers, use_numa={use_numa}, verbose={verbose}, " + f"verify_numa={verify_numa}, same_core={same_core}, use_centroid_workers={use_centroid_workers}" + ) self.index = QuakeIndex() - self.index.load(str(filename), True) + self.index.load(str(filename), n_workers) def centroids(self) -> torch.Tensor: """ @@ -154,7 +192,8 @@ def centroids(self) -> torch.Tensor: :return: The centroids of the index """ - return self.index.centroids() + centroid_ids = self.index.parent.get_ids() + return self.index.parent.get(centroid_ids) def cluster_ids(self) -> torch.Tensor: """ @@ -162,7 +201,7 @@ def cluster_ids(self) -> torch.Tensor: :return: The cluster ids of the index """ - return self.index.parent.get_ids() + return self.index.cluster_assignments() def metric(self) -> str: """ diff --git a/src/python/index_wrappers/scann.py b/src/python/index_wrappers/scann.py index 74a4498f..2b224026 100644 --- a/src/python/index_wrappers/scann.py +++ b/src/python/index_wrappers/scann.py @@ -7,7 +7,8 @@ import torch from quake.index_wrappers.wrapper import IndexWrapper -from quake.utils import to_numpy, to_path, to_torch +from quake.utils import to_numpy, to_torch + class Scann(IndexWrapper): index: scann.scann_ops_pybind.ScannSearcher @@ -98,7 +99,7 @@ def build( num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=training_sample_size, - incremental_threshold=1000 + incremental_threshold=1000, ) .score_brute_force() .build(docids=ids) diff --git a/src/python/index_wrappers/vamana.py b/src/python/index_wrappers/vamana.py index 69698fc0..fde8877e 100644 --- a/src/python/index_wrappers/vamana.py +++ b/src/python/index_wrappers/vamana.py @@ -26,11 +26,11 @@ def build( ids: Optional[torch.Tensor] = None, metric: str = "l2", num_threads: int = 1, - alpha: float = .95, + alpha: float = 0.95, graph_max_degree: int = 128, window_size: int = 128, max_candidate_pool_size: int = 128, - prune_to: int = 128 + prune_to: int = 128, ): parameters = svs.VamanaBuildParameters( alpha=alpha, @@ -43,7 +43,7 @@ def build( distance = svs.DistanceType.L2 elif metric == "ip": distance = svs.DistanceType.MIP - + if ids is not None: ids = to_numpy(ids).astype(np.uint64) else: @@ -93,7 +93,7 @@ def load(self, directory: str, metrics: str = "ip"): self.index = svs.DynamicVamana( config_path=str(config_dir), graph_loader=graph_loader, data_loader=data_loader, distance=distance ) - + self.index.search_window_size = 100 self.index.num_threads = 16 diff --git a/src/python/index_wrappers/wrapper.py b/src/python/index_wrappers/wrapper.py index 906848c9..f7c047bb 100644 --- a/src/python/index_wrappers/wrapper.py +++ b/src/python/index_wrappers/wrapper.py @@ -4,12 +4,13 @@ import torch + def get_index_class(index_name): - if index_name == 'Quake': + if index_name == "Quake": from quake.index_wrappers.quake import QuakeWrapper as IndexClass - elif index_name == 'HNSW': + elif index_name == "HNSW": from quake.index_wrappers.faiss_hnsw import FaissHNSW as IndexClass - elif index_name == 'IVF': + elif index_name == "IVF": from quake.index_wrappers.faiss_ivf import FaissIVF as IndexClass elif index_name == "DiskANN": from quake.index_wrappers.diskann import DiskANNDynamic as IndexClass @@ -17,10 +18,12 @@ def get_index_class(index_name): raise ValueError(f"Unknown index type: {index_name}") return IndexClass + class IndexWrapper(abc.ABC): """ Wrapper interface of various index implementations (faiss, leviathan, etc.) """ + @abstractmethod def build(self, vectors: torch.Tensor, *args, ids: Optional[torch.Tensor] = None): """Build the index with the provided build arguments""" @@ -74,4 +77,4 @@ def d(self) -> int: @abstractmethod def index_state(self) -> dict: """Return the state of the index""" - raise NotImplementedError("Subclasses must implement index_state method") \ No newline at end of file + raise NotImplementedError("Subclasses must implement index_state method") diff --git a/src/python/utils.py b/src/python/utils.py index f010615c..36576b4f 100644 --- a/src/python/utils.py +++ b/src/python/utils.py @@ -2,16 +2,16 @@ import shutil import tarfile import zipfile +from pathlib import Path +from typing import Tuple, Union from urllib.parse import urlparse from urllib.request import urlretrieve from zipfile import ZipFile -from pathlib import Path -from typing import Union, Tuple - import numpy as np import torch + def to_path(path: Union[str, Path]) -> Path: """ Convert a string to a Path object. @@ -29,7 +29,7 @@ def to_path(path: Union[str, Path]) -> Path: def to_torch(tensor: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: """ Convert a numpy array to a torch tensor. - :param tensor: input tensor. Can be a numpy array or a torch tensor. If it is a torch tensor, it will be returned as is. + :param tensor: input tensor. Can be a numpy array or a torch tensor. If a torch tensor, it will be returned as is. :return: torch tensor. """ if isinstance(tensor, np.ndarray): @@ -43,7 +43,7 @@ def to_torch(tensor: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: def to_numpy(tensor: Union[np.ndarray, torch.Tensor]) -> np.ndarray: """ Convert a torch tensor to a numpy array. - :param tensor: input tensor. Can be a numpy array or a torch tensor. If it is a numpy array, it will be returned as is. + :param tensor: input tensor. Can be a numpy array or a torch tensor. If a numpy array, it will be returned as is. :return: numpy array. """ if isinstance(tensor, torch.Tensor): @@ -158,6 +158,7 @@ def ibin_to_tensor(filename, header_size=0): numpy_array = np.fromfile(filename, dtype=np.int32, offset=8).reshape(2 * n, d)[:n] return torch.from_numpy(numpy_array) + def compute_recall(ids: torch.Tensor, gt_ids: torch.Tensor, k: int) -> torch.Tensor: ids = to_torch(ids) gt_ids = to_torch(gt_ids) @@ -191,7 +192,7 @@ def compute_distance(x: torch.Tensor, y: torch.Tensor, metric: str = "l2") -> to def knn( - queries: torch.Tensor, vectors: torch.Tensor, k: int = 1, metric: str = "l2" + queries: torch.Tensor, vectors: torch.Tensor, k: int = 1, metric: str = "l2" ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute the k-nearest neighbors of the queries in the vectors. @@ -215,7 +216,6 @@ def knn( assert queries.size(1) == vectors.size(1) num_queries = queries.size(0) - num_vectors = vectors.size(0) distances = compute_distance(queries, vectors, metric) @@ -226,4 +226,4 @@ def knn( topk = torch.topk(distances, k, largest=is_metric_descending(metric)) indices, values = topk.indices, topk.values - return indices, values \ No newline at end of file + return indices, values diff --git a/src/python/workload_generator.py b/src/python/workload_generator.py index d8f467ba..3e90c814 100644 --- a/src/python/workload_generator.py +++ b/src/python/workload_generator.py @@ -4,24 +4,21 @@ from pathlib import Path from typing import Optional, Union +import matplotlib.pyplot as plt import numpy as np import torch -from quake.utils import compute_recall, knn, to_path +from quake import SearchParams from quake.index_wrappers.quake import QuakeWrapper -from quake import MaintenancePolicyParams -import hashlib -import matplotlib.pyplot as plt +from quake.utils import compute_recall, knn, to_path + def run_query( - index, - queries: torch.Tensor, - search_k: int, - search_params: dict, - gt_ids: torch.Tensor, - single_query: bool = True, - tune_nprobe: bool = False, - nprobes: Optional[torch.Tensor] = None, + index, + queries: torch.Tensor, + search_k: int, + search_params: dict, + gt_ids: torch.Tensor, ): """ Run queries on the index and compute the recall. @@ -42,6 +39,7 @@ def run_query( class VectorSampler(ABC): """Abstract class for sampling vectors.""" + @abstractmethod def sample(self, size: int): """Sample vectors for an operation.""" @@ -49,6 +47,7 @@ def sample(self, size: int): class UniformSampler(VectorSampler): """Uniformly sample vectors.""" + def __init__(self): pass @@ -65,6 +64,7 @@ class StratifiedClusterSampler(VectorSampler): This sampler uses cluster assignments and centroid distances to sample in a stratified fashion. """ + def __init__(self, assignments: torch.Tensor, centroids: torch.Tensor): self.assignments = assignments self.centroids = centroids @@ -74,34 +74,54 @@ def __init__(self, assignments: torch.Tensor, centroids: torch.Tensor): self.update_ranks(self.root_cluster) def update_ranks(self, root_cluster: int): + print("Updating cluster ranks: root cluster", root_cluster) self.root_cluster = root_cluster nearest_cluster_ids, _ = knn(self.centroids[root_cluster], self.centroids, -1, "l2") self.cluster_ranks = nearest_cluster_ids.flatten() def sample(self, sample_pool: torch.Tensor, size: int, update_ranks: bool = True): + # Get the cluster assignments for all indices in the sample pool. sample_assignments = self.assignments[sample_pool] - non_empty_clusters = torch.unique(sample_assignments) - mask = torch.zeros(self.cluster_size, dtype=torch.bool) - mask[non_empty_clusters] = True - cluster_order = self.cluster_ranks[mask[self.cluster_ranks]] - cluster_samples = [] - curr_sample_size = 0 - for cluster_id in cluster_order: - mask = sample_assignments == cluster_id - cluster_sample_pool = sample_pool[mask] - if cluster_sample_pool.shape[0] == 0: + + # Identify which clusters are present in the sample pool. + present_clusters = set(sample_assignments.tolist()) + + # Filter self.cluster_ranks to only include clusters that are present. + cluster_order = [c for c in self.cluster_ranks.tolist() if c in present_clusters] + + sampled_indices = [] + num_collected = 0 + + # Loop over clusters in the order defined by the filtered ranks. + for cluster in cluster_order: + print("Sampling from cluster", cluster) + # Find indices in sample_pool that belong to this cluster. + cluster_mask = (sample_assignments == cluster).nonzero(as_tuple=True)[0] + if cluster_mask.numel() == 0: continue - cluster_sample_size = min(size - curr_sample_size, cluster_sample_pool.shape[0]) - cluster_sample = cluster_sample_pool[torch.randperm(cluster_sample_pool.shape[0])[:cluster_sample_size]] - cluster_samples.append(cluster_sample) - curr_sample_size += cluster_sample_size - if curr_sample_size >= size: + + # Determine how many samples to draw from this cluster. + n_to_sample = min(size - num_collected, cluster_mask.numel()) + + # Randomly sample from the indices in this cluster. + perm = torch.randperm(cluster_mask.numel()) + chosen = cluster_mask[perm[:n_to_sample]] + sampled_indices.append(sample_pool[chosen]) + + num_collected += n_to_sample + if num_collected >= size: break - sample_ids = torch.cat(cluster_samples) - if update_ranks: - self.update_ranks(cluster_id) - sample_ids = torch.unique(sample_ids) - return sample_ids + + # Concatenate the sampled indices. + result = torch.cat(sampled_indices) if sampled_indices else torch.tensor([], dtype=torch.long) + + # Optionally update ranks with the last cluster that contributed samples. + if update_ranks and cluster_order: + self.update_ranks(cluster_order[1]) + + # Remove duplicates, if any. + result = torch.unique(result) + return result class DynamicWorkloadGenerator: @@ -114,25 +134,26 @@ class DynamicWorkloadGenerator: 3. Generate operations (insert, delete, query) according to given ratios. 4. Save each operation and a runbook that includes a summary. """ + def __init__( - self, - workload_dir: Union[str, Path], - base_vectors: np.ndarray, - metric: str, - insert_ratio: float, - delete_ratio: float, - query_ratio: float, - update_batch_size: int, - query_batch_size: int, - number_of_operations: int, - initial_size: int, - cluster_size: int, - cluster_sample_distribution: str, - queries: np.ndarray, - query_cluster_sample_distribution: str = "uniform", - seed: int = 1738, - initial_clustering_path: Optional[Union[str, Path]] = None, - overwrite: bool = False, + self, + workload_dir: Union[str, Path], + base_vectors: np.ndarray, + metric: str, + insert_ratio: float, + delete_ratio: float, + query_ratio: float, + update_batch_size: int, + query_batch_size: int, + number_of_operations: int, + initial_size: int, + cluster_size: int, + cluster_sample_distribution: str, + queries: np.ndarray, + query_cluster_sample_distribution: str = "uniform", + seed: int = 1738, + initial_clustering_path: Optional[Union[str, Path]] = None, + overwrite: bool = False, ): # (Initialization code unchanged) self.workload_dir = to_path(workload_dir) @@ -165,6 +186,9 @@ def __init__( self.cluster_ranks = None self.sampler = None + self.resident_history = [] + self.query_history = [] + def workload_exists(self): return (self.workload_dir / "runbook.json").exists() @@ -179,7 +203,7 @@ def validate_parameters(self): assert self.number_of_operations > 0 assert self.initial_size > 0 assert self.cluster_size > 0 - assert self.cluster_sample_distribution in ["uniform", "skewed"] + assert self.cluster_sample_distribution in ["uniform", "skewed", "skewed_fixed"] def initialize_clustered_index(self): if self.initial_clustering_path is not None: @@ -193,13 +217,16 @@ def initialize_clustered_index(self): else: n_clusters = self.base_vectors.shape[0] // self.cluster_size index = QuakeWrapper() - index.build(self.base_vectors, - nc=n_clusters, - metric=self.metric, - ids=torch.arange(self.base_vectors.shape[0])) + index.build( + self.base_vectors, nc=n_clusters, metric=self.metric, ids=torch.arange(self.base_vectors.shape[0]) + ) index.save(str(self.workload_dir / "clustered_index.bin")) - self.assignments = index.cluster_ids() - print("Cluster assignments shape:", self.assignments.shape) + + search_params = SearchParams() + search_params.k = 1 + search_params.batched_scan = True + self.assignments = index.index.parent.search(self.base_vectors, search_params).ids.flatten() + return index def sample(self, size: int, operation_type: str): @@ -208,7 +235,9 @@ def sample(self, size: int, operation_type: str): elif operation_type == "delete": sample_pool = self.all_ids[self.resident_set] elif operation_type == "query": - sample_pool = torch.arange(self.queries.shape[0]) if self.queries is not None else self.all_ids[~self.resident_set] + sample_pool = ( + torch.arange(self.queries.shape[0]) if self.queries is not None else self.all_ids[~self.resident_set] + ) else: raise ValueError(f"Invalid operation type {operation_type}.") if sample_pool.shape[0] == 0: @@ -216,12 +245,13 @@ def sample(self, size: int, operation_type: str): if operation_type in ["insert", "delete"]: sample_ids = self.sampler.sample(sample_pool, size) else: - update_ranks = (self.query_cluster_sample_distribution != "skewed_fixed") + # update_ranks = (self.query_cluster_sample_distribution != "skewed_fixed") + update_ranks = True sample_ids = self.query_sampler.sample(sample_pool, size, update_ranks=update_ranks) return sample_ids def initialize_workload(self): - if self.cluster_sample_distribution == "skewed": + if self.cluster_sample_distribution in ["skewed", "skewed_fixed"]: self.sampler = StratifiedClusterSampler(self.assignments, self.clustered_index.centroids()) elif self.cluster_sample_distribution == "uniform": self.sampler = UniformSampler() @@ -268,9 +298,14 @@ def generate_workload(self): self.initialize_workload() n_inserts = n_deletes = n_queries = 0 n_operations = 0 + + initial_uniques, initial_counts = torch.unique(self.assignments, return_counts=True) + all_sizes = torch.zeros(initial_uniques.shape[0]) + all_sizes[initial_uniques] = initial_counts.float() for i in range(self.number_of_operations): - operation_type = np.random.choice(["insert", "delete", "query"], - p=[self.insert_ratio, self.delete_ratio, self.query_ratio]) + operation_type = np.random.choice( + ["insert", "delete", "query"], p=[self.insert_ratio, self.delete_ratio, self.query_ratio] + ) if operation_type == "insert": sample_size = self.update_batch_size resident = True @@ -310,39 +345,68 @@ def generate_workload(self): torch.save(dists, self.operations_dir / f"{i}_gt_dists.pt") print("Operation", i, entry) self.runbook["operations"][i] = entry + + # Determine the number of clusters. Assuming clusters are labeled from 0 to max_cluster. + n_clusters = int(self.assignments.max().item()) + 1 + fractions = np.zeros(n_clusters) + + # get resident assignments + resident_assignments = self.assignments[self.resident_set] + + uniques, counts = torch.unique(resident_assignments, return_counts=True) + fractions[uniques] = counts.float() / all_sizes[uniques] + # Append the vector of fractions for this operation. + self.resident_history.append(fractions) + self.runbook["summary"] = { "n_inserts": n_inserts, "n_deletes": n_deletes, "n_queries": n_queries, "n_operations": n_operations, } + + # Convert the history to a NumPy array with shape (n_clusters, n_operations) + heatmap_array = np.array(self.resident_history).T + fig, ax = plt.subplots(figsize=(10, 6)) + cax = ax.imshow(heatmap_array, cmap="viridis", aspect="auto") + ax.set_xlabel("Operation Number") + ax.set_ylabel("Cluster ID") + cbar = fig.colorbar(cax) + cbar.set_label("Resident Fraction") + plt.tight_layout() + plt.savefig(self.workload_dir / "resident_history.png") + print("\nWorkload Generation Summary:") print(f"Total Operations: {n_operations}") print(f"Inserts: {n_inserts}, Deletes: {n_deletes}, Queries: {n_queries}") print(f"Final resident set size: {self.resident_set.sum().item()}") with open(self.workload_dir / "runbook.json", "w") as f: json.dump(self.runbook, f, indent=4) + # ---------------------------------------------------------------------------- class WorkloadEvaluator: """ Evaluates a generated workload on a given index and produces summary statistics and plots. """ + def __init__( - self, - workload_dir: Union[str, Path], - output_dir: Union[str, Path], - base_vectors_path: Optional[Union[str, Path]] = None, + self, + workload_dir: Union[str, Path], + output_dir: Union[str, Path], + base_vectors_path: Optional[Union[str, Path]] = None, ): self.workload_dir = to_path(workload_dir) self.output_dir = to_path(output_dir) self.runbook_path = self.workload_dir / "runbook.json" self.operations_dir = self.workload_dir / "operations" self.initial_indices_path = self.workload_dir / "initial_indices.pt" - self.base_vectors_path = to_path(base_vectors_path) if base_vectors_path else self.workload_dir / "base_vectors.pt" + self.base_vectors_path = ( + to_path(base_vectors_path) if base_vectors_path else self.workload_dir / "base_vectors.pt" + ) self.runbook = None - def initialize_index(self, name, index, build_params): + def initialize_index(self, name, index, build_params, m_params): index_dir = self.workload_dir / "init_indexes" index_dir.mkdir(parents=True, exist_ok=True) index_path = index_dir / f"{name}.index" @@ -354,26 +418,36 @@ def initialize_index(self, name, index, build_params): index.save(index_path) print(f"Index {name} built and saved to {index_path}") else: - index.load(index_path) + index.load(index_path, n_workers=build_params.get("num_workers", 0)) print(f"Index {name} loaded from {index_path}") + + if isinstance(index, QuakeWrapper) and m_params is not None: + index.index.initialize_maintenance_policy(m_params) + print(f"Maintenance policy initialized: {m_params}") + return index - def evaluate_workload(self, name, index, build_params, search_params, do_maintenance=False): + def evaluate_workload( + self, name, index, build_params, search_params, do_maintenance=False, m_params=None, batch=False + ): """ Evaluate the workload on the index. At the end a summary is printed and a multi-panel plot is saved. """ # validate search_params - assert 'k' in search_params, "search_params must contain 'k' for number of neighbors" + assert "k" in search_params, "search_params must contain 'k' for number of neighbors" # --- Load Workload and Index --- base_vectors = torch.load(self.base_vectors_path, weights_only=True).to(torch.float32) initial_indices = torch.load(self.initial_indices_path, weights_only=True).to(torch.int64) - index = self.initialize_index(name, index, build_params) + index = self.initialize_index(name, index, build_params, m_params) self.runbook = json.load(open(self.runbook_path, "r")) - query_vectors = (base_vectors if self.runbook["parameters"]["sample_queries"] - else torch.load(self.workload_dir / "query_vectors.pt", weights_only=True)) + query_vectors = ( + base_vectors + if self.runbook["parameters"]["sample_queries"] + else torch.load(self.workload_dir / "query_vectors.pt", weights_only=True) + ) query_vectors = query_vectors.to(torch.float32) start_time = time.time() init_time = time.time() - start_time @@ -402,19 +476,27 @@ def evaluate_workload(self, name, index, build_params, search_params, do_mainten mean_recall = None elif operation_type == "query": gt_ids = torch.load(self.operations_dir / f"{operation_id}_gt_ids.pt", weights_only=True) - gt_dist = torch.load(self.operations_dir / f"{operation_id}_gt_dists.pt", weights_only=True) + # gt_dist = torch.load(self.operations_dir / f"{operation_id}_gt_dists.pt", weights_only=True) queries = query_vectors[operation_ids] - Is, Ds, timing_infos = [], [], [] + start_time = time.time() - for query in queries: - query = query.unsqueeze(0) - search_result = index.search(query, **search_params) - Is.append(search_result.ids) - Ds.append(search_result.distances) - timing_infos.append(search_result.timing_info) + if batch: + search_result = index.search(queries, **search_params) + pred_ids = search_result.ids + timing_infos = [search_result.timing_info] + else: + Is, Ds, timing_infos = [], [], [] + + for query in queries: + query = query.unsqueeze(0) + search_result = index.search(query, **search_params) + Is.append(search_result.ids) + Ds.append(search_result.distances) + timing_infos.append(search_result.timing_info) + pred_ids = torch.cat(Is) op_time = time.time() - start_time - pred_ids = torch.cat(Is) - recalls = compute_recall(pred_ids, gt_ids, search_params['k']) + + recalls = compute_recall(pred_ids, gt_ids, search_params["k"]) mean_recall = recalls.mean().item() self.runbook["operations"][operation_id]["recall"] = mean_recall total_time = sum([ti.total_time_ns for ti in timing_infos]) @@ -422,16 +504,17 @@ def evaluate_workload(self, name, index, build_params, search_params, do_mainten print(f"Query Time: {mean_time:.2f} ns, Recall: {mean_recall:.2f}") if do_maintenance: + print("Running maintenance...") index.maintenance() n_resident = operation.get("n_resident", None) index_state = index.index_state() result = { - 'operation_number': operation_id_int, - 'operation_type': operation_type, - 'latency_ms': op_time * 1000, - 'recall': mean_recall, - 'n_resident': n_resident, + "operation_number": operation_id_int, + "operation_type": operation_type, + "latency_ms": op_time * 1000, + "recall": mean_recall, + "n_resident": n_resident, } result.update(index_state) result.update(search_params) @@ -439,15 +522,14 @@ def evaluate_workload(self, name, index, build_params, search_params, do_mainten # --- Print Evaluation Summary --- # Gather per-operation metrics - op_nums = [r['operation_number'] for r in results] - latencies_insert = [r['latency_ms'] for r in results if r['operation_type']=='insert'] - op_nums_insert = [r['operation_number'] for r in results if r['operation_type']=='insert'] - latencies_delete = [r['latency_ms'] for r in results if r['operation_type']=='delete'] - op_nums_delete = [r['operation_number'] for r in results if r['operation_type']=='delete'] - latencies_query = [r['latency_ms'] for r in results if r['operation_type']=='query'] - op_nums_query = [r['operation_number'] for r in results if r['operation_type']=='query'] - query_recalls = [r['recall'] for r in results if r['operation_type']=='query' and r['recall'] is not None] - n_vectors = [r['n_resident'] for r in results if r['n_resident'] is not None] + latencies_insert = [r["latency_ms"] for r in results if r["operation_type"] == "insert"] + op_nums_insert = [r["operation_number"] for r in results if r["operation_type"] == "insert"] + latencies_delete = [r["latency_ms"] for r in results if r["operation_type"] == "delete"] + op_nums_delete = [r["operation_number"] for r in results if r["operation_type"] == "delete"] + latencies_query = [r["latency_ms"] for r in results if r["operation_type"] == "query"] + op_nums_query = [r["operation_number"] for r in results if r["operation_type"] == "query"] + query_recalls = [r["recall"] for r in results if r["operation_type"] == "query" and r["recall"] is not None] + n_vectors = [r["n_resident"] for r in results if r["n_resident"] is not None] avg_latency_insert = np.mean(latencies_insert) if latencies_insert else None avg_latency_delete = np.mean(latencies_delete) if latencies_delete else None @@ -469,52 +551,56 @@ def evaluate_workload(self, name, index, build_params, search_params, do_mainten # Plot A: Latency per operation type ax = axs[0, 0] if op_nums_insert: - ax.plot(op_nums_insert, latencies_insert, label='Insert', marker='o') + ax.plot(op_nums_insert, latencies_insert, label="Insert", marker="o") if op_nums_delete: - ax.plot(op_nums_delete, latencies_delete, label='Delete', marker='s') + ax.plot(op_nums_delete, latencies_delete, label="Delete", marker="s") if op_nums_query: - ax.plot(op_nums_query, latencies_query, label='Query', marker='^') - ax.set_xlabel('Operation Number') - ax.set_ylabel('Latency (ms)') - ax.set_title('Operation Latency') + ax.plot(op_nums_query, latencies_query, label="Query", marker="^") + ax.set_xlabel("Operation Number") + ax.set_ylabel("Latency (ms)") + ax.set_title("Operation Latency") ax.legend() # Plot B: Number of partitions per operation (if available) ax = axs[0, 1] - partitions = [r['n_list'] for r in results if r['n_list'] is not None] - op_nums_part = [r['operation_number'] for r in results if r['n_list'] is not None] + partitions = [r["n_list"] for r in results if r["n_list"] is not None] + op_nums_part = [r["operation_number"] for r in results if r["n_list"] is not None] if partitions: - ax.plot(op_nums_part, partitions, marker='o') - ax.set_xlabel('Operation Number') - ax.set_ylabel('Number of Partitions') - ax.set_title('Partitions per Operation') + ax.plot(op_nums_part, partitions, marker="o") + ax.set_xlabel("Operation Number") + ax.set_ylabel("Number of Partitions") + ax.set_title("Partitions per Operation") else: - ax.text(0.5, 0.5, 'No partition info', ha='center', va='center') - ax.axis('off') + ax.text(0.5, 0.5, "No partition info", ha="center", va="center") + ax.axis("off") # Plot C: Resident set size per operation ax = axs[1, 0] if n_vectors: - op_nums_vect = [r['operation_number'] for r in results if r['n_resident'] is not None] - ax.plot(op_nums_vect, n_vectors, marker='o') - ax.set_xlabel('Operation Number') - ax.set_ylabel('Resident Vectors') - ax.set_title('Resident Set Size') + op_nums_vect = [r["operation_number"] for r in results if r["n_resident"] is not None] + ax.plot(op_nums_vect, n_vectors, marker="o") + ax.set_xlabel("Operation Number") + ax.set_ylabel("Resident Vectors") + ax.set_title("Resident Set Size") else: - ax.text(0.5, 0.5, 'No resident set info', ha='center', va='center') - ax.axis('off') + ax.text(0.5, 0.5, "No resident set info", ha="center", va="center") + ax.axis("off") # Plot D: Query recall per query operation ax = axs[1, 1] if op_nums_query and query_recalls: - ax.plot(op_nums_query, query_recalls, marker='o') - ax.set_xlabel('Operation Number') - ax.set_ylabel('Query Recall') - ax.set_title('Query Recall') + ax.plot(op_nums_query, query_recalls, marker="o") + ax.set_xlabel("Operation Number") + ax.set_ylabel("Query Recall") + ax.set_title("Query Recall") else: - ax.text(0.5, 0.5, 'No query recall info', ha='center', va='center') - ax.axis('off') + ax.text(0.5, 0.5, "No query recall info", ha="center", va="center") + ax.axis("off") plt.tight_layout() + + # create output directory if it doesn't exist + self.output_dir.mkdir(parents=True, exist_ok=True) + plot_path = self.output_dir / "evaluation_plots.png" plt.savefig(plot_path) print(f"Saved evaluation plots to {plot_path}") plt.close() - return results \ No newline at end of file + return results diff --git a/test/cpp/benchmark.cpp b/test/cpp/benchmark.cpp index 50501204..8e666c4b 100644 --- a/test/cpp/benchmark.cpp +++ b/test/cpp/benchmark.cpp @@ -29,9 +29,10 @@ using torch::Tensor; static const int64_t DIM = 128; static const int64_t NUM_VECTORS = 100000; // number of database vectors static const int64_t N_LIST = 100; // number of clusters for IVF -static const int64_t NUM_QUERIES = 1000; // number of queries for search benchmark +static const int64_t NUM_QUERIES = 10; // number of queries for search benchmark static const int64_t K = 10; // top-K neighbors -static const int64_t N_PROBE = 8; // number of probes for IVF +static const int64_t N_PROBE = 32; // number of probes for IVF +static const int64_t N_WORKERS = 12; // number of workers for parallel query coordinator // Helper functions to generate random data and sequential IDs static Tensor generate_data(int64_t num, int64_t dim) { @@ -110,9 +111,8 @@ class QuakeWorkerFlatBenchmark : public ::testing::Test { auto build_params = std::make_shared(); build_params->nlist = 1; // flat index build_params->metric = "l2"; - // Use as many workers as hardware concurrency - build_params->num_workers = std::thread::hardware_concurrency(); - index_->build(data_, ids_, build_params, attributes_table_); + build_params->num_workers = N_WORKERS; + index_->build(data_, ids_, build_params,attributes_table_); } }; @@ -153,9 +153,8 @@ class QuakeWorkerIVFBenchmark : public ::testing::Test { build_params->nlist = N_LIST; // IVF index build_params->metric = "l2"; build_params->niter = 3; - // Use as many workers as hardware concurrency - build_params->num_workers = std::thread::hardware_concurrency(); - index_->build(data_, ids_, build_params, attributes_table_); + build_params->num_workers = N_WORKERS; + index_->build(data_, ids_, build_params,attributes_table_); } }; @@ -163,7 +162,7 @@ class QuakeWorkerIVFBenchmark : public ::testing::Test { // ===== Faiss BENCHMARK FIXTURES ===== // -// For Faiss Flat we use IndexFlatL2. +// For Faiss Flat we use IndexFlatL2 class FaissFlatBenchmark : public ::testing::Test { protected: std::unique_ptr index_; @@ -207,7 +206,9 @@ TEST_F(QuakeSerialFlatBenchmark, Search) { search_params->batched_scan = false; auto start = high_resolution_clock::now(); - auto result = index_->search(queries, search_params); + for (int i = 0; i < queries.size(0); i++) { + auto result = index_->search(queries[i].unsqueeze(0), search_params); + } auto end = high_resolution_clock::now(); auto elapsed = duration_cast(end - start).count(); @@ -238,8 +239,14 @@ TEST_F(QuakeWorkerFlatBenchmark, Search) { search_params->nprobe = 1; // not used for flat index search_params->batched_scan = false; + for (int i = 0; i < queries.size(0); i++) { + auto result = index_->search(queries[i].unsqueeze(0), search_params); + } + auto start = high_resolution_clock::now(); - auto result = index_->search(queries, search_params); + for (int i = 0; i < queries.size(0); i++) { + auto result = index_->search(queries[i].unsqueeze(0), search_params); + } auto end = high_resolution_clock::now(); auto elapsed = duration_cast(end - start).count(); @@ -254,6 +261,8 @@ TEST_F(QuakeWorkerFlatBenchmark, SearchBatch) { search_params->nprobe = 1; // not used for flat index search_params->batched_scan = true; + index_->search(queries, search_params); + auto start = high_resolution_clock::now(); auto result = index_->search(queries, search_params); auto end = high_resolution_clock::now(); @@ -269,8 +278,15 @@ TEST_F(QuakeSerialIVFBenchmark, Search) { search_params->k = K; search_params->nprobe = N_PROBE; search_params->batched_scan = false; + + for (int i = 0; i < queries.size(0); i++) { + auto result = index_->search(queries[i].unsqueeze(0), search_params); + } + auto start = high_resolution_clock::now(); - auto result = index_->search(queries, search_params); + for (int i = 0; i < queries.size(0); i++) { + auto result = index_->search(queries[i].unsqueeze(0), search_params); + } auto end = high_resolution_clock::now(); auto elapsed = duration_cast(end - start).count(); @@ -284,6 +300,8 @@ TEST_F(QuakeSerialIVFBenchmark, SearchBatch) { search_params->k = K; search_params->nprobe = N_PROBE; search_params->batched_scan = true; + + index_->search(queries, search_params); auto start = high_resolution_clock::now(); auto result = index_->search(queries, search_params); auto end = high_resolution_clock::now(); @@ -298,10 +316,19 @@ TEST_F(QuakeWorkerIVFBenchmark, Search) { auto search_params = std::make_shared(); search_params->k = K; search_params->nprobe = N_PROBE; + // search_params->recall_target = .75; + // search_params->aps_flush_period_us = 1; // For worker-based search, batched_scan can be false (or true) depending on your implementation. search_params->batched_scan = false; + + for (int i = 0; i < queries.size(0); i++) { + auto result = index_->search(queries[i].unsqueeze(0), search_params); + } + auto start = high_resolution_clock::now(); - auto result = index_->search(queries, search_params); + for (int i = 0; i < queries.size(0); i++) { + auto result = index_->search(queries[i].unsqueeze(0), search_params); + } auto end = high_resolution_clock::now(); auto elapsed = duration_cast(end - start).count(); @@ -315,6 +342,9 @@ TEST_F(QuakeWorkerIVFBenchmark, SearchBatch) { search_params->k = K; search_params->nprobe = N_PROBE; search_params->batched_scan = true; + + index_->search(queries, search_params); + auto start = high_resolution_clock::now(); auto result = index_->search(queries, search_params); auto end = high_resolution_clock::now(); @@ -377,12 +407,32 @@ TEST_F(QuakeSerialIVFBenchmark, Remove) { // // ===== Faiss BENCHMARK TESTS ===== // - TEST_F(FaissFlatBenchmark, Search) { int64_t k = K; std::vector distances(NUM_QUERIES * k); std::vector labels(NUM_QUERIES * k); Tensor queries = generate_data(NUM_QUERIES, DIM); + + for (int i = 0; i < queries.size(0); i++) { + index_->search(1, queries[i].data_ptr(), k, distances.data() + i * k, labels.data() + i * k); + } + auto start = high_resolution_clock::now(); + for (int i = 0; i < queries.size(0); i++) { + index_->search(1, queries[i].data_ptr(), k, distances.data() + i * k, labels.data() + i * k); + } + auto end = high_resolution_clock::now(); + auto elapsed = duration_cast(end - start).count(); + std::cout << "[Faiss Flat] Search time: " << elapsed << " ms" << std::endl; + ASSERT_GT(elapsed, 0); +} + +TEST_F(FaissFlatBenchmark, SearchBatch) { + int64_t k = K; + std::vector distances(NUM_QUERIES * k); + std::vector labels(NUM_QUERIES * k); + Tensor queries = generate_data(NUM_QUERIES, DIM); + + index_->search(NUM_QUERIES, queries.data_ptr(), k, distances.data(), labels.data()); auto start = high_resolution_clock::now(); index_->search(NUM_QUERIES, queries.data_ptr(), k, distances.data(), labels.data()); auto end = high_resolution_clock::now(); @@ -422,6 +472,27 @@ TEST_F(FaissIVFBenchmark, Search) { std::vector labels(NUM_QUERIES * k); index_->nprobe = N_PROBE; Tensor queries = generate_data(NUM_QUERIES, DIM); + + for (int i = 0; i < queries.size(0); i++) { + index_->search(1, queries[i].data_ptr(), k, distances.data() + i * k, labels.data() + i * k); + } + + auto start = high_resolution_clock::now(); + for (int i = 0; i < queries.size(0); i++) { + index_->search(1, queries[i].data_ptr(), k, distances.data() + i * k, labels.data() + i * k); + } + auto end = high_resolution_clock::now(); + auto elapsed = duration_cast(end - start).count(); + std::cout << "[Faiss IVF] Search time: " << elapsed << " ms" << std::endl; + ASSERT_GT(elapsed, 0); +} + +TEST_F(FaissIVFBenchmark, SearchBatch) { + int64_t k = K; + std::vector distances(NUM_QUERIES * k); + std::vector labels(NUM_QUERIES * k); + index_->nprobe = N_PROBE; + Tensor queries = generate_data(NUM_QUERIES, DIM); auto start = high_resolution_clock::now(); index_->search(NUM_QUERIES, queries.data_ptr(), k, distances.data(), labels.data()); auto end = high_resolution_clock::now(); diff --git a/test/cpp/hit_count_tracker.cpp b/test/cpp/hit_count_tracker.cpp new file mode 100644 index 00000000..b74665e0 --- /dev/null +++ b/test/cpp/hit_count_tracker.cpp @@ -0,0 +1,110 @@ +#include "gtest/gtest.h" +#include "hit_count_tracker.h" + +#include +#include +#include + +// Fixture for realistic HitCountTracker tests. +class HitCountTrackerTest : public ::testing::Test { +protected: + // Parameters used in multiple tests. + int window_size = 5; + int total_vectors = 1000; + + // Utility: Given a vector of scanned sizes, compute fraction. + float ComputeExpectedFraction(const std::vector& scanned_sizes) { + int sum = std::accumulate(scanned_sizes.begin(), scanned_sizes.end(), 0); + return static_cast(sum) / total_vectors; + } +}; + +TEST_F(HitCountTrackerTest, RandomQueriesTest) { + // Simulate a realistic workload with random queries. + // We'll generate 20 queries; each query has a random number of scanned partitions (between 1 and 5) + // and random scanned sizes in [0,300]. After each query, we compare the expected average scan fraction. + HitCountTracker tracker(window_size, total_vectors); + + // For expected simulation, store each query's fraction. + std::vector queryFractions; + + // Set up a random generator with fixed seed for reproducibility. + std::default_random_engine rng(42); + std::uniform_int_distribution partitionsDist(1, 5); + std::uniform_int_distribution sizeDist(0, 300); + + const int numQueries = 20; + for (int q = 0; q < numQueries; ++q) { + int numPartitions = partitionsDist(rng); + std::vector hit_ids; + std::vector scanned_sizes; + for (int i = 0; i < numPartitions; ++i) { + // For simplicity, use partition id equal to i. + hit_ids.push_back(i); + scanned_sizes.push_back(sizeDist(rng)); + } + tracker.add_query_data(hit_ids, scanned_sizes); + float fraction = tracker.get_current_scan_fraction(); + // Manually compute the expected average over the effective window. + queryFractions.push_back( ComputeExpectedFraction(scanned_sizes) ); + int effective_window = (q + 1 < window_size) ? q + 1 : window_size; + float expectedAvg = 0.0f; + for (int j = (q + 1 <= window_size ? 0 : q + 1 - window_size); j <= q; ++j) { + expectedAvg += queryFractions[j]; + } + expectedAvg /= effective_window; + EXPECT_NEAR(fraction, expectedAvg, 1e-5f) + << "Failure at query " << q << ": expected " << expectedAvg << ", got " << fraction; + } +} + +TEST_F(HitCountTrackerTest, MultipleWindowCyclesTest) { + // Simulate many queries (e.g. 50 queries) to force multiple cycles through the sliding window. + HitCountTracker tracker(window_size, total_vectors); + std::vector queryFractions; + + std::default_random_engine rng(123); + std::uniform_int_distribution partitionsDist(1, 4); + std::uniform_int_distribution sizeDist(0, 300); + + const int numQueries = 50; + for (int q = 0; q < numQueries; ++q) { + int numPartitions = partitionsDist(rng); + std::vector hit_ids; + std::vector scanned_sizes; + for (int i = 0; i < numPartitions; ++i) { + hit_ids.push_back(i); + scanned_sizes.push_back(sizeDist(rng)); + } + tracker.add_query_data(hit_ids, scanned_sizes); + float fraction = tracker.get_current_scan_fraction(); + queryFractions.push_back( std::accumulate(scanned_sizes.begin(), scanned_sizes.end(), 0) / static_cast(total_vectors) ); + + int effective_window = (q + 1 < window_size) ? q + 1 : window_size; + float expectedAvg = 0.0f; + for (int j = (q + 1 <= window_size ? 0 : q + 1 - window_size); j <= q; ++j) { + expectedAvg += queryFractions[j]; + } + expectedAvg /= effective_window; + EXPECT_NEAR(fraction, expectedAvg, 1e-5f) + << "At query " << q << ", expected average " << expectedAvg << ", got " << fraction; + } +} + +TEST_F(HitCountTrackerTest, EdgeCaseZeroScannedTest) { + // Test a query where all scanned sizes are zero. + HitCountTracker tracker(window_size, total_vectors); + std::vector hit_ids = {1, 2, 3}; + std::vector scanned_sizes = {0, 0, 0}; // Expect fraction 0. + tracker.add_query_data(hit_ids, scanned_sizes); + EXPECT_NEAR(tracker.get_current_scan_fraction(), 0.0f, 1e-5f); +} + +TEST_F(HitCountTrackerTest, FullScanTest) { + // Test a query where the scanned sizes sum equals total_vectors. + HitCountTracker tracker(window_size, total_vectors); + std::vector hit_ids = {0}; + std::vector scanned_sizes = {total_vectors}; // Fraction should be 1.0. + tracker.add_query_data(hit_ids, scanned_sizes); + EXPECT_NEAR(tracker.get_current_scan_fraction(), 1.0f, 1e-5f); +} \ No newline at end of file diff --git a/test/cpp/index_partition.cpp b/test/cpp/index_partition.cpp index b040d0d1..21a73d8d 100644 --- a/test/cpp/index_partition.cpp +++ b/test/cpp/index_partition.cpp @@ -85,7 +85,7 @@ TEST_F(IndexPartitionTest, DefaultConstructorTest) { EXPECT_EQ(default_partition.codes_, nullptr); EXPECT_EQ(default_partition.ids_, nullptr); EXPECT_EQ(default_partition.numa_node_, -1); - EXPECT_EQ(default_partition.thread_id_, -1); + EXPECT_EQ(default_partition.core_id_, -1); } // Test parameterized constructor @@ -314,7 +314,7 @@ TEST_F(IndexPartitionTest, ClearTest) { EXPECT_EQ(partition->codes_, nullptr); EXPECT_EQ(partition->ids_, nullptr); EXPECT_EQ(partition->numa_node_, -1); - EXPECT_EQ(partition->thread_id_, -1); + EXPECT_EQ(partition->core_id_, -1); } // Test find_id method diff --git a/test/cpp/latency_estimator.cpp b/test/cpp/latency_estimator.cpp index 5ab7c708..80ed570a 100644 --- a/test/cpp/latency_estimator.cpp +++ b/test/cpp/latency_estimator.cpp @@ -6,9 +6,10 @@ // #include -#include "latency_estimation.h" +#include "maintenance_cost_estimator.h" #include "list_scanning.h" // Must include your scan_list(...) definition #include // For remove() +#include // For file I/O // Helper function to measure actual latency for given n and k static float measure_actual_latency(const ListScanLatencyEstimator& estimator, @@ -141,46 +142,50 @@ TEST(ListScanLatencyEstimatorTest, MismatchedGridsForFile) { std::remove(test_filename.c_str()); } -TEST(ListScanLatencyEstimatorTest, EstimateVsActualLatency) { - int d = 32; - std::vector n_values = {16, 64, 256}; - std::vector k_values = {1, 4, 16}; - int n_trials = 25; - - // In practice, you might have a bigger grid, but let's keep it short for test - ListScanLatencyEstimator estimator(d, n_values, k_values, n_trials); - - // Profile is performed in constructor if not already loaded from file - std::vector> test_cases = { - {16, 4}, - {64, 1}, - {64, 16}, - {256, 16}, - {100, 5}, // interpolation - {300, 16}, // interpolation - {512, 16}, // extrapolation - }; - - for (auto& tc : test_cases) { - int n = tc.first; - int k = tc.second; - if (n < n_values.front() || k < k_values.front()) { - EXPECT_THROW(estimator.estimate_scan_latency(n, k), std::out_of_range); - continue; - } - - float estimated_latency_ns = estimator.estimate_scan_latency(n, k); - float actual_latency_ns = measure_actual_latency(estimator, n, k); - - float estimated_ms = estimated_latency_ns / 1e6f; - float actual_ms = actual_latency_ns / 1e6f; - std::cout << "n=" << n << ", k=" << k - << " => estimated=" << estimated_ms << "ms, actual=" << actual_ms - << "ms\n"; - - // Tolerance of 40% because these are quite approximate with small n_trials - float tolerance = 0.4f * actual_ms; - EXPECT_NEAR(estimated_ms, actual_ms, tolerance) - << "Difference is too large for n=" << n << ", k=" << k; - } -} \ No newline at end of file +// TEST(ListScanLatencyEstimatorTest, EstimateVsActualLatency) { +// int d = 32; +// std::vector n_values = {64, 256, 1024}; +// std::vector k_values = {1, 4, 16}; +// int n_trials = 25; +// +// // clear old profile file if it exists +// std::string test_filename = "latency_profile.csv"; +// std::remove(test_filename.c_str()); +// +// // In practice, you might have a bigger grid, but let's keep it short for test +// ListScanLatencyEstimator estimator(d, n_values, k_values, n_trials); +// +// // Profile is performed in constructor if not already loaded from file +// std::vector> test_cases = { +// {16, 4}, +// {64, 1}, +// {64, 16}, +// {256, 16}, +// {100, 5}, // interpolation +// {300, 16}, // interpolation +// {512, 16}, // extrapolation +// }; +// +// for (auto& tc : test_cases) { +// int n = tc.first; +// int k = tc.second; +// if (n < n_values.front() || k < k_values.front()) { +// EXPECT_THROW(estimator.estimate_scan_latency(n, k), std::out_of_range); +// continue; +// } +// +// float estimated_latency_ns = estimator.estimate_scan_latency(n, k); +// float actual_latency_ns = measure_actual_latency(estimator, n, k); +// +// float estimated_ms = estimated_latency_ns / 1e6f; +// float actual_ms = actual_latency_ns / 1e6f; +// std::cout << "n=" << n << ", k=" << k +// << " => estimated=" << estimated_ms << "ms, actual=" << actual_ms +// << "ms\n"; +// +// // Tolerance of 40% because these are quite approximate with small n_trials +// float tolerance = 0.4f * actual_ms; +// EXPECT_NEAR(estimated_ms, actual_ms, tolerance) +// << "Difference is too large for n=" << n << ", k=" << k; +// } +// } \ No newline at end of file diff --git a/test/cpp/maintenance.cpp b/test/cpp/maintenance.cpp index 078d9867..cabf2723 100644 --- a/test/cpp/maintenance.cpp +++ b/test/cpp/maintenance.cpp @@ -1,31 +1,31 @@ +// maintenance_policy_refactored_test.cpp // -// Created by Jason on 10/7/24. -// Prompt for GitHub Copilot: -// - Conform to the google style guide -// - Use descriptive variable names +// Unit tests for the refactored MaintenancePolicy class. // -// This file contains basic unit tests for MaintenancePolicy and -// its derived classes. Each test checks individual functionality -// within the maintenance code. If the approach for testing a -// particular method isn't obvious, the test is left blank. +// These tests use a helper function to create a parent QuakeIndex and PartitionManager, +// and then exercise the new MaintenancePolicy interface (recording hits, resetting, +// and performing maintenance operations based on cost-estimation). #include #include +#include #include #include "maintenance_policies.h" #include "partition_manager.h" -#include "list_scanning.h" #include "quake_index.h" +#include "list_scanning.h" // Needed for latency estimator, etc. +using std::make_shared; using std::shared_ptr; +using std::tuple; using std::vector; using torch::Tensor; -// Helper function to create a parent QuakeIndex + PartitionManager: -static std::tuple, shared_ptr> CreateParentAndManager( +// Helper function to create a parent QuakeIndex and PartitionManager. +static tuple, shared_ptr> CreateParentAndManager( int64_t nlist, int dimension, int64_t ntotal) { - auto clustering = std::make_shared(); + auto clustering = make_shared(); clustering->partition_ids = torch::arange(nlist, torch::kInt64); Tensor vectors = torch::randn({ntotal, dimension}, torch::kFloat32); @@ -34,210 +34,187 @@ static std::tuple, shared_ptr> CreatePa Tensor centroids = torch::empty({nlist, dimension}, torch::kFloat32); for (int i = 0; i < nlist; i++) { - Tensor v = vectors.index_select(0, torch::nonzero(assignments == i).squeeze(1)); - Tensor id = ids.index_select(0, torch::nonzero(assignments == i).squeeze(1)); + Tensor idx = torch::nonzero(assignments == i).squeeze(1); + Tensor v = vectors.index_select(0, idx); + Tensor id = ids.index_select(0, idx); clustering->vectors.push_back(v); clustering->vector_ids.push_back(id); centroids[i] = v.mean(0); } clustering->centroids = centroids; - auto parent = std::make_shared(); - auto build_params = std::make_shared(); + auto parent = make_shared(); + auto build_params = make_shared(); parent->build(clustering->centroids, clustering->partition_ids, build_params); - auto manager = std::make_shared(); + auto manager = make_shared(); manager->init_partitions(parent, clustering); return {parent, manager}; } -TEST(MaintenancePolicyTest, IncrementAndCheckHitCount) { - // Checks that increment_hit_count updates per_partition_hits_ - // and curr_query_id_ is incremented correctly. - - auto [p, pm] = CreateParentAndManager(3, 4, 100); - auto params = std::make_shared(); +// +// Test that without any hit events, perform_maintenance() does nothing (i.e. no deletion or splitting). +// +TEST(MaintenancePolicyRefactoredTest, NoMaintenanceWithoutHits) { + auto [parent, manager] = CreateParentAndManager(3, 4, 100); + auto params = make_shared(); params->window_size = 3; params->alpha = 0.5f; - auto policy = std::make_shared(pm, params); + // Set thresholds to 0 so that any deviation would trigger maintenance if there were hits. + params->delete_threshold_ns = 0.0f; + params->split_threshold_ns = 0.0f; + // Assume params->latency_estimator is already set up (or set it to a dummy if needed). - EXPECT_EQ(policy->curr_query_id_, 0); - int n = 5; - vector hits = {1, 2}; - for (int i = 0; i < n; i++) { - policy->increment_hit_count(hits); - } - EXPECT_EQ(policy->curr_query_id_, n); - - auto st = policy->get_partition_state(false); - for (size_t i = 0; i < st->partition_ids.size(); i++) { - auto pid = st->partition_ids[i]; - if (pid == 1 || pid == 2) { - // With n queries, each incrementing partitions 1 & 2 once, - // hit_rate => (n) / n = 1.0 - EXPECT_FLOAT_EQ(st->partition_hit_rate[i], 1.0f); - } else { - EXPECT_FLOAT_EQ(st->partition_hit_rate[i], 0.0f); - } - } -} + auto policy = make_shared(manager, params); + shared_ptr info = policy->perform_maintenance(); -TEST(MaintenancePolicyTest, DecrementAndCheckHitCount) { - // Checks that decrement_hit_count reduces the partition's per_partition_hits_. - - auto [p, pm] = CreateParentAndManager(2, 4, 50); - auto params = std::make_shared(); - params->window_size = 5; - auto policy = std::make_shared(pm, params); - - vector hits = {0}; - for (int i = 0; i < 3; i++) policy->increment_hit_count(hits); - - policy->decrement_hit_count(0); - auto st = policy->get_partition_state(false); - for (size_t i = 0; i < st->partition_ids.size(); i++) { - if (st->partition_ids[i] == 0) { - // After 3 increments and 1 decrement => total hits = 2 - // The window size is effectively 3 queries, - // so hit rate => 2/3 - EXPECT_NEAR(st->partition_hit_rate[i], (2.0f / 3.0f), 1e-6); - } - } + // With no hit events recorded, no deletion or splitting should occur. + EXPECT_EQ(info->delete_time_us, 0); + EXPECT_EQ(info->split_time_us, 0); } -TEST(MaintenancePolicyTest, GetSplitHistory) { - // Checks get_split_history after manually calling add_split. - - auto [p, pm] = CreateParentAndManager(3, 4, 30); - auto params = std::make_shared(); - auto policy = std::make_shared(pm, params); - - // Pretend that partition 10 was split into 11,12. Provide some hits to partition 10. - policy->add_partition(10, 5); - policy->add_split(10, 11, 12); - - auto history = policy->get_split_history(); - // The tuple => (parent, parent_hits, left_id, left_hits, right_id, right_hits) - // We expect one record: (10, 5, 11, 5, 12, 5) - ASSERT_EQ(history.size(), 1u); - auto rec = history[0]; - EXPECT_EQ(std::get<0>(rec), 10); - EXPECT_EQ(std::get<1>(rec), 5); - EXPECT_EQ(std::get<2>(rec), 11); - EXPECT_EQ(std::get<3>(rec), 5); - EXPECT_EQ(std::get<4>(rec), 12); - EXPECT_EQ(std::get<5>(rec), 5); +// +// Test that hit events are recorded and then reset properly. +// +TEST(MaintenancePolicyRefactoredTest, RecordAndResetHitCount) { + auto [parent, manager] = CreateParentAndManager(3, 4, 100); + auto params = make_shared(); + params->window_size = 3; + params->alpha = 0.5f; + // Use non-triggering thresholds. + params->delete_threshold_ns = 1000.0f; + params->split_threshold_ns = 1000.0f; + + auto policy = make_shared(manager, params); + + // Record several hit events. + policy->record_query_hits({1}); + policy->record_query_hits({2}); + policy->record_query_hits({1}); + + // Perform maintenance. Since thresholds are high, no maintenance should occur. + shared_ptr info1 = policy->perform_maintenance(); + EXPECT_EQ(info1->delete_time_us, 0); + EXPECT_EQ(info1->split_time_us, 0); + + // Now reset and verify that subsequent maintenance has no effect. + policy->reset(); + shared_ptr info2 = policy->perform_maintenance(); + EXPECT_EQ(info2->delete_time_us, 0); + EXPECT_EQ(info2->split_time_us, 0); } -TEST(MaintenancePolicyTest, GetPartitionState) { - // Checks get_partition_state for only_modified = false. - - auto [p, pm] = CreateParentAndManager(4, 4, 100); - auto params = std::make_shared(); - auto policy = std::make_shared(pm, params); - - auto st = policy->get_partition_state(false); +// +// Test that underutilized partitions are selected for deletion. +// Here we record hits only on some partitions so that others remain underutilized. +// +TEST(MaintenancePolicyRefactoredTest, TriggerDeletion) { + auto [parent, manager] = CreateParentAndManager(1000, 4, 100000); + auto params = make_shared(); + params->window_size = 999; + params->alpha = 0.5f; + // Set delete threshold low so that partitions with few hits are marked. + params->delete_threshold_ns = 0.0f; + // Set split threshold high so that splitting does not trigger. + params->split_threshold_ns = 1000.0f; - // check state makes sense: - // hit rate should be 0.0 for all partitions - for (size_t i = 0; i < st->partition_ids.size(); i++) { - EXPECT_FLOAT_EQ(st->partition_hit_rate[i], 0.0f); - } + auto policy = make_shared(manager, params); - // number of total vectors should be 100 - int curr_size = 0; - for (size_t i = 0; i < st->partition_sizes.size(); i++) { - curr_size += st->partition_sizes[i]; + // set hits for all partitions besides 0 + for (int i = 1; i < 1000; i++) { + policy->record_query_hits({i, i, i, i, i, i, i, i, i, i, i}); } - EXPECT_EQ(curr_size, 100); -} -TEST(MaintenancePolicyTest, SetPartitionModified) { - // Checks that set_partition_modified updates the internal set. + // make partition 0 small + manager->partition_store_->partitions_[0]->resize(10); - auto [p, pm] = CreateParentAndManager(2, 4, 20); - auto params = std::make_shared(); - auto policy = std::make_shared(pm, params); - - policy->set_partition_modified(99); - EXPECT_TRUE(policy->modified_partitions_.find(99) != policy->modified_partitions_.end()); -} + // Run maintenance. Partition 0, being unhit, should be deleted. + shared_ptr info = policy->perform_maintenance(); -TEST(MaintenancePolicyTest, SetPartitionUnmodified) { - // Checks that set_partition_unmodified removes a partition from the modified set. - - auto [p, pm] = CreateParentAndManager(2, 4, 20); - auto params = std::make_shared(); - auto policy = std::make_shared(pm, params); - - policy->set_partition_modified(42); - EXPECT_TRUE(policy->modified_partitions_.find(42) != policy->modified_partitions_.end()); - policy->set_partition_unmodified(42); - EXPECT_TRUE(policy->modified_partitions_.find(42) == policy->modified_partitions_.end()); + // Retrieve current partition IDs from the manager. + Tensor pids = manager->get_partition_ids(); + auto pids_accessor = pids.accessor(); + bool found0 = false; + for (int i = 0; i < pids.size(0); i++) { + if (pids_accessor[i] == 0) { + found0 = true; + break; + } + } + EXPECT_FALSE(found0); } -TEST(MaintenancePolicyTest, Maintenance) { - // Checks that maintenance runs without error and returns a timing info struct. - auto [p, pm] = CreateParentAndManager(2, 4, 10000); - auto params = std::make_shared(); - params->delete_threshold_ns = 0.0; - params->split_threshold_ns = 0.0; - auto policy = std::make_shared(pm, params); - - // this should not trigger any maintenance operations since no queries have been made - auto info = policy->maintenance(); - - // check that the timing info struct is populated and shows no splits or deletes - EXPECT_EQ(info->n_splits, 0); - EXPECT_EQ(info->n_deletes, 0); - - // Increment the hit count for partition 0 and run maintenance again - policy->increment_hit_count({0}); - info = policy->maintenance(); +// +// Test that overutilized partitions are selected for splitting. +// Here we record many hits on a partition so that it is overutilized and triggers splitting. +// +TEST(MaintenancePolicyRefactoredTest, TriggerSplitting) { + auto [parent, manager] = CreateParentAndManager(3, 4, 100); + auto params = make_shared(); + params->window_size = 3; + params->alpha = 0.5f; + // Set split threshold low to force splitting when hit count is high. + params->split_threshold_ns = 0.0f; + // Set delete threshold high so deletion will not be triggered. + params->delete_threshold_ns = 1000.0f; + // Ensure partitions are large enough to split. + params->min_partition_size = 1; + + auto policy = make_shared(manager, params); + + // Record multiple hits on partition 1. + for (int i = 0; i < 5; i++) { + policy->record_query_hits({1}); + } - // check that we did a split - EXPECT_EQ(info->n_splits, 1); -} + shared_ptr info = policy->perform_maintenance(); -TEST(MaintenancePolicyTest, AddSplit) { - // Checks that add_split correctly updates data structures. - auto [p, pm] = CreateParentAndManager(2, 4, 20); - auto params = std::make_shared(); - auto policy = std::make_shared(pm, params); - - policy->add_partition(50, 7); - policy->add_split(50, 51, 52); - - // check that the split was added - auto history = policy->get_split_history(); - ASSERT_EQ(history.size(), 1); - auto rec = history[0]; - EXPECT_EQ(std::get<0>(rec), 50); // old partition - EXPECT_EQ(std::get<1>(rec), 7); // old partition hits - EXPECT_EQ(std::get<2>(rec), 51); // left partition - EXPECT_EQ(std::get<3>(rec), 7); // left partition hits - EXPECT_EQ(std::get<4>(rec), 52); // right partition - EXPECT_EQ(std::get<5>(rec), 7); // right partition hits - - // check 50 is not in per_partition_hits_ and 51, 52 are - policy->per_partition_hits_.find(50) == policy->per_partition_hits_.end(); - EXPECT_EQ(policy->per_partition_hits_[51], 7); - EXPECT_EQ(policy->per_partition_hits_[52], 7); - - // check that 50 is in deleted_partition_hit_rate_ and ancestor_partition_hits_ - // TODO check if these are even needed - EXPECT_EQ(policy->deleted_partition_hit_rate_[50], 7); - EXPECT_EQ(policy->ancestor_partition_hits_[50], 7); + // Check that partition 1 has been replaced by new partition IDs. + Tensor pids = manager->get_partition_ids(); + auto pids_accessor = pids.accessor(); + bool found1 = false; + for (int i = 0; i < pids.size(0); i++) { + if (pids_accessor[i] == 1) { + found1 = true; + break; + } + } + EXPECT_FALSE(found1); } -TEST(MaintenancePolicyTest, AddPartition) { - // Checks that add_partition inserts new partition hits in the map. - auto [p, pm] = CreateParentAndManager(2, 4, 20); - auto params = std::make_shared(); - auto policy = std::make_shared(pm, params); - - policy->add_partition(1234, 99); - - EXPECT_EQ(policy->per_partition_hits_[1234], 99); +// +// Optionally, if you have implemented local refinement in PartitionManager, +// you can add a test that simulates a split and then verifies that refine_partitions() +// is called (perhaps by using a mock or a subclass of PartitionManager that records calls). +// For simplicity, this test only checks that perform_maintenance() runs without error +// when splits occur. +// +TEST(MaintenancePolicyRefactoredTest, MaintenanceRunsSuccessfully) { + auto [parent, manager] = CreateParentAndManager(100, 4, 100000); + auto params = make_shared(); + params->window_size = 3; + params->alpha = 0.5f; + // Set thresholds to force both deletion and splitting if conditions are met. + params->delete_threshold_ns = 0.0f; + params->split_threshold_ns = 0.0f; + params->min_partition_size = 1; + params->refinement_radius = 10; + params->refinement_iterations = 3; + + auto policy = make_shared(manager, params); + + // Record some hit events to trigger maintenance operations. + policy->record_query_hits({0}); + policy->record_query_hits({1}); + policy->record_query_hits({0}); + policy->record_query_hits({1}); + policy->record_query_hits({2}); + + // Run maintenance and verify that timing info is returned. + shared_ptr info = policy->perform_maintenance(); + // We expect non-negative timing values. + EXPECT_GE(info->delete_time_us, 0); + EXPECT_GE(info->split_time_us, 0); + EXPECT_GE(info->total_time_us, 0); } \ No newline at end of file diff --git a/test/cpp/maintenance_cost_estimator.cpp b/test/cpp/maintenance_cost_estimator.cpp new file mode 100644 index 00000000..22c02617 --- /dev/null +++ b/test/cpp/maintenance_cost_estimator.cpp @@ -0,0 +1,110 @@ +#include "gtest/gtest.h" +#include "maintenance_cost_estimator.h" +#include +#include +#include +#include + +// For convenience. +using std::make_shared; +using std::shared_ptr; +using std::vector; + +// Test fixture for MaintenanceCostEstimator tests. +class MaintenanceCostEstimatorTest : public ::testing::Test { +protected: + int d = 128; // dimension (not used in our fake, but required) + float alpha = 0.9f; + int k = 10; + shared_ptr latency_estimator_; + MaintenanceCostEstimator* estimator; + + virtual void SetUp() { + estimator = new MaintenanceCostEstimator(d, alpha, k); + latency_estimator_ = estimator->get_latency_estimator(); + } + + virtual void TearDown() { + delete estimator; + } +}; + +TEST_F(MaintenanceCostEstimatorTest, ComputeSplitDelta) { + // Given: partition_size = 1000, hit_rate = 0.3, total_partitions = 100. + int partition_size = 1000; + float hit_rate = 0.3f; + int total_partitions = 100; + + float expected_delta_overhead = latency_estimator_->estimate_scan_latency(total_partitions + 1, k) - + latency_estimator_->estimate_scan_latency(total_partitions, k); + float expected_delta_split = (latency_estimator_->estimate_scan_latency(partition_size / 2, k) * hit_rate * (2 * alpha) - + latency_estimator_->estimate_scan_latency(partition_size, k) * hit_rate); + float expected_delta = expected_delta_overhead + expected_delta_split; + + float computed_delta = estimator->compute_split_delta(partition_size, hit_rate, total_partitions); + + std::cout << "Computed delta: " << computed_delta << std::endl; + std::cout << "Expected delta: " << expected_delta << std::endl; + + EXPECT_NEAR(computed_delta, expected_delta, 1.0); +} + +TEST_F(MaintenanceCostEstimatorTest, ComputeDeleteDelta) { + // Given: partition_size = 1000, hit_rate = 0.3, total_partitions = 100, current_scan_fraction = 0.25. + int partition_size = 1000; + float hit_rate = 0.3f; + int total_partitions = 100; + float avg_partition_hit_rate = 0.25f; + int k = estimator->get_k(); + int avg_partition_size = partition_size; + + // Let T = total_partitions, n = partition_size, and p = hit_rate. + // Compute the structural benefit: the reduction in overhead when one partition is removed. + float latency_T = latency_estimator_->estimate_scan_latency(total_partitions, k); + float latency_T_minus_1 = latency_estimator_->estimate_scan_latency(total_partitions - 1, k); + float delta_overhead = latency_T_minus_1 - latency_T; + + float cost_old = (total_partitions - 1) * avg_partition_hit_rate + * latency_estimator_->estimate_scan_latency(avg_partition_size, k) + + hit_rate + * latency_estimator_->estimate_scan_latency(partition_size, k); + + // Compute the "new" size and scan fraction after merging + float merged_size = avg_partition_size + static_cast(partition_size) / (total_partitions - 1); + float merged_hit_rate = avg_partition_hit_rate + hit_rate / static_cast(total_partitions - 1); + + float cost_new; + if (partition_size < total_partitions) { + // assume at most partition_size partitions get the extra vectors + cost_new = partition_size * merged_hit_rate * latency_estimator_->estimate_scan_latency(avg_partition_size + 1, k) + + (total_partitions - partition_size - 1) * merged_hit_rate * latency_estimator_->estimate_scan_latency(avg_partition_size, k); + } else { + cost_new = (total_partitions - 1) * merged_hit_rate * latency_estimator_->estimate_scan_latency(ceil(merged_size), k); + } + + float delta_scanning = cost_new - cost_old; + float expected_delta = delta_overhead + delta_scanning; + + int average_partition_size = partition_size; + + float computed_delta = estimator->compute_delete_delta(partition_size, + hit_rate, + total_partitions, + avg_partition_hit_rate, + average_partition_size); + + std::cout << "Computed delta: " << computed_delta << std::endl; + std::cout << "Expected delta: " << expected_delta << std::endl; + + EXPECT_NEAR(computed_delta, expected_delta, 1.0); +} + +TEST_F(MaintenanceCostEstimatorTest, InvalidParametersThrow) { + // Test that invalid parameters cause exceptions in the constructor. + EXPECT_THROW({ + MaintenanceCostEstimator est(d, -0.5f, k); + }, std::invalid_argument); + EXPECT_THROW({ + MaintenanceCostEstimator est(d, alpha, 0); + }, std::invalid_argument); +} \ No newline at end of file diff --git a/test/cpp/partition_manager.cpp b/test/cpp/partition_manager.cpp index a66d5c9e..31370291 100644 --- a/test/cpp/partition_manager.cpp +++ b/test/cpp/partition_manager.cpp @@ -200,8 +200,8 @@ TEST_F(PartitionManagerTest, AddPartitionsTest) { EXPECT_EQ(partition_manager_->nlist(), 3 + 2); Tensor p_ids = partition_manager_->get_partition_ids(); auto p_ids_acc = p_ids.accessor(); - EXPECT_EQ(partition_manager_->partitions_->list_size(3), 1); - EXPECT_EQ(partition_manager_->partitions_->list_size(4), 1); + EXPECT_EQ(partition_manager_->partition_store_->list_size(3), 1); + EXPECT_EQ(partition_manager_->partition_store_->list_size(4), 1); } // Test: Verify that select_partitions returns the expected data. @@ -245,7 +245,7 @@ TEST_F(PartitionManagerTest, GetPartitionSizesTest) { auto p_ids_acc = p_ids.accessor(); auto sizes_acc = sizes.accessor(); for (int i = 0; i < p_ids.size(0); i++) { - EXPECT_EQ(sizes_acc[i], partition_manager_->partitions_->list_size(p_ids_acc[i])); + EXPECT_EQ(sizes_acc[i], partition_manager_->partition_store_->list_size(p_ids_acc[i])); } } diff --git a/test/cpp/quake_index.cpp b/test/cpp/quake_index.cpp index a693ef13..3794911e 100644 --- a/test/cpp/quake_index.cpp +++ b/test/cpp/quake_index.cpp @@ -431,7 +431,6 @@ TEST(QuakeIndexStressTest, RapidAddRemoveAddTest) { // Add auto add_vectors = generate_random_data(batch_size, dimension); auto add_ids = generate_sequential_ids(batch_size, i * batch_size); - std::cout << add_ids << std::endl; auto add_info = index.add(add_vectors, add_ids); ASSERT_EQ(add_info->n_vectors, batch_size); @@ -498,8 +497,8 @@ TEST(QuakeIndexStressTest, SearchAddRemoveMaintenanceTest) { // Repeatedly search, add, remove, and perform maintenance to see if the index remains consistent. int64_t dimension = 16; - int64_t num_vectors = 10000; - int64_t num_queries = 1; + int64_t num_vectors = 100000; + int64_t num_queries = 100; int64_t batch_size = 10; QuakeIndex index; @@ -515,7 +514,8 @@ TEST(QuakeIndexStressTest, SearchAddRemoveMaintenanceTest) { for (int i = 0; i < 100; i++) { // Search - auto query_vectors = generate_random_data(num_queries, dimension) * .0001; + std::cout << "Iteration " << i << std::endl; + auto query_vectors = generate_random_data(num_queries, dimension) * .1; auto search_params = std::make_shared(); search_params->nprobe = 1; search_params->k = 5; diff --git a/test/experiments/adaptive_partition_scanning/aps.py b/test/experiments/adaptive_partition_scanning/aps.py index a5dc1a02..78ee9e7f 100644 --- a/test/experiments/adaptive_partition_scanning/aps.py +++ b/test/experiments/adaptive_partition_scanning/aps.py @@ -1,20 +1,18 @@ -import hydra -from hydra.utils import get_original_cwd -from omegaconf import DictConfig -import os +import logging from pathlib import Path -import pandas as pd -import numpy as np + +import hydra import matplotlib.pyplot as plt +import numpy as np +import pandas as pd import seaborn as sns -from multiprocessing import Pool -import logging import torch -import time +from hydra.utils import get_original_cwd +from omegaconf import DictConfig -from quake.utils import compute_recall, to_path +from quake import IndexBuildParams, QuakeIndex, SearchParams from quake.datasets.ann_datasets import load_dataset -from quake import MaintenancePolicyParams, QuakeIndex, IndexBuildParams, SearchParams +from quake.utils import compute_recall, to_path # Set up logging log = logging.getLogger(__name__) @@ -22,11 +20,13 @@ # Constants MIN_FANOUT = 8 + # Dataset Management def get_dataset(cfg): dataset_path = get_original_cwd() / to_path(cfg.dataset.path) vectors, queries, gt = load_dataset(cfg.dataset.name, dataset_path) - return vectors, queries[:cfg.experiment.nq], gt[:cfg.experiment.nq] + return vectors, queries[: cfg.experiment.nq], gt[: cfg.experiment.nq] + # Index Management def build_or_load_index(cfg, num_workers): @@ -45,30 +45,26 @@ def build_or_load_index(cfg, num_workers): # Load existing index index = QuakeIndex() - index.load( - str(index_path.absolute()), - num_workers - ) + index.load(str(index_path.absolute()), num_workers) log.info(f"Index loaded from {index_path} with {num_workers} workers") return index + # Experiment Execution def run_single_experiment(args): method, recall_target, recompute_ratio, use_precompute, cfg, action, n_workers = args index = build_or_load_index(cfg, n_workers) _, queries, gt = get_dataset(cfg) k = cfg.experiment.k - metric = cfg.index.metric - nlist = index.nlist() result_dir = get_original_cwd() / cfg.paths.results_dir / method result_dir.mkdir(parents=True, exist_ok=True) print(f"Running experiment for {method} at recall {recall_target} with action {action}") - if method == 'Oracle': + if method == "Oracle": result_path = result_dir / f"recall_{recall_target:.2f}.csv" data_df = run_experiment_for_configuration( @@ -82,7 +78,7 @@ def run_single_experiment(args): # Save per-query data data_df.to_csv(result_path, index=False) log.info(f"Results saved to {result_path}") - elif method == 'FixedNProbe': + elif method == "FixedNProbe": result_path = result_dir / f"recall_{recall_target:.2f}.csv" if result_path.exists() and not cfg.overwrite.results: log.info(f"Results for {method} at recall {recall_target} already exist. Skipping.") @@ -99,7 +95,7 @@ def run_single_experiment(args): # Save per-query data data_df.to_csv(result_path, index=False) log.info(f"Results saved to {result_path}") - elif method.startswith('APS'): + elif method.startswith("APS"): result_path = result_dir / f"recall_{recall_target:.2f}.csv" if result_path.exists() and not cfg.overwrite.results: log.info(f"Results for {method} at recall {recall_target} already exist. Skipping.") @@ -119,6 +115,8 @@ def run_single_experiment(args): log.info(f"Results saved to {result_path}") else: raise ValueError(f"Unknown method: {method}") + + # Results Management and Plotting def collect_and_plot_results(cfg): methods = cfg.methods @@ -127,17 +125,17 @@ def collect_and_plot_results(cfg): for method in methods: for recall_target in recall_targets: - result_dir = cfg.paths.results_dir / method - result_path = result_dir / f"recall_{recall_target:.2f}.csv" - if not result_path.exists(): - log.warning(f"Result file {result_path} does not exist. Skipping.") - continue - data_df = pd.read_csv(result_path) - data_df['Recall Target'] = recall_target - - complete_name = method - data_df['Method'] = complete_name - all_data.append(data_df) + result_dir = cfg.paths.results_dir / method + result_path = result_dir / f"recall_{recall_target:.2f}.csv" + if not result_path.exists(): + log.warning(f"Result file {result_path} does not exist. Skipping.") + continue + data_df = pd.read_csv(result_path) + data_df["Recall Target"] = recall_target + + complete_name = method + data_df["Method"] = complete_name + all_data.append(data_df) if not all_data: log.error("No data available for plotting.") @@ -146,186 +144,223 @@ def collect_and_plot_results(cfg): df_plot = pd.concat(all_data, ignore_index=True) # Clean data - df_plot = df_plot.dropna(subset=['total_time_ms', 'nprobe', 'recall']) - df_plot = df_plot[(df_plot['total_time_ms'] >= 0) & (df_plot['nprobe'] >= 0) & (df_plot['recall'] >= 0)] + df_plot = df_plot.dropna(subset=["total_time_ms", "nprobe", "recall"]) + df_plot = df_plot[(df_plot["total_time_ms"] >= 0) & (df_plot["nprobe"] >= 0) & (df_plot["recall"] >= 0)] - df_plot['Query Time (ms)'] = df_plot['total_time_ms'] - df_plot['Recall'] = df_plot['recall'] + df_plot["Query Time (ms)"] = df_plot["total_time_ms"] + df_plot["Recall"] = df_plot["recall"] # Compute stats - grouped = df_plot.groupby(['Recall Target', 'Method']) - stats = grouped.agg({ - 'Query Time (ms)': ['min', 'mean', 'max'], - 'nprobe': ['min', 'mean', 'max'], - 'Recall': ['min', 'mean', 'max'], - 'buffer_init_time_ms': ['min', 'mean', 'max'], - 'job_enqueue_time_ms': ['min', 'mean', 'max'], - 'boundary_distance_time_ms': ['min', 'mean', 'max'], - 'job_wait_time_ms': ['min', 'mean', 'max'], - 'result_aggregate_time_ms': ['min', 'mean', 'max'], - }).reset_index() - stats.columns = [' '.join(col).strip() for col in stats.columns.values] + grouped = df_plot.groupby(["Recall Target", "Method"]) + stats = grouped.agg( + { + "Query Time (ms)": ["min", "mean", "max"], + "nprobe": ["min", "mean", "max"], + "Recall": ["min", "mean", "max"], + "buffer_init_time_ms": ["min", "mean", "max"], + "job_enqueue_time_ms": ["min", "mean", "max"], + "boundary_distance_time_ms": ["min", "mean", "max"], + "job_wait_time_ms": ["min", "mean", "max"], + "result_aggregate_time_ms": ["min", "mean", "max"], + } + ).reset_index() + stats.columns = [" ".join(col).strip() for col in stats.columns.values] # Compute 'other_time_ms' and include in stats - df_plot['other_time_ms'] = df_plot['total_time_ms'] - ( - df_plot['buffer_init_time_ms'] + - df_plot['job_enqueue_time_ms'] + - df_plot['boundary_distance_time_ms'] + - df_plot['job_wait_time_ms'] + - df_plot['result_aggregate_time_ms'] + df_plot["other_time_ms"] = df_plot["total_time_ms"] - ( + df_plot["buffer_init_time_ms"] + + df_plot["job_enqueue_time_ms"] + + df_plot["boundary_distance_time_ms"] + + df_plot["job_wait_time_ms"] + + df_plot["result_aggregate_time_ms"] ) - df_plot['other_time_ms'] = df_plot['other_time_ms'].apply(lambda x: x if x >= 0 else 0) - other_time_stats = grouped['other_time_ms'].agg(['min', 'mean', 'max']).reset_index() - other_time_stats.columns = ['Recall Target', 'Method', 'other_time_ms min', 'other_time_ms mean', 'other_time_ms max'] - stats = pd.merge(stats, other_time_stats, on=['Recall Target', 'Method']) + df_plot["other_time_ms"] = df_plot["other_time_ms"].apply(lambda x: x if x >= 0 else 0) + other_time_stats = grouped["other_time_ms"].agg(["min", "mean", "max"]).reset_index() + other_time_stats.columns = [ + "Recall Target", + "Method", + "other_time_ms min", + "other_time_ms mean", + "other_time_ms max", + ] + stats = pd.merge(stats, other_time_stats, on=["Recall Target", "Method"]) # Compute p99 for 'Query Time (ms)' and 'Recall' - p95_latency = grouped['Query Time (ms)'].quantile(0.95).reset_index() - p95_recall = grouped['Recall'].quantile(0.05).reset_index() - p95_latency.rename(columns={'Query Time (ms)': 'Query Time (ms) p95'}, inplace=True) - p95_recall.rename(columns={'Recall': 'Recall p95'}, inplace=True) - stats = stats.merge(p95_latency, on=['Recall Target', 'Method']) - stats = stats.merge(p95_recall, on=['Recall Target', 'Method']) + p95_latency = grouped["Query Time (ms)"].quantile(0.95).reset_index() + p95_recall = grouped["Recall"].quantile(0.05).reset_index() + p95_latency.rename(columns={"Query Time (ms)": "Query Time (ms) p95"}, inplace=True) + p95_recall.rename(columns={"Recall": "Recall p95"}, inplace=True) + stats = stats.merge(p95_latency, on=["Recall Target", "Method"]) + stats = stats.merge(p95_recall, on=["Recall Target", "Method"]) # Save this as intermediate results - stats_save_path = cfg.paths.plot_dir / 'all_intermediate_stats.csv' - stats.to_csv(stats_save_path, index = False) - + stats_save_path = cfg.paths.plot_dir / "all_intermediate_stats.csv" + stats.to_csv(stats_save_path, index=False) + # Plotting plot_recall_only(df_plot, stats, cfg.paths.plot_dir) plot_mean_line_plots(df_plot, stats, cfg.paths.plot_dir) plot_query_overheads(stats, cfg.paths.plot_dir) -palette = {'Oracle': 'C0', 'APS': 'C1'} + +palette = {"Oracle": "C0", "APS": "C1"} + + def plot_recall_only(df_plot, stats, plot_dir): sns.set_style("whitegrid") - sns.set_context("talk", font_scale=.8) - plt.rcParams['font.weight'] = 'bold' - plt.rcParams['axes.labelweight'] = 'bold' + sns.set_context("talk", font_scale=0.8) + plt.rcParams["font.weight"] = "bold" + plt.rcParams["axes.labelweight"] = "bold" fig, ax = plt.subplots(figsize=(6, 4)) - stats['Method'] = stats['Method'].replace('Adaptive nprobe', 'APS') + stats["Method"] = stats["Method"].replace("Adaptive nprobe", "APS") # Shared x-axis range - x_min = stats['Recall Target'].min() - x_max = stats['Recall Target'].max() + x_min = stats["Recall Target"].min() + x_max = stats["Recall Target"].max() # Generate detailed x-axis ticks x_ticks = [0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0] - ax.set_title('Sift1M: Measured Recall vs. Recall Target', fontsize=14, fontweight='bold') - - for method in stats['Method'].unique(): - method_stats = stats[stats['Method'] == method].sort_values('Recall Target') - ax.plot(method_stats['Recall Target'], method_stats['Recall mean'], - label=method, color=palette[method], linewidth=4, marker='o', markersize=10) + ax.set_title("Sift1M: Measured Recall vs. Recall Target", fontsize=14, fontweight="bold") + + for method in stats["Method"].unique(): + method_stats = stats[stats["Method"] == method].sort_values("Recall Target") + ax.plot( + method_stats["Recall Target"], + method_stats["Recall mean"], + label=method, + color=palette[method], + linewidth=4, + marker="o", + markersize=10, + ) # plot p95 recall # p95_recall = method_stats['Recall p95'] # ax.plot(method_stats['Recall Target'], p95_recall, color=palette[method], linestyle='--', linewidth=8) x_vals = np.linspace(x_min, x_max, 100) - ax.plot(x_vals, x_vals, color='black', linestyle='--', linewidth=4, label='Recall Target') + ax.plot(x_vals, x_vals, color="black", linestyle="--", linewidth=4, label="Recall Target") - ax.set_ylabel('Recall', fontsize=16, fontweight='bold') - ax.tick_params(axis='both', which='major', labelsize=16) + ax.set_ylabel("Recall", fontsize=16, fontweight="bold") + ax.tick_params(axis="both", which="major", labelsize=16) ax.set_yticks([0.7, 0.8, 0.9, 1.0]) # Custom y-axis ticks ax.set_xticks(x_ticks) - ax.legend(fontsize=8, prop={'weight': 'bold'}, loc='lower right') # Legend moved to top-left + ax.legend(fontsize=8, prop={"weight": "bold"}, loc="lower right") # Legend moved to top-left ax.grid(True) plt.tight_layout() - plot_path = plot_dir / 'recall_plot.pdf' - plt.savefig(plot_path, bbox_inches='tight') + plot_path = plot_dir / "recall_plot.pdf" + plt.savefig(plot_path, bbox_inches="tight") log.info(f"Plot saved to {plot_path}") plt.show() - # Plotting Functions def plot_mean_line_plots(df_plot, stats, plot_dir): sns.set_style("whitegrid") - sns.set_context("talk", font_scale=.8) - plt.rcParams['font.weight'] = 'bold' - plt.rcParams['axes.labelweight'] = 'bold' + sns.set_context("talk", font_scale=0.8) + plt.rcParams["font.weight"] = "bold" + plt.rcParams["axes.labelweight"] = "bold" # Create a vertically stacked layout with shared x-axis fig, axes = plt.subplots(3, 1, figsize=(8, 6), sharex=True) # Stacked plots with shared x-axis - stats['Method'] = stats['Method'].replace('Adaptive nprobe', 'APS') + stats["Method"] = stats["Method"].replace("Adaptive nprobe", "APS") # Shared x-axis range - x_min = stats['Recall Target'].min() - x_max = stats['Recall Target'].max() + x_min = stats["Recall Target"].min() + x_max = stats["Recall Target"].max() # Generate detailed x-axis ticks x_ticks = [0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0] - axes[0].set_title('Sift1M: Recall, Query Latency, and Nprobe vs. Recall Target', fontsize=20, fontweight='bold') + axes[0].set_title("Sift1M: Recall, Query Latency, and Nprobe vs. Recall Target", fontsize=20, fontweight="bold") # Plot 1: Recall - for method in stats['Method'].unique(): - method_stats = stats[stats['Method'] == method].sort_values('Recall Target') - axes[0].plot(method_stats['Recall Target'], method_stats['Recall mean'], - label=method, color=palette[method], linewidth=8, marker='o', markersize=10) + for method in stats["Method"].unique(): + method_stats = stats[stats["Method"] == method].sort_values("Recall Target") + axes[0].plot( + method_stats["Recall Target"], + method_stats["Recall mean"], + label=method, + color=palette[method], + linewidth=8, + marker="o", + markersize=10, + ) # plot p95 recall # p95_recall = method_stats['Recall p95'] # axes[0].plot(method_stats['Recall Target'], p95_recall, color=palette[method], linestyle='--', linewidth=8) - - # Diagonal reference line x_vals = np.linspace(x_min, x_max, 100) - axes[0].plot(x_vals, x_vals, color='black', linestyle='--', linewidth=8, label='Recall Target') + axes[0].plot(x_vals, x_vals, color="black", linestyle="--", linewidth=8, label="Recall Target") - axes[0].set_ylabel('Recall', fontsize=16, fontweight='bold') - axes[0].tick_params(axis='both', which='major', labelsize=16) + axes[0].set_ylabel("Recall", fontsize=16, fontweight="bold") + axes[0].tick_params(axis="both", which="major", labelsize=16) axes[0].set_yticks([0.7, 0.8, 0.9, 1.0]) # Custom y-axis ticks axes[0].set_xticks(x_ticks) - axes[0].legend(fontsize=8, prop={'weight': 'bold'}, loc='lower right') # Legend moved to top-left + axes[0].legend(fontsize=8, prop={"weight": "bold"}, loc="lower right") # Legend moved to top-left axes[0].grid(True) # Plot 2: Query Time - for method in stats['Method'].unique(): - method_stats = stats[stats['Method'] == method].sort_values('Recall Target') - axes[1].plot(method_stats['Recall Target'], method_stats['Query Time (ms) mean'], - label=method, color=palette[method], linewidth=8, marker='o', markersize=10) + for method in stats["Method"].unique(): + method_stats = stats[stats["Method"] == method].sort_values("Recall Target") + axes[1].plot( + method_stats["Recall Target"], + method_stats["Query Time (ms) mean"], + label=method, + color=palette[method], + linewidth=8, + marker="o", + markersize=10, + ) # plot p95 query time - p95_query_time = method_stats['Query Time (ms) p95'] - axes[1].plot(method_stats['Recall Target'], p95_query_time, color=palette[method], linestyle='--', linewidth=8) + p95_query_time = method_stats["Query Time (ms) p95"] + axes[1].plot(method_stats["Recall Target"], p95_query_time, color=palette[method], linestyle="--", linewidth=8) # axes[1].set_ylabel('Query Time (ms)', fontsize=20, fontweight='bold') - axes[1].set_yscale('log') - axes[1].tick_params(axis='both', which='major', labelsize=16) + axes[1].set_yscale("log") + axes[1].tick_params(axis="both", which="major", labelsize=16) axes[1].set_yticks([0.2, 0.4, 0.8, 1.6, 3.2]) # Custom y-axis ticks - axes[1].set_yticklabels(['0.2ms', '0.4ms', '0.8ms', '1.6ms', '3.2ms']) # Add units to tick labels + axes[1].set_yticklabels(["0.2ms", "0.4ms", "0.8ms", "1.6ms", "3.2ms"]) # Add units to tick labels axes[1].set_xticks(x_ticks) # axes[1].legend(fontsize=16, prop={'weight': 'bold'}, loc='upper left') # Legend moved to top-left axes[1].grid(True) # # # Plot 3: nprobe - for method in stats['Method'].unique(): - method_stats = stats[stats['Method'] == method].sort_values('Recall Target') - axes[2].plot(method_stats['Recall Target'], method_stats['nprobe mean'], - label=method, color=palette[method], linewidth=8, marker='o', markersize=10) - - axes[2].set_xlabel('Recall Target', fontsize=16, fontweight='bold') # Shared x-axis - axes[2].set_ylabel('nprobe', fontsize=16, fontweight='bold') - axes[2].set_yscale('log') - axes[2].tick_params(axis='both', which='major', labelsize=16) + for method in stats["Method"].unique(): + method_stats = stats[stats["Method"] == method].sort_values("Recall Target") + axes[2].plot( + method_stats["Recall Target"], + method_stats["nprobe mean"], + label=method, + color=palette[method], + linewidth=8, + marker="o", + markersize=10, + ) + + axes[2].set_xlabel("Recall Target", fontsize=16, fontweight="bold") # Shared x-axis + axes[2].set_ylabel("nprobe", fontsize=16, fontweight="bold") + axes[2].set_yscale("log") + axes[2].tick_params(axis="both", which="major", labelsize=16) axes[2].set_yticks([4, 8, 16, 32, 64]) # Custom y-axis ticks - axes[2].set_yticklabels(['4', '8', '16', '32', '64']) # Add units to tick labels + axes[2].set_yticklabels(["4", "8", "16", "32", "64"]) # Add units to tick labels axes[2].set_xticks(x_ticks) # axes[2].legend(fontsize=16, prop={'weight': 'bold'}, loc='upper left') # Legend moved to top-left axes[2].grid(True) plt.tight_layout() - plot_path = plot_dir / 'mean_line_plots_stacked.pdf' - plt.savefig(plot_path, bbox_inches='tight') + plot_path = plot_dir / "mean_line_plots_stacked.pdf" + plt.savefig(plot_path, bbox_inches="tight") log.info(f"Plot saved to {plot_path}") plt.show() + def plot_query_overheads(stats, plot_dir): """ Plot stacked bar charts showing the mean time taken by each component for each method and recall target, @@ -337,35 +372,35 @@ def plot_query_overheads(stats, plot_dir): # Define the components and their labels components = [ - 'buffer_init_time_ms', - 'job_enqueue_time_ms', - 'boundary_distance_time_ms', - 'job_wait_time_ms', - 'result_aggregate_time_ms', - 'other_time_ms' # Added 'Other' category + "buffer_init_time_ms", + "job_enqueue_time_ms", + "boundary_distance_time_ms", + "job_wait_time_ms", + "result_aggregate_time_ms", + "other_time_ms", # Added 'Other' category ] component_labels = [ - 'Buffer Init', - 'Job Enqueue', - 'Boundary Distance', - 'Job Wait', - 'Result Aggregate', - 'Other' # Label for 'Other' + "Buffer Init", + "Job Enqueue", + "Boundary Distance", + "Job Wait", + "Result Aggregate", + "Other", # Label for 'Other' ] # Blue, Orange, Green, Red, Purple, Gray] - component_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#8c564b', '#7f7f7f'] + component_colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#8c564b", "#7f7f7f"] # remap the method names: Adaptive nprobe -> Adaptive Partition Scanning # stats[methods == 'Adaptive nprobe'] = 'Adaptive Partition Scanning' - stats['Method'] = stats['Method'].replace('Adaptive nprobe', 'APS') + stats["Method"] = stats["Method"].replace("Adaptive nprobe", "APS") # Define methods and their hatching patterns - methods = stats['Method'].unique() + methods = stats["Method"].unique() - method_hatches = {method: '' if method == 'Oracle' else '//' for method in methods} + method_hatches = {method: "" if method == "Oracle" else "//" for method in methods} # Prepare data for plotting - recall_targets = sorted(stats['Recall Target'].unique()) + recall_targets = sorted(stats["Recall Target"].unique()) # Initialize plotting fig, ax = plt.subplots(figsize=(14, 8)) @@ -379,11 +414,11 @@ def plot_query_overheads(stats, plot_dir): # Set positions for each method's bars # method_offsets = np.linspace(-bar_width/2, bar_width/2, len(methods)) positions = {} - positions['Oracle'] = index - 3 * bar_width / 2 - positions['APS'] = index - bar_width / 2 - positions['APS-R'] = index + bar_width / 2 - positions['APS-RP'] = index + 3 * bar_width/2 - positions['FixedNProbe'] = index + 5 * bar_width / 2 + positions["Oracle"] = index - 3 * bar_width / 2 + positions["APS"] = index - bar_width / 2 + positions["APS-R"] = index + bar_width / 2 + positions["APS-RP"] = index + 3 * bar_width / 2 + positions["FixedNProbe"] = index + 5 * bar_width / 2 # Initialize cumulative bottoms for stacking cumulative_bottoms = {method: np.zeros(n_groups) for method in methods} @@ -395,9 +430,9 @@ def plot_query_overheads(stats, plot_dir): comp_means = [] for rt in recall_targets: - row = stats[(stats['Recall Target'] == rt) & (stats['Method'] == method)] + row = stats[(stats["Recall Target"] == rt) & (stats["Method"] == method)] if not row.empty: - comp_time = row[f'{comp} mean'].values[0] + comp_time = row[f"{comp} mean"].values[0] else: comp_time = 0 comp_means.append(comp_time) @@ -405,7 +440,7 @@ def plot_query_overheads(stats, plot_dir): method_lookup = method.split("+")[0].strip() pos = positions[method_lookup] - bars = ax.bar( + ax.bar( pos, comp_means, bar_width, @@ -413,72 +448,50 @@ def plot_query_overheads(stats, plot_dir): color=comp_color, alpha=opacity, hatch=method_hatches[method], - edgecolor='black', - label=comp_label if (method == methods[0]) else "" + edgecolor="black", + label=comp_label if (method == methods[0]) else "", ) cumulative_bottoms[method] += comp_means # Set x-axis labels and ticks - ax.set_xlabel('Recall Target', fontsize=14, weight='bold') - ax.set_ylabel('Time (ms)', fontsize=14, weight='bold') - ax.set_title('Query Overheads by Method and Recall Target', fontsize=16, weight='bold') + ax.set_xlabel("Recall Target", fontsize=14, weight="bold") + ax.set_ylabel("Time (ms)", fontsize=14, weight="bold") + ax.set_title("Query Overheads by Method and Recall Target", fontsize=16, weight="bold") ax.set_xticks(index) ax.set_xticklabels([f"{rt:.2f}" for rt in recall_targets], fontsize=12) - ax.tick_params(axis='y', labelsize=12) + ax.tick_params(axis="y", labelsize=12) # Create custom legend for components - component_handles = [ - Patch(facecolor=color, edgecolor='black') for color in component_colors - ] + component_handles = [Patch(facecolor=color, edgecolor="black") for color in component_colors] component_legend = ax.legend( - component_handles, - component_labels, - title='Components', - title_fontsize=14, - fontsize=12, - loc='upper left' + component_handles, component_labels, title="Components", title_fontsize=14, fontsize=12, loc="upper left" ) # Create custom legend for methods method_handles = [ - Patch(facecolor='white', edgecolor='black', hatch=method_hatches[method], label=method) - for method in methods + Patch(facecolor="white", edgecolor="black", hatch=method_hatches[method], label=method) for method in methods ] - method_legend = ax.legend( - method_handles, - methods, - title='Method', - title_fontsize=14, - fontsize=12, - loc='upper right' - ) + ax.legend(method_handles, methods, title="Method", title_fontsize=14, fontsize=12, loc="upper right") # Add the component legend back to the plot ax.add_artist(component_legend) # Add grid for better readability - ax.yaxis.grid(True, linestyle='--', which='major', color='grey', alpha=0.5) + ax.yaxis.grid(True, linestyle="--", which="major", color="grey", alpha=0.5) # set ylimit ax.set_ylim(0, 2.5) plt.tight_layout() - plot_path = plot_dir / 'query_overheads.png' + plot_path = plot_dir / "query_overheads.png" plt.savefig(plot_path) log.info(f"Overhead plot saved to {plot_path}") plt.show() + # Helper Functions def run_experiment_for_configuration( - index, - queries, - gt, - k, - recall_target, - oracle=False, - recompute_ratio=.05, - use_precompute=True, - fixed_nprobe=False + index, queries, gt, k, recall_target, oracle=False, recompute_ratio=0.05, use_precompute=True, fixed_nprobe=False ): timing_infos = [] @@ -584,14 +597,17 @@ def run_experiment_for_configuration( search_params = SearchParams() search_params.nprobe = -1 search_params.k = k - search_params.initial_search_fraction = .1 + search_params.initial_search_fraction = 0.1 search_params.recall_target = recall_target search_params.recompute_threshold = 0.01 search_params.use_precomputed = use_precompute search_params.num_threads = 1 # debug print search params - print(f"Search Params: {search_params.nprobe}, {search_params.k}, {search_params.recall_target}, {search_params.recompute_threshold}, {search_params.use_precomputed}") + print( + f"Search Params: {search_params.nprobe}, {search_params.k}, {search_params.recall_target}, " + f"{search_params.recompute_threshold}, {search_params.use_precomputed}" + ) for query in queries: search_result = index.search(query.unsqueeze(0), search_params) @@ -601,7 +617,7 @@ def run_experiment_for_configuration( timing_infos.append(timing_info) ids = torch.cat(all_ids, dim=0) - dists = torch.cat(all_dists, dim=0) + # dists = torch.cat(all_dists, dim=0) recalls = compute_recall(ids, gt, k) per_query_data = [] @@ -615,26 +631,30 @@ def run_experiment_for_configuration( total_time_ms = timing_info.total_time_ns / 1e6 per_query_recall = float(recalls[i]) - per_query_data.append({ - 'nprobe': per_query_nprobe, - 'recall': per_query_recall, - 'buffer_init_time_ms': buffer_init_time_ms, - 'job_enqueue_time_ms': job_enqueue_time_ms, - 'boundary_distance_time_ms': boundary_distance_time_ms, - 'job_wait_time_ms': job_wait_time_ms, - 'result_aggregate_time_ms': result_aggregate_time_ms, - 'total_time_ms': total_time_ms, - }) + per_query_data.append( + { + "nprobe": per_query_nprobe, + "recall": per_query_recall, + "buffer_init_time_ms": buffer_init_time_ms, + "job_enqueue_time_ms": job_enqueue_time_ms, + "boundary_distance_time_ms": boundary_distance_time_ms, + "job_wait_time_ms": job_wait_time_ms, + "result_aggregate_time_ms": result_aggregate_time_ms, + "total_time_ms": total_time_ms, + } + ) per_query_data_df = pd.DataFrame(per_query_data) # per_query_data_df = per_query_data_df[1:] # Skip the first query return per_query_data_df + def get_nprobe_for_recall_target(recall_target, nlist): nprobe = int(nlist * recall_target / 2) nprobe = max(1, min(nprobe, nlist)) return nprobe + @hydra.main(config_path="configs", config_name="sift1m") def main(cfg: DictConfig): # Set up directories @@ -647,7 +667,7 @@ def main(cfg: DictConfig): cfg.paths.plot_dir = base_dir / cfg.paths.plot_dir cfg.paths.plot_dir.mkdir(parents=True, exist_ok=True) - if cfg.mode == 'run': + if cfg.mode == "run": # Prepare experiment parameters methods = cfg.methods recall_targets = cfg.experiment.recall_targets @@ -655,23 +675,24 @@ def main(cfg: DictConfig): experiment_args = [] for method in methods: for recall_target in recall_targets: - if method == 'Oracle': - experiment_args.append((method, recall_target, .05, True, cfg, 'execute_queries', n_workers)) - elif method == 'APS': - experiment_args.append((method, recall_target, 0.0, True, cfg, 'execute_queries', n_workers)) - elif method == 'APS-R': - experiment_args.append((method, recall_target, -1, True, cfg, 'execute_queries', n_workers)) - elif method == 'APS-RP': - experiment_args.append((method, recall_target, -1, False, cfg, 'execute_queries', n_workers)) - elif method == 'FixedNProbe': - experiment_args.append((method, recall_target, -1, True, cfg, 'execute_queries', n_workers)) + if method == "Oracle": + experiment_args.append((method, recall_target, 0.05, True, cfg, "execute_queries", n_workers)) + elif method == "APS": + experiment_args.append((method, recall_target, 0.0, True, cfg, "execute_queries", n_workers)) + elif method == "APS-R": + experiment_args.append((method, recall_target, -1, True, cfg, "execute_queries", n_workers)) + elif method == "APS-RP": + experiment_args.append((method, recall_target, -1, False, cfg, "execute_queries", n_workers)) + elif method == "FixedNProbe": + experiment_args.append((method, recall_target, -1, True, cfg, "execute_queries", n_workers)) # Run experiments for args in experiment_args: run_single_experiment(args) - elif cfg.mode == 'plot': + elif cfg.mode == "plot": collect_and_plot_results(cfg) -if __name__ == '__main__': - main() \ No newline at end of file + +if __name__ == "__main__": + main() diff --git a/test/experiments/adaptive_partition_scanning/configs/sift1m.yaml b/test/experiments/adaptive_partition_scanning/configs/sift1m.yaml index 9c5cfce1..51e8f982 100644 --- a/test/experiments/adaptive_partition_scanning/configs/sift1m.yaml +++ b/test/experiments/adaptive_partition_scanning/configs/sift1m.yaml @@ -7,7 +7,7 @@ dataset: experiment: nq: 100 k: 100 - n_workers: 4 + n_workers: 0 recall_targets: [0.7, 0.75, 0.8, 0.85, 0.9, 0.95, .99] index: metric: l2 diff --git a/test/experiments/maintenance_ablation/configs/sift1m_write_heavy.yaml b/test/experiments/maintenance_ablation/configs/sift1m_write_heavy.yaml new file mode 100644 index 00000000..234da398 --- /dev/null +++ b/test/experiments/maintenance_ablation/configs/sift1m_write_heavy.yaml @@ -0,0 +1,48 @@ +# sift1m_read_only.yaml +seed: 1739 +mode: run +name: sift1m_write_heavy +overwrite: True + +dataset: + name: sift1m + path: data/sift + +index: + metric: l2 + nc: 1024 + + search: + k: 10 + recall_target: .9 + +maintenance_configs: + - name: no_maintenance + do_maintenance: False + + - name: wo_reject + do_maintenance: True + delete_threshold: 10.0 + split_threshold: 50.0 + refinement_radius: 50 + enable_delete_rejection: False + + - name: w_reject + do_maintenance: True + delete_threshold: 10.0 + split_threshold: 50.0 + refinement_radius: 50 + +workload: + insert_ratio: .4 + delete_ratio: .1 + query_ratio: .5 + update_batch_size: 10000 + query_batch_size: 100 + number_of_operations: 1000 + initial_size: 100000 + cluster_size: 1000 + cluster_sample_distribution: skewed + +results_dir: results +workload_dir: workloads \ No newline at end of file diff --git a/test/experiments/maintenance_ablation/maintenance_ablation.py b/test/experiments/maintenance_ablation/maintenance_ablation.py new file mode 100644 index 00000000..76c38903 --- /dev/null +++ b/test/experiments/maintenance_ablation/maintenance_ablation.py @@ -0,0 +1,152 @@ +import subprocess +from pathlib import Path + +import pandas as pd +import yaml + +from quake import MaintenancePolicyParams +from quake.datasets.ann_datasets import load_dataset +from quake.index_wrappers.quake import QuakeWrapper +from quake.workload_generator import DynamicWorkloadGenerator, WorkloadEvaluator + + +def create_maintenance_params(m_config): + """ + Create a MaintenancePolicyParams object from a maintenance configuration. + Here we assume that the YAML thresholds are given in milliseconds and convert them to nanoseconds. + """ + m_params = MaintenancePolicyParams() + if "delete_threshold" in m_config: + m_params.delete_threshold_ns = m_config["delete_threshold"] + if "split_threshold" in m_config: + m_params.split_threshold_ns = m_config["split_threshold"] + if "refinement_radius" in m_config: + m_params.refinement_radius = m_config["refinement_radius"] + if "refinement_iterations" in m_config: + m_params.refinement_iterations = m_config["refinement_iterations"] + if "enable_delete_rejection" in m_config: + m_params.enable_delete_rejection = m_config["enable_delete_rejection"] + + return m_params + + +def run_experiment_for_config(m_config, config): + print(f"\n=== Running maintenance config: {m_config['name']} ===") + + # Load dataset using the configuration. + dataset_name = config["dataset"]["name"] + dataset_path = config["dataset"].get("path", "data") + vectors, queries, gt = load_dataset(dataset_name, dataset_path) + + # Set up workload parameters. + workload_cfg = config["workload"] + base_experiment_name = config["name"] + workload_dir = Path(config.get("workload_dir", "workloads")) / base_experiment_name + results_dir = Path(config.get("results_dir", "results")) / base_experiment_name / m_config["name"] + workload_dir.mkdir(parents=True, exist_ok=True) + results_dir.mkdir(parents=True, exist_ok=True) + + # Initialize the dynamic workload generator. + workload_gen = DynamicWorkloadGenerator( + workload_dir=workload_dir, + base_vectors=vectors, + metric=config["index"]["metric"], + insert_ratio=workload_cfg["insert_ratio"], + delete_ratio=workload_cfg["delete_ratio"], + query_ratio=workload_cfg["query_ratio"], + update_batch_size=workload_cfg["update_batch_size"], + query_batch_size=workload_cfg["query_batch_size"], + number_of_operations=workload_cfg["number_of_operations"], + initial_size=workload_cfg["initial_size"], + cluster_size=workload_cfg["cluster_size"], + cluster_sample_distribution=workload_cfg["cluster_sample_distribution"], + queries=queries, + seed=config.get("seed", 1738), + ) + + # Generate the workload if it doesn't already exist. + if not workload_gen.workload_exists(): + print("Generating workload...") + workload_gen.generate_workload() + else: + print("Workload already exists; reusing generated workload.") + + # Build a fresh index. + index = QuakeWrapper() + build_params = {"nc": config["index"].get("nc", 1024), "metric": config["index"]["metric"]} + + # Initialize maintenance policy if enabled. + do_maint = m_config.get("do_maintenance", False) + if do_maint: + m_params = create_maintenance_params(m_config) + print(f"Maintenance policy initialized: {m_params}") + else: + m_params = None + print("Maintenance disabled for this configuration.") + + # Set up and run the workload evaluator. + evaluator = WorkloadEvaluator(workload_dir=workload_dir, output_dir=results_dir) + search_params = config["index"]["search"] + results = evaluator.evaluate_workload( + name="quake_test", + index=index, + build_params=build_params, + search_params=search_params, + do_maintenance=do_maint, + m_params=m_params, + ) + + return results + + +def run_experiments_and_compare(): + # Load the overall configuration. + script_dir = Path(__file__).resolve().parent + # config_path = script_dir / Path("configs/sift1m_read_only.yaml") + config_path = script_dir / Path("configs/sift1m_write_heavy.yaml") + with open(config_path, "r") as f: + config = yaml.safe_load(f) + + base_results_dir = Path(config.get("results_dir", "results")) / config["name"] + base_results_dir.mkdir(parents=True, exist_ok=True) + + overwrite = config.get("overwrite", False) + + experiments_results = {} + # Loop over each maintenance configuration. + for m_config in config.get("maintenance_configs", []): + result_file = base_results_dir / m_config["name"] / "results.csv" + + if result_file.exists() and not overwrite: + print(f"Results already exist for maintenance config '{m_config['name']}'. Loading results...") + df = pd.read_csv(result_file) + experiments_results[m_config["name"]] = df + else: + results = run_experiment_for_config(m_config, config) + df = pd.DataFrame(results) + result_file.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(result_file, index=False) + experiments_results[m_config["name"]] = df + + # After running experiments (or loading results), call the existing compare_results.py script + # to generate detailed per-operation plots and the aggregate matrix. + results_root = Path(config.get("results_dir", "results")) + output_aggregate = results_root / "aggregate_matrix.png" + print("Generating detailed comparison plots using compare_results.py ...") + subprocess.run( + [ + "python", + "test/python/regression/compare_results.py", + "--results_dir", + str(results_root), + "--plot_type", + "both", + "--output_aggregate", + str(output_aggregate), + ], + check=True, + ) + + +if __name__ == "__main__": + run_experiments_and_compare() diff --git a/test/experiments/query_processing_perf/configs/sift1m_read_only.yaml b/test/experiments/query_processing_perf/configs/sift1m_read_only.yaml new file mode 100644 index 00000000..47ed99ac --- /dev/null +++ b/test/experiments/query_processing_perf/configs/sift1m_read_only.yaml @@ -0,0 +1,82 @@ +seed: 1739 +mode: run +name: sift1m_read_only +overwrite: True + +dataset: + name: sift1m + path: data/sift + +index: + metric: l2 + nc: 1024 + do_maintenance: False + + search: + k: 10 + nprobe: 20 + batched_scan: True + +configs: +# - name: serial_scan1 +# build: +# n_workers: 0 +# search: +# n_threads: 1 +# +# - name: serial_scan2 +# build: +# n_workers: 0 +# search: +# n_threads: 2 +## +# - name: serial_scan4 +# build: +# n_workers: 0 +# search: +# n_threads: 4 +## +# - name: serial_scan8 +# build: +# n_workers: 0 +# search: +# n_threads: 8 + + - name: worker_scan1 + build: + n_workers: 1 + search: + n_threads: 1 + + - name: worker_scan2 + build: + n_workers: 2 + search: + n_threads: 1 + + - name: worker_scan4 + build: + n_workers: 4 + search: + n_threads: 1 + + + - name: worker_scan8 + build: + n_workers: 8 + search: + n_threads: 1 + +workload: + insert_ratio: 0.0 + delete_ratio: 0.0 + query_ratio: 1.0 + update_batch_size: 10000 + query_batch_size: 100 + number_of_operations: 100 + initial_size: 1000000 + cluster_size: 1000 + cluster_sample_distribution: uniform + +results_dir: results +workload_dir: workloads \ No newline at end of file diff --git a/test/experiments/query_processing_perf/vary_workers.py b/test/experiments/query_processing_perf/vary_workers.py new file mode 100644 index 00000000..3ed59549 --- /dev/null +++ b/test/experiments/query_processing_perf/vary_workers.py @@ -0,0 +1,150 @@ +import subprocess +from pathlib import Path + +import pandas as pd +import yaml + +from quake import MaintenancePolicyParams +from quake.datasets.ann_datasets import load_dataset +from quake.index_wrappers.quake import QuakeWrapper +from quake.workload_generator import DynamicWorkloadGenerator, WorkloadEvaluator + + +def create_maintenance_params(m_config): + """ + Create a MaintenancePolicyParams object from a maintenance configuration. + Here we assume that the YAML thresholds are given in milliseconds and convert them to nanoseconds. + """ + m_params = MaintenancePolicyParams() + if "delete_threshold" in m_config: + m_params.delete_threshold_ns = m_config["delete_threshold"] + if "split_threshold" in m_config: + m_params.split_threshold_ns = m_config["split_threshold"] + if "refinement_radius" in m_config: + m_params.refinement_radius = m_config["refinement_radius"] + if "refinement_iterations" in m_config: + m_params.refinement_iterations = m_config["refinement_iterations"] + if "enable_delete_rejection" in m_config: + m_params.enable_delete_rejection = m_config["enable_delete_rejection"] + + return m_params + + +def run_experiment_for_config(curr_config, config): + print(f"\n=== Running maintenance config: {curr_config['name']} ===") + + # Load dataset using the configuration. + dataset_name = config["dataset"]["name"] + dataset_path = config["dataset"].get("path", "data") + vectors, queries, gt = load_dataset(dataset_name, dataset_path) + + # Set up workload parameters. + workload_cfg = config["workload"] + base_experiment_name = config["name"] + workload_dir = Path(config.get("workload_dir", "workloads")) / base_experiment_name + results_dir = Path(config.get("results_dir", "results")) / base_experiment_name / curr_config["name"] + workload_dir.mkdir(parents=True, exist_ok=True) + results_dir.mkdir(parents=True, exist_ok=True) + + # Initialize the dynamic workload generator. + workload_gen = DynamicWorkloadGenerator( + workload_dir=workload_dir, + base_vectors=vectors, + metric=config["index"]["metric"], + insert_ratio=workload_cfg["insert_ratio"], + delete_ratio=workload_cfg["delete_ratio"], + query_ratio=workload_cfg["query_ratio"], + update_batch_size=workload_cfg["update_batch_size"], + query_batch_size=workload_cfg["query_batch_size"], + number_of_operations=workload_cfg["number_of_operations"], + initial_size=workload_cfg["initial_size"], + cluster_size=workload_cfg["cluster_size"], + cluster_sample_distribution=workload_cfg["cluster_sample_distribution"], + queries=queries, + seed=config.get("seed", 1738), + ) + + # Generate the workload if it doesn't already exist. + if not workload_gen.workload_exists(): + print("Generating workload...") + workload_gen.generate_workload() + else: + print("Workload already exists; reusing generated workload.") + + # Build a fresh index. + index = QuakeWrapper() + build_params = { + "nc": config["index"].get("nc", 1024), + "metric": config["index"]["metric"], + "num_workers": curr_config["build"]["n_workers"], + } + + # Initialize maintenance policy if enabled. + do_maint = config.get("do_maintenance", False) + + # Set up and run the workload evaluator. + evaluator = WorkloadEvaluator(workload_dir=workload_dir, output_dir=results_dir) + search_params = config["index"]["search"] + search_params["n_threads"] = curr_config["search"]["n_threads"] + results = evaluator.evaluate_workload( + name="quake_test", + index=index, + build_params=build_params, + search_params=search_params, + do_maintenance=do_maint, + batch=True, + ) + + return results + + +def run_experiments_and_compare(): + # Load the overall configuration. + script_dir = Path(__file__).resolve().parent + config_path = script_dir / Path("configs/sift1m_read_only.yaml") + with open(config_path, "r") as f: + config = yaml.safe_load(f) + + base_results_dir = Path(config.get("results_dir", "results")) / config["name"] + base_results_dir.mkdir(parents=True, exist_ok=True) + + overwrite = config.get("overwrite", False) + + experiments_results = {} + # Loop over each maintenance configuration. + for curr_config in config.get("configs", []): + result_file = base_results_dir / curr_config["name"] / "results.csv" + + if result_file.exists() and not overwrite: + print(f"Results already exist for maintenance config '{curr_config['name']}'. Loading results...") + df = pd.read_csv(result_file) + experiments_results[curr_config["name"]] = df + else: + results = run_experiment_for_config(curr_config, config) + df = pd.DataFrame(results) + result_file.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(result_file, index=False) + experiments_results[curr_config["name"]] = df + + # After running experiments (or loading results), call the existing compare_results.py script + # to generate detailed per-operation plots and the aggregate matrix. + results_root = Path(config.get("results_dir", "results")) + output_aggregate = results_root / "aggregate_matrix.png" + print("Generating detailed comparison plots using compare_results.py ...") + subprocess.run( + [ + "python", + "test/python/regression/compare_results.py", + "--results_dir", + str(results_root), + "--plot_type", + "both", + "--output_aggregate", + str(output_aggregate), + ], + check=True, + ) + + +if __name__ == "__main__": + run_experiments_and_compare() diff --git a/test/python/regression/compare_results.py b/test/python/regression/compare_results.py new file mode 100644 index 00000000..9d1c787a --- /dev/null +++ b/test/python/regression/compare_results.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python +import argparse +import json +import logging +from pathlib import Path +from typing import Dict, List, Tuple + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + + +def setup_logging(log_level: int = logging.INFO) -> None: + logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(message)s") + + +def load_results(results_dir: Path) -> pd.DataFrame: + """ + Recursively load all CSV files named "results.csv" from subdirectories of results_dir. + Each CSV is assumed to come from one configuration (workload) run. + """ + csv_files = list(results_dir.rglob("results.csv")) + if not csv_files: + raise ValueError(f"No CSV results found in {results_dir}") + + dfs = [] + for csv_file in csv_files: + df = pd.read_csv(csv_file) + # Assume the parent directory's name is the maintenance configuration name. + df["maintenance_config"] = csv_file.parent.name + dfs.append(df) + combined_df = pd.concat(dfs, ignore_index=True) + # If the 'method' column is missing, use maintenance_config as the method. + if "method" not in combined_df.columns: + combined_df["method"] = combined_df["maintenance_config"] + logging.info(f"Loaded {len(dfs)} CSV files with a total of {combined_df.shape[0]} rows.") + return combined_df + + +def compare_metrics( + df: pd.DataFrame, thresholds: Dict[str, float] +) -> Tuple[Dict[str, Dict[str, pd.Series]], List[Tuple[str, str, str, float, float, float]]]: + """ + Compute average metrics per maintenance configuration and flag regression failures. + For latency (lower is better) and recall (higher is better), compare averages against thresholds. + """ + comparisons = {} + regression_failures = [] + for config, subdf in df.groupby("maintenance_config"): + metrics_to_compare = [metric for metric in thresholds.keys() if metric in subdf.columns] + if not metrics_to_compare: + continue + avg_by_config = subdf.groupby("maintenance_config")[metrics_to_compare].mean(numeric_only=True) + comparisons[config] = {} + for metric, threshold in thresholds.items(): + if metric not in avg_by_config.columns: + continue + series = avg_by_config[metric] + comparisons[config][metric] = series + if metric == "latency_ms": + best = series.min() + for method, value in series.items(): + diff = (value - best) / best + if diff > threshold: + regression_failures.append((metric, config, method, best, value, diff)) + elif metric == "recall": + best = series.max() + for method, value in series.items(): + diff = (best - value) / best + if diff > threshold: + regression_failures.append((metric, config, method, best, value, diff)) + return comparisons, regression_failures + + +def plot_aggregate_matrix(df: pd.DataFrame, metrics: List[str], output_path: Path) -> None: + """ + Create an aggregate matrix plot showing the average value for each metric (row) p + er maintenance configuration (column). + """ + # Group by maintenance_config and compute the mean of each metric. + agg_data = {} + for metric in metrics: + agg_data[metric] = df.groupby("maintenance_config")[metric].mean() + agg_df = pd.DataFrame(agg_data).T # rows are metrics, columns are maintenance configurations + + fig, ax = plt.subplots(figsize=(8, 4)) + cax = ax.imshow(agg_df.values, cmap="viridis", aspect="auto") + # Set x-axis labels (maintenance configurations) + ax.set_xticks(np.arange(agg_df.shape[1])) + ax.set_xticklabels(agg_df.columns, rotation=45, ha="right") + # Set y-axis labels (metrics) + ax.set_yticks(np.arange(agg_df.shape[0])) + ax.set_yticklabels(agg_df.index) + ax.set_title("Aggregate Metrics per Workload") + # Annotate each cell with the value. + for i in range(agg_df.shape[0]): + for j in range(agg_df.shape[1]): + ax.text(j, i, f"{agg_df.iloc[i, j]:.2f}", ha="center", va="center", color="w") + fig.colorbar(cax, ax=ax) + plt.tight_layout() + plt.savefig(output_path) + logging.info(f"Aggregate matrix plot saved to {output_path}") + plt.show() + + +def plot_joint_detailed_per_operation_all(df: pd.DataFrame, output_path: Path) -> None: + """ + Generate a joint detailed per-operation plot that shows multiple maintenance configurations + (methods) on the same axes. Four subplots are created for key metrics: + - Query latency (ms) + - Query recall + - Resident set size (n_resident) + - Number of partitions (n_list) + For query-specific metrics (latency and recall), only operations of type 'query' are used. + """ + # Filter query operations for latency and recall. + query_df = df[df["operation_type"] == "query"] + configs = df["maintenance_config"].unique() + + fig, axs = plt.subplots(2, 2, figsize=(12, 10)) + axs = axs.flatten() + + # Define the metrics to plot along with labels and their source dataframe. + metrics = [ + ("latency_ms", "Query Latency (ms)", query_df), + ("recall", "Query Recall", query_df), + ("n_resident", "Resident Set Size", df), + ("n_list", "Number of Partitions", df), + ] + + for ax, (metric, ylabel, data_source) in zip(axs, metrics): + for config in configs: + data = data_source[data_source["maintenance_config"] == config] + data = data.sort_values("operation_number") + if not data.empty: + ax.plot(data["operation_number"], data[metric], marker="o", label=config) + ax.set_xlabel("Operation Number") + ax.set_ylabel(ylabel) + ax.set_title(f"{ylabel} per Operation") + ax.legend() + + plt.tight_layout() + plt.savefig(output_path) + logging.info(f"Joint detailed per-operation plot saved to {output_path}") + plt.show() + + +def main(): + setup_logging() + parser = argparse.ArgumentParser( + description="Compare regression test results and produce aggregate and joint detailed plots." + ) + parser.add_argument( + "--results_dir", + type=Path, + default=Path("results"), + help="Directory containing result CSV files (searched recursively).", + ) + parser.add_argument( + "--plot_type", + type=str, + default="both", + choices=["aggregate", "detailed", "both"], + help="Type of plot to generate: 'aggregate', 'detailed' (joint), or 'both'.", + ) + parser.add_argument( + "--output_aggregate", + type=Path, + default=Path("aggregate_matrix.png"), + help="Output path for the aggregate matrix plot.", + ) + parser.add_argument( + "--detailed_output", + type=Path, + default=Path("detailed_joint.png"), + help="Output path for the joint detailed per-operation plot.", + ) + parser.add_argument( + "--thresholds", + type=str, + default='{"latency_ms": 0.05, "recall": 0.01}', + help="JSON string of thresholds for regression checks.", + ) + args = parser.parse_args() + + thresholds = json.loads(args.thresholds) + df = load_results(args.results_dir) + + comparisons, failures = compare_metrics(df, thresholds) + if failures: + logging.error("Regression failures detected:") + for metric, config, method, best, current, diff in failures: + logging.error( + f"Workload '{config}' method '{method}' for metric '{metric}': best = {best:.2f}, " + f"current = {current:.2f}, diff = {diff*100:.1f}%" + ) + else: + logging.info("No inter-method regressions detected within workloads.") + + if args.plot_type in ["aggregate", "both"]: + # Here the aggregate matrix shows workloads (maintenance_config) on the x-axis + # and each metric (e.g., latency and recall) as rows. + plot_aggregate_matrix(df, metrics=["latency_ms", "recall"], output_path=args.output_aggregate) + if args.plot_type in ["detailed", "both"]: + plot_joint_detailed_per_operation_all(df, output_path=args.detailed_output) + + +if __name__ == "__main__": + main() diff --git a/test/python/regression/configs/sift1m_balanced.yaml b/test/python/regression/configs/sift1m_balanced.yaml new file mode 100644 index 00000000..7230ee02 --- /dev/null +++ b/test/python/regression/configs/sift1m_balanced.yaml @@ -0,0 +1,30 @@ +# Experiment configuration for the regression test suite +seed: 1738 +mode: run +name: sift1m_balanced + +dataset: + name: sift1m + path: data/sift + +index: + metric: l2 + nc: 1024 + do_maintenance: True + + search: + k: 10 + +workload: + insert_ratio: 0.33 + delete_ratio: 0.33 + query_ratio: 0.34 + update_batch_size: 1000 + query_batch_size: 1000 + number_of_operations: 1000 + initial_size: 100000 + cluster_size: 100 + cluster_sample_distribution: uniform + +results_dir: results +workload_dir: workloads \ No newline at end of file diff --git a/test/python/regression/configs/sift1m_insert_heavy.yaml b/test/python/regression/configs/sift1m_insert_heavy.yaml new file mode 100644 index 00000000..330fa5e6 --- /dev/null +++ b/test/python/regression/configs/sift1m_insert_heavy.yaml @@ -0,0 +1,30 @@ +# sift1m_heavy_insert.yaml +seed: 1738 +mode: run +name: sift1m_heavy_insert + +dataset: + name: sift1m + path: data/sift + +index: + metric: l2 + nc: 1024 + do_maintenance: True + + search: + k: 10 + +workload: + insert_ratio: 0.80 + delete_ratio: 0.10 + query_ratio: 0.10 + update_batch_size: 1000 + query_batch_size: 1000 + number_of_operations: 1000 + initial_size: 100000 + cluster_size: 100 + cluster_sample_distribution: uniform + +results_dir: results +workload_dir: workloads \ No newline at end of file diff --git a/test/python/regression/configs/sift1m_read_only.yaml b/test/python/regression/configs/sift1m_read_only.yaml new file mode 100644 index 00000000..4b3a970c --- /dev/null +++ b/test/python/regression/configs/sift1m_read_only.yaml @@ -0,0 +1,30 @@ +# sift1m_read_only.yaml +seed: 1738 +mode: run +name: sift1m_read_only + +dataset: + name: sift1m + path: data/sift + +index: + metric: l2 + nc: 1024 + do_maintenance: True + + search: + k: 10 + +workload: + insert_ratio: 0.0 + delete_ratio: 0.0 + query_ratio: 1.0 + update_batch_size: 1000 + query_batch_size: 1000 + number_of_operations: 100 + initial_size: 1000000 + cluster_size: 100 + cluster_sample_distribution: uniform + +results_dir: results +workload_dir: workloads \ No newline at end of file diff --git a/test/python/regression/run_all_workloads.py b/test/python/regression/run_all_workloads.py new file mode 100644 index 00000000..5e85ec64 --- /dev/null +++ b/test/python/regression/run_all_workloads.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python +import argparse +import subprocess +import sys +from pathlib import Path + + +def run_workload(config: Path, results_dir: Path, run_name: str, overwrite: bool) -> None: + """Run the regression test for a given config and store its results.""" + print(f"Running workload for config: {config.name} (run: {run_name})") + # Create a dedicated subdirectory for this workload's results using the run name. + output_dir = results_dir / config.stem / run_name + output_dir.mkdir(parents=True, exist_ok=True) + + subprocess_args = [ + "python", + "run_workload.py", + "--config", + str(config), + "--output", + str(output_dir), + "--name", + run_name, + ] + if overwrite: + subprocess_args.append("--overwrite") + + result = subprocess.run(subprocess_args) + + if result.returncode != 0: + sys.exit(f"Error running workload: {config}") + + +def main(): + parser = argparse.ArgumentParser(description="Run all workload configurations with a given run name.") + parser.add_argument( + "--name", + type=str, + required=True, + help="Name of this run (e.g. 'baseline' or 'PR-123'). " + "This will be appended to the results directory for each config.", + ) + parser.add_argument("--overwrite", action="store_true") + args = parser.parse_args() + + configs = [ + Path("configs/sift1m_balanced.yaml"), + Path("configs/sift1m_insert_heavy.yaml"), + Path("configs/sift1m_read_only.yaml"), + ] + + # Base directory to store results for all workloads. + results_dir = Path("results") + results_dir.mkdir(exist_ok=True, parents=True) + + # Run each workload configuration. + for config in configs: + run_workload(config, results_dir, args.name, args.overwrite) + + compare_cmd = [ + "python", + "compare_results.py", + "--results_dir", + str(results_dir), + "--plot_type", + "aggregate", # or 'both' if you want detailed as well + "--output_aggregate", + str(results_dir / f"aggregate_matrix_{args.name}.png"), + ] + result = subprocess.run(compare_cmd) + if result.returncode != 0: + sys.exit("Error generating comparison plots.") + else: + print("All workloads completed and comparison plots generated.") + + +if __name__ == "__main__": + main() diff --git a/test/python/regression/run_workload.py b/test/python/regression/run_workload.py new file mode 100644 index 00000000..82492080 --- /dev/null +++ b/test/python/regression/run_workload.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python +import argparse +import logging +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +import yaml + +from quake.datasets.ann_datasets import load_dataset + +# Import your project-specific modules. +from quake.index_wrappers.quake import QuakeWrapper +from quake.workload_generator import DynamicWorkloadGenerator, WorkloadEvaluator + + +def setup_logging(log_level=logging.INFO): + logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(message)s") + + +def load_configuration(config_path: Path) -> dict: + try: + with open(config_path, "r") as f: + config = yaml.safe_load(f) + return config + except Exception as e: + logging.error(f"Failed to load configuration file {config_path}: {e}") + raise + + +def run_performance_test( + index: QuakeWrapper, + build_params: dict, + search_params: dict, + workload_generator: DynamicWorkloadGenerator, + evaluator: WorkloadEvaluator, + csv_output_path: Path, +) -> None: + """ + Runs the workload on the index, collects per-operation metrics, and writes a CSV file. + """ + logging.info("Generating workload operations...") + if not workload_generator.workload_exists(): + workload_generator.generate_workload() + + logging.info("Evaluating workload...") + results = evaluator.evaluate_workload( + name="Quake", index=index, build_params=build_params, search_params=search_params + ) + + # Create output directory if it does not exist. + csv_output_path.mkdir(parents=True, exist_ok=True) + csv_output_file = csv_output_path / "results.csv" + + # Save results as CSV. + df = pd.DataFrame(results) + df.to_csv(csv_output_file, index=False) + logging.info(f"CSV results saved to {csv_output_file}") + + +def main(): + setup_logging() + parser = argparse.ArgumentParser(description="Run Regression Test for the Vector Search Library") + parser.add_argument("--config", type=str, default="configs/experiment.yaml", help="Path to YAML configuration file") + parser.add_argument("--output", type=str, default="results/current.csv", help="Output CSV file path") + parser.add_argument("--name", type=str, default=None, help="Name of the experiment") + parser.add_argument("--overwrite", action="store_true", help="Overwrite existing results if set") + args = parser.parse_args() + + config_path = Path(args.config) + config = load_configuration(config_path) + + # Set seeds for reproducibility. + seed = config.get("seed", 1738) + np.random.seed(seed) + torch.manual_seed(seed) + + # Use the stem of the config file as a default experiment name. + workload_name = config_path.stem + name = args.name if args.name is not None else workload_name + + # Setup workload and output directories. + workload_dir = Path(config.get("workload_dir", "workloads/experiment")) / workload_name + workload_dir.mkdir(parents=True, exist_ok=True) + output_dir = Path(config.get("results_dir", "results")) / workload_name / name + + # Check if results already exist. + results_csv = output_dir / "results.csv" + if results_csv.exists() and not args.overwrite: + logging.info( + f"Results already exist in {output_dir}. Skipping run_regression_test. Use --overwrite to force rerun." + ) + return + + # Load dataset. + dataset_config = config.get("dataset", {}) + dataset_name = dataset_config.get("name") + dataset_path = dataset_config.get("path", None) + logging.info(f"Loading dataset {dataset_name} from {dataset_path} ...") + vectors, queries, gt = load_dataset(dataset_name, dataset_path) + + # Extract index parameters from the configuration. + build_params = {"nc": config["index"].get("nc", 1000), "metric": config["index"].get("metric", "l2")} + search_params = { + "k": config["index"]["search"].get("k", 10), + "nprobe": config["index"]["search"].get("nprobe", 10), + "recall_target": config["index"]["search"].get("recall_target", -1), + "use_precomputed": config["index"]["search"].get("use_precomputed", False), + "batched_scan": config["index"]["search"].get("batched_scan", False), + } + + # Instantiate the workload generator. + workload_generator = DynamicWorkloadGenerator( + workload_dir=workload_dir, + base_vectors=vectors, + metric=build_params["metric"], + insert_ratio=config["workload"].get("insert_ratio", 0.3), + delete_ratio=config["workload"].get("delete_ratio", 0.2), + query_ratio=config["workload"].get("query_ratio", 0.5), + update_batch_size=config["workload"].get("update_batch_size", 100), + query_batch_size=config["workload"].get("query_batch_size", 100), + number_of_operations=config["workload"].get("number_of_operations", 1000), + initial_size=config["workload"].get("initial_size", 10000), + cluster_size=config["workload"].get("cluster_size", 100), + cluster_sample_distribution=config["workload"].get("cluster_sample_distribution", "uniform"), + queries=queries, + seed=seed, + ) + + # Instantiate the evaluator. + evaluator = WorkloadEvaluator(workload_dir=workload_dir, output_dir=output_dir) + + # Instantiate your index. + index = QuakeWrapper() + + # Run the regression test. + run_performance_test(index, build_params, search_params, workload_generator, evaluator, output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/python/test_basic.py b/test/python/test_basic.py index 15339232..c242c61a 100644 --- a/test/python/test_basic.py +++ b/test/python/test_basic.py @@ -1,11 +1,13 @@ -from quake import QuakeIndex, IndexBuildParams, SearchParams -from quake.datasets.ann_datasets import load_dataset -from quake.utils import compute_recall -import torch import time as t + import faiss +import torch -if __name__ == '__main__': +from quake import IndexBuildParams, QuakeIndex, SearchParams +from quake.datasets.ann_datasets import load_dataset +from quake.utils import compute_recall + +if __name__ == "__main__": print("Testing bindings") search_params = SearchParams() @@ -43,8 +45,7 @@ faiss_ivf_index.add(vectors.numpy()) start = t.time() faiss_ivf_index.nprobe = search_params.nprobe - D, I = faiss_ivf_index.search(queries.numpy(), search_params.k) + dists, ids = faiss_ivf_index.search(queries.numpy(), search_params.k) end = t.time() - recall = compute_recall(torch.from_numpy(I), gt, search_params.k) + recall = compute_recall(torch.from_numpy(ids), gt, search_params.k) print("Faiss IVF search", recall.mean(), "Time", end - start) - diff --git a/test/python/test_workload_generator.py b/test/python/test_workload_generator.py index 0e8516c0..affa9f95 100644 --- a/test/python/test_workload_generator.py +++ b/test/python/test_workload_generator.py @@ -1,14 +1,13 @@ import json import tempfile from pathlib import Path -import torch -import numpy as np + import pytest +import torch -# Import your workload generator and evaluator -import quake from quake.index_wrappers.quake import QuakeWrapper -from quake.workload_generator import DynamicWorkloadGenerator, WorkloadEvaluator, UniformSampler +from quake.workload_generator import DynamicWorkloadGenerator, UniformSampler, WorkloadEvaluator + # Create a small synthetic dataset for testing @pytest.fixture @@ -23,6 +22,7 @@ def synthetic_dataset(): return base_vectors, queries, workload_dir + def test_workload_generation(synthetic_dataset): base_vectors, queries, workload_dir = synthetic_dataset workload_dir.mkdir(exist_ok=True) @@ -31,7 +31,7 @@ def test_workload_generation(synthetic_dataset): # These ratios must sum to 1. insert_ratio = 0.3 delete_ratio = 0.2 - query_ratio = 0.5 + query_ratio = 0.5 update_batch_size = 50 query_batch_size = 10 number_of_operations = 20 @@ -57,7 +57,7 @@ def test_workload_generation(synthetic_dataset): cluster_sample_distribution=cluster_sample_distribution, queries=queries, query_cluster_sample_distribution=query_cluster_sample_distribution, - seed=seed + seed=seed, ) # For testing, use a simple sampler (could also test StratifiedClusterSampler) @@ -84,6 +84,7 @@ def test_workload_generation(synthetic_dataset): op_count = len(runbook["operations"]) assert op_count <= number_of_operations, "Too many operations generated." + def test_workload_evaluation(synthetic_dataset): base_vectors, queries, workload_dir = synthetic_dataset nc = 100 @@ -91,7 +92,6 @@ def test_workload_evaluation(synthetic_dataset): index = QuakeWrapper() - experiment_params = {"n_workers": 2} evaluator = WorkloadEvaluator( workload_dir=workload_dir, output_dir=workload_dir, @@ -103,7 +103,7 @@ def test_workload_evaluation(synthetic_dataset): index=index, build_params=build_params, search_params={"k": 5, "nprobe": 1}, - do_maintenance=True + do_maintenance=True, ) # Basic checks on the returned results @@ -111,4 +111,4 @@ def test_workload_evaluation(synthetic_dataset): for result in results: # For query operations, recall should be a float between 0 and 1 if result["operation_type"] == "query": - assert 0.0 <= result["recall"] <= 1.0, "Recall out of bounds." \ No newline at end of file + assert 0.0 <= result["recall"] <= 1.0, "Recall out of bounds."