2020class RerankingEvaluator (Evaluator ):
2121 """This class evaluates a SentenceTransformer model for the task of re-ranking.
2222 Given a query and a list of documents, it computes the score [query, doc_i] for all possible
23- documents and sorts them in decreasing order. Then, MRR@10 and MAP is compute to measure the quality of the ranking.
23+ documents and sorts them in decreasing order. Then, MRR@k, MAP, and Recall@k are computed to measure the quality of the ranking.
2424 :param samples: Must be a list and each element is of the form:
2525 - {'query': '', 'positive': [], 'negative': []}. Query is the search query, positive is a list of positive
2626 (relevant) documents, negative is a list of negative (irrelevant) documents.
@@ -143,6 +143,7 @@ def compute_metrics_individual(self, model: Encoder):
143143 def _encode_candidates (self , model : Encoder , batched : bool , all_query_embs = None ):
144144 all_mrr_scores = []
145145 all_ap_scores = []
146+ all_recall_scores = []
146147 all_conf_scores = []
147148 logger .info ("Encoding candidates..." )
148149 if batched :
@@ -151,16 +152,18 @@ def _encode_candidates(self, model: Encoder, batched: bool, all_query_embs=None)
151152 all_query_embs = all_query_embs ,
152153 all_mrr_scores = all_mrr_scores ,
153154 all_ap_scores = all_ap_scores ,
155+ all_recall_scores = all_recall_scores ,
154156 all_conf_scores = all_conf_scores ,
155157 )
156158 else :
157159 self ._encode_candidates_individual (
158160 model = model ,
159161 all_mrr_scores = all_mrr_scores ,
160162 all_ap_scores = all_ap_scores ,
163+ all_recall_scores = all_recall_scores ,
161164 all_conf_scores = all_conf_scores ,
162165 )
163- scores = self ._collect_results (all_mrr_scores , all_ap_scores , all_conf_scores )
166+ scores = self ._collect_results (all_mrr_scores , all_ap_scores , all_recall_scores , all_conf_scores )
164167 return scores
165168
166169 def _encode_candidates_batched (
@@ -169,6 +172,7 @@ def _encode_candidates_batched(
169172 model : Encoder ,
170173 all_mrr_scores ,
171174 all_ap_scores ,
175+ all_recall_scores ,
172176 all_conf_scores ,
173177 ):
174178 all_docs = []
@@ -208,6 +212,7 @@ def _encode_candidates_batched(
208212 is_relevant ,
209213 all_mrr_scores ,
210214 all_ap_scores ,
215+ all_recall_scores ,
211216 all_conf_scores ,
212217 model ,
213218 )
@@ -217,6 +222,7 @@ def _encode_candidates_individual(
217222 model : Encoder ,
218223 all_mrr_scores ,
219224 all_ap_scores ,
225+ all_recall_scores ,
220226 all_conf_scores ,
221227 ):
222228 for instance in tqdm .tqdm (self .samples , desc = "Samples" ):
@@ -255,19 +261,22 @@ def _encode_candidates_individual(
255261 is_relevant ,
256262 all_mrr_scores ,
257263 all_ap_scores ,
264+ all_recall_scores ,
258265 all_conf_scores ,
259266 model ,
260267 )
261268
262- def _collect_results (self , all_mrr_scores , all_ap_scores , all_conf_scores ):
269+ def _collect_results (self , all_mrr_scores , all_ap_scores , all_recall_scores , all_conf_scores ):
263270 mean_ap = np .mean (all_ap_scores )
264271 mean_mrr = np .mean (all_mrr_scores )
272+ mean_recall = np .mean (all_recall_scores )
265273
266274 # Compute nAUCs
267275 naucs_map = self .nAUC_scores (all_conf_scores , all_ap_scores , "map" )
268276 naucs_mrr = self .nAUC_scores (all_conf_scores , all_mrr_scores , "mrr" )
277+ naucs_recall = self .nAUC_scores (all_conf_scores , all_recall_scores , f"recall_at_{ self .mrr_at_k } " )
269278
270- return {** {"map" : mean_ap , "mrr" : mean_mrr } , ** naucs_map , ** naucs_mrr }
279+ return {** {"map" : mean_ap , "mrr" : mean_mrr , f"recall_at_ { self . mrr_at_k } " : mean_recall } , ** naucs_map , ** naucs_mrr , ** naucs_recall }
271280
272281 def _encode_candidates_miracl (
273282 self ,
@@ -408,6 +417,7 @@ def _apply_sim_scores(
408417 is_relevant ,
409418 all_mrr_scores ,
410419 all_ap_scores ,
420+ all_recall_scores ,
411421 all_conf_scores ,
412422 model : Encoder ,
413423 ):
@@ -417,6 +427,7 @@ def _apply_sim_scores(
417427
418428 all_mrr_scores .append (scores ["mrr" ])
419429 all_ap_scores .append (scores ["ap" ])
430+ all_recall_scores .append (scores ["recall" ])
420431 all_conf_scores .append (conf_scores )
421432
422433 @staticmethod
@@ -483,11 +494,13 @@ def _compute_metrics_instance(
483494 scores:
484495 - `mrr`: Mean Reciprocal Rank @ `self.mrr_at_k`
485496 - `ap`: Average Precision
497+ - `recall`: Recall @ `self.mrr_at_k`
486498 """
487499 pred_scores_argsort = torch .argsort (- sim_scores ) # Sort in decreasing order
488500 mrr = self .mrr_at_k_score (is_relevant , pred_scores_argsort , self .mrr_at_k )
489501 ap = self .ap_score (is_relevant , sim_scores .cpu ().tolist ())
490- return {"mrr" : mrr , "ap" : ap }
502+ recall = self .recall_at_k_score (is_relevant , pred_scores_argsort , self .mrr_at_k )
503+ return {"mrr" : mrr , "ap" : ap , "recall" : recall }
491504
492505 @staticmethod
493506 def conf_scores (sim_scores : torch .Tensor ) -> dict [str , float ]:
@@ -570,3 +583,29 @@ def ap_score(is_relevant, pred_scores):
570583 # ap = np.mean([np.mean(preds[: k + 1]) for k in range(len(preds)) if preds[k]])
571584 ap = average_precision_score (is_relevant , pred_scores )
572585 return ap
586+
587+ @staticmethod
588+ def recall_at_k_score (
589+ is_relevant : list [bool ], pred_ranking : list [int ], k : int
590+ ) -> float :
591+ """Computes Recall@k score
592+
593+ Args:
594+ is_relevant: True if the document is relevant
595+ pred_ranking: Indices of the documents sorted in decreasing order
596+ of the similarity score
597+ k: Top-k documents to consider
598+
599+ Returns:
600+ The Recall@k score
601+ """
602+ total_relevant = sum (is_relevant )
603+ if total_relevant == 0 :
604+ return 0.0
605+
606+ relevant_retrieved = 0
607+ for rank , index in enumerate (pred_ranking [:k ]):
608+ if is_relevant [index ]:
609+ relevant_retrieved += 1
610+
611+ return relevant_retrieved / total_relevant
0 commit comments