1616
1717QUERIES_FILENAME : str = "queries.json.bz2"
1818QUERIES_RECALL_FILENAME : str = "queries-recall.json.bz2"
19+ QUERIES_RECALL_10M_FILENAME : str = "queries-recall-10m.json.bz2"
1920
2021
2122def extract_vector_operations_count (knn_result ):
@@ -147,6 +148,7 @@ def params(self):
147148 "num_candidates" : self ._params .get ("num-candidates" , 100 ),
148149 "visit_percentage" : self ._params .get ("visit-percentage" , - 1 ),
149150 "oversample_rescore" : self ._params .get ("oversample-rescore" , - 1 ),
151+ "recall_doc_set" : self ._params .get ("recall-doc-set" , - 1 ),
150152 }
151153
152154
@@ -157,6 +159,7 @@ async def __call__(self, es, params):
157159 visit_percentage = params ["visit_percentage" ]
158160 index = params ["index" ]
159161 request_cache = params ["cache" ]
162+ recall_doc_set = params ["recall_doc_set" ]
160163
161164 cwd = os .path .dirname (__file__ )
162165 qrels = read_qrels (os .path .join (cwd , "qrels.tsv" ))
@@ -166,7 +169,13 @@ async def __call__(self, es, params):
166169 exact_total = 0
167170 min_recall = top_k
168171 nodes_visited = []
169- with bz2 .open (os .path .join (cwd , QUERIES_RECALL_FILENAME ), "r" ) as queries_file :
172+
173+ if recall_doc_set == "10m" :
174+ queries_recall = QUERIES_RECALL_10M_FILENAME
175+ else :
176+ queries_recall = QUERIES_RECALL_FILENAME
177+
178+ with bz2 .open (os .path .join (cwd , queries_recall ), "r" ) as queries_file :
170179 for line in queries_file :
171180 query = json .loads (line )
172181 query_id = query ["query_id" ]
0 commit comments