Skip to content

Commit 97dc4b7

Browse files
authored
cost_fn fixes and clarity (#22)
* cost_fn fixes and clarity
1 parent b2521c6 commit 97dc4b7

File tree

18 files changed

+986
-810
lines changed

18 files changed

+986
-810
lines changed

.github/workflows/test.yml

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
name: Tests
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
9+
jobs:
10+
test:
11+
runs-on: ubuntu-latest
12+
13+
services:
14+
redis:
15+
image: redis/redis-stack:latest
16+
ports:
17+
- 6379:6379
18+
options: >-
19+
--health-cmd "redis-cli ping"
20+
--health-interval 10s
21+
--health-timeout 5s
22+
--health-retries 5
23+
24+
strategy:
25+
matrix:
26+
python-version: ["3.12"]
27+
28+
steps:
29+
- uses: actions/checkout@v4
30+
31+
- name: Set up Python ${{ matrix.python-version }}
32+
uses: actions/setup-python@v4
33+
with:
34+
python-version: ${{ matrix.python-version }}
35+
36+
- name: Install Poetry
37+
uses: snok/install-poetry@v1
38+
with:
39+
version: latest
40+
virtualenvs-create: true
41+
virtualenvs-in-project: true
42+
43+
- name: Load cached venv
44+
id: cached-poetry-dependencies
45+
uses: actions/cache@v3
46+
with:
47+
path: .venv
48+
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}
49+
50+
- name: Install dependencies
51+
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
52+
run: poetry install --all-extras
53+
54+
- name: Run tests
55+
run: poetry run test
56+
env:
57+
REDIS_URL: redis://localhost:6379/0

README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ metrics = run_grid_study(
136136
```
137137

138138
#### Example output
139-
| search_method | model | avg_query_time | recall@k | precision | ndcg@k |
139+
| search_method | model | avg_query_time | recall | precision | ndcg |
140140
|----------------|---------------------------------------------|----------------|-----------|-----------|----------|
141141
| weighted_rrf | sentence-transformers/all-MiniLM-L6-v2 | 0.006608 | 0.156129 | 0.261056 | 0.204241 |
142142
| rerank | sentence-transformers/all-MiniLM-L6-v2 | 0.127574 | 0.156039 | 0.260437 | 0.190298 |
@@ -169,7 +169,7 @@ index_settings:
169169
optimization_settings:
170170
# defines weight of each metric in optimization function
171171
metric_weights:
172-
f1_at_k: 1
172+
f1: 1
173173
total_indexing_time: 1
174174
algorithms: ["hnsw"] # indexing algorithm to be included in the study
175175
vector_data_types: ["float16", "float32"] # data types to be included in the study
@@ -214,7 +214,7 @@ metrics = run_bayes_study(
214214
```
215215

216216
#### Example output
217-
| search_method | algorithm | vector_data_type | ef_construction | ef_runtime | m | avg_query_time | total_indexing_time | f1@k |
217+
| search_method | algorithm | vector_data_type | ef_construction | ef_runtime | m | avg_query_time | total_indexing_time | f1 |
218218
|---------------|-----------|------------------|------------------|------------|----|----------------|----------------------|---------|
219219
| hybrid | hnsw | float16 | 200 | 50 | 8 | 0.004628 | 3.559 | 0.130712|
220220
| hybrid | hnsw | float16 | 200 | 50 | 64 | 0.004498 | 4.804 | 0.130712|
@@ -296,7 +296,7 @@ cache = SemanticCache(
296296

297297
# Add some data to the cache
298298
paris_key = cache.store(
299-
prompt="what is the capital of france?",
299+
prompt="what is the capital of france?",
300300
response="paris"
301301
)
302302

@@ -307,7 +307,7 @@ test_data = [
307307
"query_match": paris_key # Expected cache hit
308308
},
309309
{
310-
"query": "What's the capital of Britain?",
310+
"query": "What's the capital of Britain?",
311311
"query_match": "" # Expected cache miss
312312
}
313313
]
@@ -337,7 +337,7 @@ routes = [
337337
distance_threshold=0.5,
338338
),
339339
Route(
340-
name="farewell",
340+
name="farewell",
341341
references=["bye", "goodbye"],
342342
metadata={"type": "farewell"},
343343
distance_threshold=0.5,
@@ -509,7 +509,7 @@ metrics = run_grid_study(
509509

510510
### Example output
511511

512-
| search_method | model | avg_query_time | recall@k | precision | ndcg@k |
512+
| search_method | model | avg_query_time | recall | precision | ndcg |
513513
|-------------------|---------------------------------------------|----------------|-----------|-----------|----------|
514514
| pre_filter_vector | sentence-transformers/all-MiniLM-L6-v2 | 0.001177 | 1.0 | 0.25 | 0.914903 |
515515
| basic_vector | sentence-transformers/all-MiniLM-L6-v2 | 0.002605 | 0.9 | 0.23 | 0.717676 |

docs/examples/bayesian_optimization/00_bayes_study.ipynb

Lines changed: 695 additions & 712 deletions
Large diffs are not rendered by default.

docs/examples/bayesian_optimization/bayes_study_config.yaml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@ index_settings:
1717
optimization_settings:
1818
# defines weight of each metric in optimization function
1919
metric_weights:
20-
f1_at_k: 1
21-
total_indexing_time: 1
20+
f1: 2
21+
total_indexing_time: 2
22+
avg_query_time: 2
23+
recall: 2
24+
ndcg: 2
25+
precision: 2
2226
algorithms: ["hnsw"] # indexing algorithm to be included in the study
2327
vector_data_types: ["float16", "float32"] # data types to be included in the study
2428
distance_metrics: ["cosine"] # distance metrics to be included in the study

docs/examples/comparison/00_comparison.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3641,7 +3641,7 @@
36413641
}
36423642
],
36433643
"source": [
3644-
"metrics[[\"search_method\", \"model\", \"model_dim\", 'total_indexing_time', \"avg_query_time\", \"recall@k\", \"precision\", \"ndcg@k\"]].sort_values(by=\"ndcg@k\", ascending=False)"
3644+
"metrics[[\"search_method\", \"model\", \"model_dim\", 'total_indexing_time', \"avg_query_time\", \"recall\", \"precision\", \"ndcg\"]].sort_values(by=\"ndcg\", ascending=False)"
36453645
]
36463646
},
36473647
{

docs/examples/grid_study/00_grid_study.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1501,7 +1501,7 @@
15011501
}
15021502
],
15031503
"source": [
1504-
"metrics[[\"search_method\", \"model\", \"avg_query_time\", \"recall@k\", \"precision\", \"ndcg@k\"]].sort_values(by=\"ndcg@k\", ascending=False)"
1504+
"metrics[[\"search_method\", \"model\", \"avg_query_time\", \"recall\", \"precision\", \"ndcg\"]].sort_values(by=\"ndcg\", ascending=False)"
15051505
]
15061506
}
15071507
],

