|
| 1 | +import os |
| 2 | +import pickle |
| 3 | +import random |
| 4 | +import time |
| 5 | +from typing import List |
| 6 | + |
| 7 | +import hnswlib |
| 8 | +from sentence_transformers import SentenceTransformer |
| 9 | + |
| 10 | +# MODEL VARIABLES |
| 11 | +MODEL_NAME = "intfloat/e5-base-v2" |
| 12 | +EMBEDDING_SIZE = 768 |
| 13 | + |
| 14 | +# EMBEDDING VARIABLES |
| 15 | +EMBEDDING_CACHE_DIR = "../data/training_files/embeddings/" |
| 16 | +EMBEDDING_FILE = "loinc_lab_names_intfloat_e5-base-v2_20251007" |
| 17 | + |
| 18 | +# GRID-SEARCH ANN PARAMS |
| 19 | +# EF-value is described as the "speed/accuracy" tradeoff metric for HNSW |
| 20 | +# search. EF typically ranges from 50 to 1000, with a default value being |
| 21 | +# 200. Higher values of EF will increase recall compared to exact search, |
| 22 | +# (i.e. results will tend to look more like exact kNN), but will increase |
| 23 | +# search time in a nonlinear fashion. |
| 24 | +EF_CONSTRUCTION = 200 |
| 25 | +# M-Value is the number of connections/neighbors made per "node" in the |
| 26 | +# search graph. It represents how many embedded vectors are considered to |
| 27 | +# be in the "small world" defined around each other vector. Higher values |
| 28 | +# of M increase recall compared to exact search, but also slow down |
| 29 | +# the search time. |
| 30 | +M_VALUE = 48 |
| 31 | +# These are the range of EF values we want to test during our grid search. |
| 32 | +# The EF-value that an HNSW index is constructed with *does not* need to be |
| 33 | +# the EF-value that index is searched with. The search EF can range from |
| 34 | +# 0 to 1000, just like the initial EF used during construction. The initial |
| 35 | +# EF controls how many "small worlds" get attached as branches in the |
| 36 | +# search graph, while this "search EF" controls how many actually get |
| 37 | +# explored by the algorithm during ANN. |
| 38 | +EFS_TO_TEST = [50, 100, 200, 400, 600, 800, 1000] |
| 39 | + |
| 40 | +# VALIDATION VARIABLES |
| 41 | +VALIDATION_FILE = "../data/training_files/validation_set_positive_pairs.txt" |
| 42 | +# This is the "k" value in KNN, how many approximate neighbors we'll be |
| 43 | +# retrieving. The script does not optimize a search over K, but the choice of |
| 44 | +# K does directly influence the ordered-recall calculation (e.g. more neighbors |
| 45 | +# means a better sample to compare ANN to exact KNN). |
| 46 | +NUM_NEIGHBORS_TO_SEARCH = 10 |
| 47 | + |
| 48 | +# IMPORTANT: Change this value to calculate stats using more or less |
| 49 | +# examples drawn from the validation set. |
| 50 | +NUM_EXAMPLES_TO_VALIDATE = 10000 |
| 51 | + |
| 52 | + |
| 53 | +def run_recall_trial( |
| 54 | + model: SentenceTransformer, |
| 55 | + hnsw_index: hnswlib.Index, |
| 56 | + bf_index: hnswlib.Index, |
| 57 | + examples: List[List[str]], |
| 58 | + k: int, |
| 59 | + ef: int, |
| 60 | +) -> None: |
| 61 | + """ |
| 62 | + Perform a single search in a grid of trials to compare approximate search |
| 63 | + with exact search. Importantly, the goal of a recall trial is *not* to |
| 64 | + maximize accuracy. Model analysis is a separate task. The goal of ANN |
| 65 | + hyperparameter optimization is to get the approximate search to behave |
| 66 | + as closely as possible to exact search in terms of which results are |
| 67 | + retrieved and the relative rankings of those results. This allows other |
| 68 | + notebooks to optimize for Top-K performance. |
| 69 | +
|
| 70 | + :param model: The sentence transformers model to evaluate. |
| 71 | + :param hnsw_index: An HNSW index file computed over the embeddings. |
| 72 | + :param bf_index: A brute force index file computed over the embeddings. |
| 73 | + :param examples: A list of validation samples on which to evaluate recall. |
| 74 | + :param k: The number of search results to retrieve. |
| 75 | + :param ef: The search depth to use as part of this optimization. |
| 76 | + """ |
| 77 | + num_correct = 0.0 |
| 78 | + search_times = [] |
| 79 | + |
| 80 | + for e in examples: |
| 81 | + nonstandard_in = e[1].strip() |
| 82 | + |
| 83 | + # Unlike embedding, which can convert to tensor on GPU, HNSW exists in |
| 84 | + # CPU memory, so we leave as is |
| 85 | + enc = model.encode(nonstandard_in) |
| 86 | + start = time.time() |
| 87 | + labels_hnsw, _ = hnsw_index.knn_query(enc, k=k) |
| 88 | + search_times.append(time.time() - start) |
| 89 | + labels_bf, _ = bf_index.knn_query(enc, k=k) |
| 90 | + |
| 91 | + for label in labels_hnsw[0]: |
| 92 | + for correct_label in labels_bf[0]: |
| 93 | + # We're counting only the instances where the elements between |
| 94 | + # HNSW and brute force match |
| 95 | + if label == correct_label: |
| 96 | + num_correct += 1 |
| 97 | + break |
| 98 | + |
| 99 | + recall = round(num_correct / float(k * len(examples)), 3) |
| 100 | + mean_search_time = round(float(sum(search_times)) / float(len(search_times)), 3) |
| 101 | + |
| 102 | + print(f"Speed/Accuracy Tradeoff for K = {k}, EF = {ef}") |
| 103 | + print(f" Recall: {recall}") |
| 104 | + print(f" Mean Search Time: {mean_search_time}") |
| 105 | + |
| 106 | + |
| 107 | +if __name__ == "__main__": |
| 108 | + print("Instantiating language model...") |
| 109 | + model = SentenceTransformer(MODEL_NAME) |
| 110 | + |
| 111 | + print("Checking for cached embeddings...") |
| 112 | + if os.path.exists(EMBEDDING_CACHE_DIR + EMBEDDING_FILE): |
| 113 | + print(" Found cached embeddings. Loading them...") |
| 114 | + with open(EMBEDDING_CACHE_DIR + EMBEDDING_FILE, "rb") as fp: |
| 115 | + cache_data = pickle.load(fp) |
| 116 | + name_codes = cache_data["codes"] |
| 117 | + embeddings = cache_data["embeddings"] |
| 118 | + embeddings = embeddings.cpu().numpy() |
| 119 | + |
| 120 | + print("Loading validation set...") |
| 121 | + examples = [] |
| 122 | + with open(VALIDATION_FILE, "r") as fp: |
| 123 | + for line in fp: |
| 124 | + if line.strip() != "": |
| 125 | + examples.append(line.strip().split("|")) |
| 126 | + random.shuffle(examples) |
| 127 | + examples = examples[:NUM_EXAMPLES_TO_VALIDATE] |
| 128 | + |
| 129 | + print("Initializing Indices: Regular and Brute Force") |
| 130 | + hnsw_index = hnswlib.Index(space="cosine", dim=EMBEDDING_SIZE) |
| 131 | + bf_index = hnswlib.BFIndex(space="cosine", dim=EMBEDDING_SIZE) |
| 132 | + hnsw_index.init_index( |
| 133 | + max_elements=len(embeddings), ef_construction=EF_CONSTRUCTION, M=M_VALUE |
| 134 | + ) |
| 135 | + bf_index.init_index(max_elements=len(embeddings)) |
| 136 | + |
| 137 | + hnsw_index.add_items(embeddings) |
| 138 | + bf_index.add_items(embeddings) |
| 139 | + |
| 140 | + print("Performing grid-search on EF to identify optimal value...") |
| 141 | + for ef in EFS_TO_TEST: |
| 142 | + hnsw_index.set_ef(ef) |
| 143 | + run_recall_trial(model, hnsw_index, bf_index, examples, NUM_NEIGHBORS_TO_SEARCH, ef) |
| 144 | + |
| 145 | + else: |
| 146 | + print("No embeddings found, please run embedding.py to compute vectors first.") |
0 commit comments