Skip to content

Commit 7f25d8e

Browse files
authored
Incorporate ANN (#119)
## Description This PR introduces approximate neighbor search to our performance metrics notebook. It also creates a script that uses a relative grid-search to optimize the parameters for HNSW evaluation. ## Related Issues Closes #106
1 parent 6f0cbca commit 7f25d8e

File tree

5 files changed

+279
-15
lines changed

5 files changed

+279
-15
lines changed

model_tuning/build_hnsw_index.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""
2+
build_hnsw_index.py
3+
4+
5+
Simple script for creating an HNSW index for a specific set of model
6+
vector embeddings. This index can be persisted to disk for faster
7+
instantiation during performance metric computation.
8+
"""
9+
10+
import os
11+
import pickle
12+
13+
import hnswlib
14+
15+
# MODEL VARIABLES
16+
MODEL_NAME = "intfloat/e5-base-v2"
17+
EMBEDDING_SIZE = 768
18+
19+
# EMBEDDING VARIABLES
20+
EMBEDDING_CACHE_DIR = "../data/training_files/embeddings/"
21+
EMBEDDING_FILE = "loinc_lab_names_intfloat_e5-base-v2_20251007"
22+
23+
# ANN INDEX VARIABLES
24+
INDEX_FP = f"hnswlib_index_{MODEL_NAME.replace('/', '_')}.index"
25+
EF_VALUE = 200
26+
M_VALUE = 64
27+
28+
29+
if __name__ == "__main__":
30+
print("Checking for cached embeddings...")
31+
if os.path.exists(EMBEDDING_CACHE_DIR + EMBEDDING_FILE):
32+
print(" Found cached embeddings. Loading them...")
33+
with open(EMBEDDING_CACHE_DIR + EMBEDDING_FILE, "rb") as fp:
34+
cache_data = pickle.load(fp)
35+
name_codes = cache_data["codes"]
36+
embeddings = cache_data["embeddings"]
37+
embeddings = embeddings.cpu().numpy()
38+
39+
index = hnswlib.Index(space="cosine", dim=EMBEDDING_SIZE)
40+
print("Checking for cached ANN index...")
41+
if os.path.exists(INDEX_FP):
42+
print(" Cached index already exists.")
43+
else:
44+
print(f"No local index found. Creating index for {MODEL_NAME}...")
45+
index.init_index(max_elements=len(embeddings), ef_construction=EF_VALUE, M=M_VALUE)
46+
print(" Index created, adding vectors...")
47+
index.add_items(embeddings, list(range(len(embeddings))))
48+
print(" Vectors embedded, saving index...")
49+
index.save_index(INDEX_FP)
50+
else:
51+
print("No embeddings found, please run embedding.py to compute vectors first.")

model_tuning/cpu_convert.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
cpu_convert.py
3+
4+
5+
Simple script for converting collections of embedded vectors that were
6+
built using GPU / Tensor optimization to purely CPU-compatible.
7+
8+
Vector embeddings *must* be CPU-formatted for use with Azure ML Studio's
9+
copy of the `performance.ipynb` notebook.
10+
"""
11+
12+
import pickle
13+
14+
# Directory in which the embeddings are saved
15+
EMBEDDING_CACHE_DIR = "../data/training_files/embeddings/"
16+
17+
# The original embedding file that may have been saved in a GPU-based
18+
# format
19+
GPU_PICKLE_FILE = "loinc_lab_names_intfloat_e5-base-v2_20251007"
20+
21+
# The new embedding file to write after conversion to pure CPU formatting
22+
CPU_PICKLE_FILE = "loinc_lab_names_intfloat_e5-base-v2_20251007_cpu"
23+
24+
25+
if __name__ == "__main__":
26+
print("Loading pickled tensor embeddings...")
27+
with open(EMBEDDING_CACHE_DIR + GPU_PICKLE_FILE, "rb") as fp:
28+
cache_data = pickle.load(fp)
29+
name_codes = cache_data["codes"]
30+
embeddings = cache_data["embeddings"]
31+
32+
print("Converting to CPU and writing back...")
33+
embeddings = embeddings.cpu()
34+
with open(EMBEDDING_CACHE_DIR + CPU_PICKLE_FILE, "wb") as fp:
35+
pickle.dump({"codes": name_codes, "embeddings": embeddings}, fp)

model_tuning/hnsw_estimator.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import os
2+
import pickle
3+
import random
4+
import time
5+
from typing import List
6+
7+
import hnswlib
8+
from sentence_transformers import SentenceTransformer
9+
10+
# MODEL VARIABLES
11+
MODEL_NAME = "intfloat/e5-base-v2"
12+
EMBEDDING_SIZE = 768
13+
14+
# EMBEDDING VARIABLES
15+
EMBEDDING_CACHE_DIR = "../data/training_files/embeddings/"
16+
EMBEDDING_FILE = "loinc_lab_names_intfloat_e5-base-v2_20251007"
17+
18+
# GRID-SEARCH ANN PARAMS
19+
# EF-value is described as the "speed/accuracy" tradeoff metric for HNSW
20+
# search. EF typically ranges from 50 to 1000, with a default value being
21+
# 200. Higher values of EF will increase recall compared to exact search,
22+
# (i.e. results will tend to look more like exact kNN), but will increase
23+
# search time in a nonlinear fashion.
24+
EF_CONSTRUCTION = 200
25+
# M-Value is the number of connections/neighbors made per "node" in the
26+
# search graph. It represents how many embedded vectors are considered to
27+
# be in the "small world" defined around each other vector. Higher values
28+
# of M increase recall compared to exact search, but also slow down
29+
# the search time.
30+
M_VALUE = 48
31+
# These are the range of EF values we want to test during our grid search.
32+
# The EF-value that an HNSW index is constructed with *does not* need to be
33+
# the EF-value that index is searched with. The search EF can range from
34+
# 0 to 1000, just like the initial EF used during construction. The initial
35+
# EF controls how many "small worlds" get attached as branches in the
36+
# search graph, while this "search EF" controls how many actually get
37+
# explored by the algorithm during ANN.
38+
EFS_TO_TEST = [50, 100, 200, 400, 600, 800, 1000]
39+
40+
# VALIDATION VARIABLES
41+
VALIDATION_FILE = "../data/training_files/validation_set_positive_pairs.txt"
42+
# This is the "k" value in KNN, how many approximate neighbors we'll be
43+
# retrieving. The script does not optimize a search over K, but the choice of
44+
# K does directly influence the ordered-recall calculation (e.g. more neighbors
45+
# means a better sample to compare ANN to exact KNN).
46+
NUM_NEIGHBORS_TO_SEARCH = 10
47+
48+
# IMPORTANT: Change this value to calculate stats using more or less
49+
# examples drawn from the validation set.
50+
NUM_EXAMPLES_TO_VALIDATE = 10000
51+
52+
53+
def run_recall_trial(
54+
model: SentenceTransformer,
55+
hnsw_index: hnswlib.Index,
56+
bf_index: hnswlib.Index,
57+
examples: List[List[str]],
58+
k: int,
59+
ef: int,
60+
) -> None:
61+
"""
62+
Perform a single search in a grid of trials to compare approximate search
63+
with exact search. Importantly, the goal of a recall trial is *not* to
64+
maximize accuracy. Model analysis is a separate task. The goal of ANN
65+
hyperparameter optimization is to get the approximate search to behave
66+
as closely as possible to exact search in terms of which results are
67+
retrieved and the relative rankings of those results. This allows other
68+
notebooks to optimize for Top-K performance.
69+
70+
:param model: The sentence transformers model to evaluate.
71+
:param hnsw_index: An HNSW index file computed over the embeddings.
72+
:param bf_index: A brute force index file computed over the embeddings.
73+
:param examples: A list of validation samples on which to evaluate recall.
74+
:param k: The number of search results to retrieve.
75+
:param ef: The search depth to use as part of this optimization.
76+
"""
77+
num_correct = 0.0
78+
search_times = []
79+
80+
for e in examples:
81+
nonstandard_in = e[1].strip()
82+
83+
# Unlike embedding, which can convert to tensor on GPU, HNSW exists in
84+
# CPU memory, so we leave as is
85+
enc = model.encode(nonstandard_in)
86+
start = time.time()
87+
labels_hnsw, _ = hnsw_index.knn_query(enc, k=k)
88+
search_times.append(time.time() - start)
89+
labels_bf, _ = bf_index.knn_query(enc, k=k)
90+
91+
for label in labels_hnsw[0]:
92+
for correct_label in labels_bf[0]:
93+
# We're counting only the instances where the elements between
94+
# HNSW and brute force match
95+
if label == correct_label:
96+
num_correct += 1
97+
break
98+
99+
recall = round(num_correct / float(k * len(examples)), 3)
100+
mean_search_time = round(float(sum(search_times)) / float(len(search_times)), 3)
101+
102+
print(f"Speed/Accuracy Tradeoff for K = {k}, EF = {ef}")
103+
print(f" Recall: {recall}")
104+
print(f" Mean Search Time: {mean_search_time}")
105+
106+
107+
if __name__ == "__main__":
108+
print("Instantiating language model...")
109+
model = SentenceTransformer(MODEL_NAME)
110+
111+
print("Checking for cached embeddings...")
112+
if os.path.exists(EMBEDDING_CACHE_DIR + EMBEDDING_FILE):
113+
print(" Found cached embeddings. Loading them...")
114+
with open(EMBEDDING_CACHE_DIR + EMBEDDING_FILE, "rb") as fp:
115+
cache_data = pickle.load(fp)
116+
name_codes = cache_data["codes"]
117+
embeddings = cache_data["embeddings"]
118+
embeddings = embeddings.cpu().numpy()
119+
120+
print("Loading validation set...")
121+
examples = []
122+
with open(VALIDATION_FILE, "r") as fp:
123+
for line in fp:
124+
if line.strip() != "":
125+
examples.append(line.strip().split("|"))
126+
random.shuffle(examples)
127+
examples = examples[:NUM_EXAMPLES_TO_VALIDATE]
128+
129+
print("Initializing Indices: Regular and Brute Force")
130+
hnsw_index = hnswlib.Index(space="cosine", dim=EMBEDDING_SIZE)
131+
bf_index = hnswlib.BFIndex(space="cosine", dim=EMBEDDING_SIZE)
132+
hnsw_index.init_index(
133+
max_elements=len(embeddings), ef_construction=EF_CONSTRUCTION, M=M_VALUE
134+
)
135+
bf_index.init_index(max_elements=len(embeddings))
136+
137+
hnsw_index.add_items(embeddings)
138+
bf_index.add_items(embeddings)
139+
140+
print("Performing grid-search on EF to identify optimal value...")
141+
for ef in EFS_TO_TEST:
142+
hnsw_index.set_ef(ef)
143+
run_recall_trial(model, hnsw_index, bf_index, examples, NUM_NEIGHBORS_TO_SEARCH, ef)
144+
145+
else:
146+
print("No embeddings found, please run embedding.py to compute vectors first.")

model_tuning/performance.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,24 @@
44
import time
55
from typing import List
66

7+
import hnswlib
78
from sentence_transformers import SentenceTransformer
8-
from sentence_transformers import util
9-
from torch import Tensor
109

10+
# MODEL VARIABLES
1111
MODEL_NAME = "intfloat/e5-base-v2"
12+
EMBEDDING_SIZE = 768
13+
14+
# EMBEDDING VARIABLES
1215
EMBEDDING_CACHE_DIR = "../data/training_files/embeddings/"
1316
EMBEDDING_FILE = "loinc_lab_names_intfloat_e5-base-v2_20251007"
17+
18+
# ANN INDEX VARIABLES
19+
INDEX_FP = "./hnswlib.index"
20+
EF_CONSTRUCTION = 200
21+
M_VALUE = 64
22+
EF_SEARCH = 100
23+
24+
# VALIDATION VARIABLES
1425
VALIDATION_FILE = "../data/training_files/validation_set_positive_pairs.txt"
1526
K_VALUES = [1, 3, 5, 10]
1627

@@ -21,7 +32,7 @@
2132

2233
def predict_and_evaluate_validation_set(
2334
model: SentenceTransformer,
24-
vector_db: Tensor,
35+
ann_index: hnswlib.Index,
2536
standard_loinc_names: List[str],
2637
examples: List[List[str]],
2738
k_vals: List[int],
@@ -35,14 +46,15 @@ def predict_and_evaluate_validation_set(
3546
scoring result, and mean time to encode an input and perform semantic search.
3647
3748
:param model: The sentence transformer model to evaluate.
38-
:param vector_db: A list of pre-computed embeddings on the corpus in which
39-
to semantic search (these are the embedded standard LOINC codes).
49+
:param ann_index: A pre-computed HNSW index file over the embeddings that
50+
we want to match nonstandard inputs to.
4051
:param standard_loinc_names: A list of strings representing the names of
4152
the LOINC codes embedded in the `vector_db`. Note that the order of
4253
strings in the list should match the order of embeddings in the DB.
4354
:param examples: A list of lists of strings representing the experimental
4455
examples to evaluate.
45-
:param k: An integer for how many neighbors to retrieve from the DB.
56+
:param k_vals: A list of integers indicating how many neighbors should be
57+
retrieved from the DB across a range of trials.
4658
:returns: None
4759
"""
4860
encoding_times = []
@@ -57,20 +69,19 @@ def predict_and_evaluate_validation_set(
5769
correct_code = e[0].strip()
5870
nonstandard_in = e[1].strip()
5971

60-
# This utility performs exact neighbor semantic search
61-
# If approximate is desired, see
62-
# https://sbert.net/examples/sentence_transformer/applications/semantic-search/README.html#approximate-nearest-neighbor # noqa
63-
# for details
6472
start = time.time()
65-
enc = model.encode(nonstandard_in, convert_to_tensor=True)
73+
enc = model.encode(nonstandard_in)
6674
encoding_times.append(time.time() - start)
6775

6876
for k in k_vals:
6977
start = time.time()
70-
hits = util.semantic_search(enc, vector_db, top_k=k)
71-
hits = hits[0]
78+
embedding_ids, distances = ann_index.knn_query(enc, k=k)
79+
hits = [
80+
{"corpus_id": id, "score": 1 - dist}
81+
for id, dist in zip(embedding_ids[0], distances[0])
82+
]
83+
hits = sorted(hits, key=lambda x: x["score"], reverse=True)
7284

73-
# Store some metrics
7485
times[k].append(time.time() - start)
7586
cosine_sims[k].append(hits[0]["score"])
7687

@@ -110,6 +121,21 @@ def predict_and_evaluate_validation_set(
110121
cache_data = pickle.load(fp)
111122
name_codes = cache_data["codes"]
112123
embeddings = cache_data["embeddings"]
124+
embeddings = embeddings.cpu().numpy()
125+
126+
index = hnswlib.Index(space="cosine", dim=EMBEDDING_SIZE)
127+
print("Checking for cached ANN index...")
128+
if os.path.exists(INDEX_FP):
129+
print(" Found cached index. Loading it...")
130+
index.load_index(INDEX_FP)
131+
else:
132+
print("No locally cached index found. Creating hierarchical index...")
133+
index.init_index(
134+
max_elements=len(embeddings), ef_construction=EF_CONSTRUCTION, M=M_VALUE
135+
)
136+
index.add_items(embeddings, list(range(len(embeddings))))
137+
index.save_index(INDEX_FP)
138+
index.set_ef(EF_SEARCH)
113139

114140
print("Loading validation set...")
115141
examples = []
@@ -119,7 +145,9 @@ def predict_and_evaluate_validation_set(
119145
examples.append(line.strip().split("|"))
120146

121147
print("Predicting and computing stats for validation set...")
122-
predict_and_evaluate_validation_set(model, embeddings, name_codes, examples, K_VALUES)
148+
predict_and_evaluate_validation_set(
149+
model, index, embeddings, name_codes, examples, K_VALUES
150+
)
123151

124152
else:
125153
print("No embeddings found, please run embedding.py to compute vectors first.")

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ dependencies = [
2626
"spacy-lookups-data",
2727
"sentence-transformers",
2828
"scikit-learn",
29+
# HNSWLIB might actually not be needed in our main dependencies, but we won't know that
30+
# until we get into AWS and see how OpenSearch structures things. We'll leave this in
31+
# main for now, but will remove in prod if we're able to migrate it to a dev dependency.
32+
"hnswlib",
2933
"pydantic-settings",
3034
# Typing
3135
"aws-lambda-typing",

0 commit comments

Comments
 (0)