docs/examples/grid_study/01_custom_grid_study.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@
562562
}
563563
],
564564
"source": [
565-
"metrics[[\"search_method\", \"model\", \"avg_query_time\", \"recall@k\", \"precision\", \"ndcg@k\"]].sort_values(by=\"ndcg@k\", ascending=False)"
565+
"metrics[[\"search_method\", \"model\", \"avg_query_time\", \"recall\", \"precision\", \"ndcg\"]].sort_values(by=\"ndcg\", ascending=False)"
566566
]
567567
}
568568
],

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "redis-retrieval-optimizer"
3-
version = "0.2.0"
3+
version = "0.2.1"
44
description = "A tool to help optimize information retrieval with the Redis Query Engine."
55
authors = [ "Robert Shelton <[email protected]>" ]
66
license = "MIT"

redis_retrieval_optimizer/bayes_study.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,17 @@
2828
"model": [],
2929
"model_dim": [],
3030
"ret_k": [],
31-
"recall@k": [],
32-
"ndcg@k": [],
33-
"f1@k": [],
31+
"recall": [],
32+
"ndcg": [],
33+
"f1": [],
3434
"precision": [],
3535
"algorithm": [],
3636
"ef_construction": [],
3737
"ef_runtime": [],
3838
"m": [],
3939
"distance_metric": [],
4040
"vector_data_type": [],
41+
"objective_value": [],
4142
}
4243

4344

@@ -52,12 +53,13 @@ def update_metric_row(trial_settings: TrialSettings, trial_metrics: dict):
5253
METRICS["vector_data_type"].append(trial_settings.index_settings.vector_data_type)
5354
METRICS["model"].append(trial_settings.embedding.model)
5455
METRICS["model_dim"].append(trial_settings.embedding.dim)
55-
METRICS["recall@k"].append(trial_metrics["recall"])
56-
METRICS["ndcg@k"].append(trial_metrics["ndcg"])
56+
METRICS["recall"].append(trial_metrics["recall"])
57+
METRICS["ndcg"].append(trial_metrics["ndcg"])
5758
METRICS["precision"].append(trial_metrics["precision"])
58-
METRICS["f1@k"].append(trial_metrics["f1"])
59+
METRICS["f1"].append(trial_metrics["f1"])
5960
METRICS["total_indexing_time"].append(trial_metrics["total_indexing_time"])
6061
METRICS["avg_query_time"].append(trial_metrics["avg_query_time"])
62+
METRICS["objective_value"].append(trial_metrics["objective_value"])
6163

6264

6365
def persist_metrics(
@@ -70,17 +72,30 @@ def persist_metrics(
7072
client.json().set(f"study:{study_id}", Path.root_path(), METRICS)
7173

7274

75+
def norm_metric(value: float):
76+
"""Normalize a metric value using 1/(1+value) formula.
77+
78+
Handles edge cases:
79+
- When value is -1, returns a large positive number (infinity equivalent)
80+
- When value is very negative, returns a large positive number
81+
- When value is very positive, returns a small positive number
82+
"""
83+
if value == -1:
84+
# Return a large positive number to represent "infinity" for optimization
85+
return 1000.0
86+
return 1 / (1 + value)
87+
88+
7389
def cost_fn(metrics: dict, weights: dict):
7490
objective = 0
7591
for key in metrics:
76-
objective += weights.get(key, 0) * metrics[key]
92+
if key == "avg_query_time" or key == "total_indexing_time":
93+
objective += weights.get(key, 0) * -norm_metric(metrics[key])
94+
else:
95+
objective += weights.get(key, 0) * metrics[key]
7796
return objective
7897

7998

80-
def norm_metric(value: float):
81-
return 1 / (1 + value)
82-
83-
8499
def objective(trial, study_config, redis_url, corpus_processor, search_method_map):
85100

86101
# optimizer will select hyperparameters from available option in study_config
@@ -152,19 +167,19 @@ def objective(trial, study_config, redis_url, corpus_processor, search_method_ma
152167
search_method_output = search_fn(search_input)
153168

154169
trial_metrics = utils.eval_trial_metrics(qrels, search_method_output.run)
155-
trial_metrics["total_indexing_time"] = -(total_indexing_time)
156-
trial_metrics["avg_query_time"] = -(
157-
utils.get_query_time_stats(search_method_output.query_metrics.query_times)[
158-
"avg_query_time"
159-
]
170+
trial_metrics["total_indexing_time"] = total_indexing_time
171+
trial_metrics["avg_query_time"] = utils.get_query_time_stats(
172+
search_method_output.query_metrics.query_times
173+
)["avg_query_time"]
174+
175+
trial_metrics["objective_value"] = cost_fn(
176+
trial_metrics, study_config.optimization_settings.metric_weights.model_dump()
160177
)
161178

162179
# save results as we go in case of failure
163180
persist_metrics(redis_url, trial_settings, trial_metrics, study_config.study_id)
164181

165-
return cost_fn(
166-
trial_metrics, study_config.optimization_settings.metric_weights.model_dump()
167-
)
182+
return trial_metrics["objective_value"]
168183

169184

170185
def run_bayes_study(

redis_retrieval_optimizer/grid_study.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ def update_metric_row(
2929
)
3030
metrics["model"].append(embedding_settings.model)
3131
metrics["model_dim"].append(embedding_settings.dim)
32-
metrics["recall@k"].append(trial_metrics["recall"])
33-
metrics["ndcg@k"].append(trial_metrics["ndcg"])
32+
metrics["recall"].append(trial_metrics["recall"])
33+
metrics["ndcg"].append(trial_metrics["ndcg"])
3434
metrics["precision"].append(trial_metrics["precision"])
35-
metrics["f1@k"].append(trial_metrics["f1"])
35+
metrics["f1"].append(trial_metrics["f1"])
3636
metrics["total_indexing_time"].append(trial_metrics["total_indexing_time"])
3737
metrics["avg_query_time"].append(trial_metrics["query_stats"]["avg_query_time"])
3838
return metrics
@@ -125,9 +125,9 @@ def run_grid_study(
125125
"search_method": [],
126126
"total_indexing_time": [],
127127
"avg_query_time": [],
128-
"recall@k": [],
129-
"ndcg@k": [],
130-
"f1@k": [],
128+
"recall": [],
129+
"ndcg": [],
130+
"f1": [],
131131
"precision": [],
132132
"ret_k": [],
133133
"algorithm": [],

0 commit comments

Comments
 (0)