Skip to content

Commit e13ddd8

Browse files
committed
OpenConceptLab/ocl_issues#2163 | Optional use of reranker for authorised users
1 parent 4fc3e00 commit e13ddd8

File tree

10 files changed

+159
-47
lines changed

10 files changed

+159
-47
lines changed

core/common/search.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pydash import compact, get
77

88
from core.common.constants import ES_REQUEST_TIMEOUT
9-
from core.common.utils import is_url_encoded_string
9+
from core.common.utils import is_url_encoded_string, get_cross_encoder
1010

1111

1212
class CustomESFacetedSearch(FacetedSearch):
@@ -203,28 +203,28 @@ def apply_aggregation_score_histogram(self):
203203
def apply_aggregation_score_stats(self):
204204
self._dsl_search.aggs.bucket("score", "stats", script="_score")
205205

206-
def to_queryset(self, keep_order=True, normalized_score=False, exact_count=True): # pylint:disable=too-many-locals
206+
def to_queryset(self, keep_order=True, normalized_score=False, exact_count=True, txt=None, encoder_model=None): # pylint:disable=too-many-locals,too-many-arguments
207207
"""
208208
This method return a django queryset from the an elasticsearch result.
209209
It cost a query to the sql db.
210210
"""
211-
import time
212-
start_time = time.time()
213-
s, hits, total = self.__get_response(exact_count)
214-
print("ES query execute", time.time() - start_time)
211+
encoder = bool(txt)
212+
s, hits, total = self.__get_response(exact_count, encoder)
215213
max_score = hits.max_score or 1
216214

217-
start_time = time.time()
218-
for result in hits.hits:
215+
hits = get_cross_encoder(txt, hits.hits, encoder_model) if encoder else hits.hits
216+
for result in hits:
219217
_id = get(result, '_id')
218+
rerank_score = get(result, '_rerank_score')
219+
raw_score = get(result, '_score') or 0
220220
self.scores[int(_id)] = {
221-
'raw': get(result, '_score'),
222-
'normalized': (get(result, '_score') or 0) / max_score
223-
} if normalized_score else get(result, '_score')
221+
'raw': raw_score,
222+
'rerank': rerank_score,
223+
'normalized': rerank_score if encoder else (raw_score / max_score)
224+
} if normalized_score else raw_score
224225
highlight = get(result, 'highlight')
225226
if highlight:
226227
self.highlights[int(_id)] = highlight.to_dict()
227-
print("Highlights/Score", time.time() - start_time)
228228
if self.document and self.document.__name__ == 'RepoDocument':
229229
from core.sources.models import Source
230230
from core.collections.models import Collection
@@ -308,12 +308,14 @@ def append_to_bucket(_bucket, _score, count):
308308

309309
return [build_confidence(high), build_confidence(medium), build_confidence(low)]
310310

311-
def __get_response(self, exact_count=True):
311+
def __get_response(self, exact_count=True, load_fields=False):
312312
# Do not query again if the es result is already cached
313313
total = None
314314
if not hasattr(self._dsl_search, '_response'):
315315
# We only need the meta fields with the models ids
316-
s = self._dsl_search.source(False)
316+
s = self._dsl_search.source(
317+
excludes=['_embeddings', '_synonyms_embeddings']
318+
) if load_fields else self._dsl_search.source(False)
317319
s = s.params(request_timeout=ES_REQUEST_TIMEOUT)
318320
if exact_count:
319321
total = s.count()

core/common/serializers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def validate_identifier(value):
173173
class SearchResultSerializer(Serializer): # pylint: disable=abstract-method
174174
match_type = CharField(source='_match_type', allow_null=True, allow_blank=True)
175175
search_score = FloatField(source='_score', allow_null=True)
176+
search_rerank_score = FloatField(source='_rerank_score', allow_null=True)
176177
search_confidence = CharField(source='_confidence', allow_null=True, allow_blank=True)
177178
search_highlight = SerializerMethodField()
178179

core/common/utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from requests import ConnectTimeout
2828
from requests.auth import HTTPBasicAuth
2929
from rest_framework.utils import encoders
30+
from sentence_transformers import CrossEncoder
3031

3132
from core.common.constants import UPDATED_SINCE_PARAM, BULK_IMPORT_QUEUES_COUNT, CURRENT_USER, REQUEST_URL, \
3233
TEMP_PREFIX
@@ -927,3 +928,51 @@ def get_embeddings(txt):
927928
from sentence_transformers import SentenceTransformer
928929
model = SentenceTransformer(settings.LM_MODEL_NAME)
929930
return model.encode(str(txt))
931+
932+
933+
ENCODERS = [
934+
# Best and Fastest overall lightweight medical reranker
935+
# Size: ~110M
936+
# Speed: similar to MiniLM CrossEncoder
937+
# Training: includes clinical, medical, question-answering datasets
938+
# Output: positive similarity scores (not raw logits!)
939+
# 0.6B params
940+
# https://huggingface.co/BAAI/bge-reranker-v2-m3
941+
"BAAI/bge-reranker-v2-m3",
942+
943+
# Model: jinhybr/OA-MedBERT-cross-encoder or similar
944+
# Size: ~110M
945+
# Domain: PubMed abstracts, biomedical QA
946+
# Type: binary classifier (logits)
947+
# Not huggin face model -- ???
948+
# "jinhybr/OA-MedBERT-cross-encoder",
949+
950+
# Model: microsoft/BioLinkBERT-base
951+
# Type: CrossEncoder
952+
# Size: ~120M
953+
# Domain: UMLS, PubMed, MeSH, SNOMED (closest to OCL)
954+
# Not huggin face model -- doesn't work with sentence_transformers
955+
# "microsoft/BioLinkBERT-base",
956+
957+
# 22.7M params
958+
# https://huggingface.co/cross-encoder/ms-marco-MiniLM-L6-v2
959+
# doesn't work with logits, so not between 0-1
960+
"cross-encoder/ms-marco-MiniLM-L-6-v2",
961+
]
962+
963+
ENCODER = CrossEncoder(ENCODERS[0], device="cpu")
964+
965+
def get_encoder(model):
966+
if model in ENCODERS:
967+
return CrossEncoder(model, device="cpu")
968+
return ENCODER
969+
970+
971+
def get_cross_encoder(txt, hits, model=None):
972+
docs = [get(dict(hit["_source"]), 'name') for hit in hits]
973+
encoder = get_encoder(model) if model else ENCODER
974+
scores = encoder.predict([(txt, d) for d in docs])
975+
976+
for hit, score in zip(hits, scores):
977+
hit["_rerank_score"] = float(score)
978+
return hits

core/concepts/views.py

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,7 @@ def get_serializer_class(self):
797797

798798
return ConceptListSerializer
799799

800-
def filter_queryset(self, _=None): # pylint:disable=too-many-locals,too-many-statements
800+
def filter_queryset(self, _=None): # pylint: disable=too-many-locals
801801
rows = self.request.data.get('rows')
802802
target_repo_url = self.request.data.get('target_repo_url')
803803
target_repo_params = self.request.data.get('target_repo')
@@ -808,8 +808,8 @@ def filter_queryset(self, _=None): # pylint:disable=too-many-locals,too-many-st
808808
map_config = self.request.data.get('map_config', [])
809809
filters = self.request.data.get('filter', {})
810810
include_retired = self.request.query_params.get(INCLUDE_RETIRED_PARAM) in get_truthy_values()
811-
num_candidates = min(to_int(self.request.query_params.get('numCandidates', 0), 2000), 2000)
812-
k_nearest = min(to_int(self.request.query_params.get('kNearest', 0), 50), 50)
811+
num_candidates = min(to_int(self.request.query_params.get('numCandidates', 0), 3000), 3000)
812+
k_nearest = min(to_int(self.request.query_params.get('kNearest', 0), 100), 100)
813813
offset = max(to_int(self.request.GET.get('offset'), 0), 0)
814814
limit = max(to_int(self.request.GET.get('limit'), 0), 0) or self.default_limit
815815
page = max(to_int(self.request.GET.get('page'), 1), 1)
@@ -823,57 +823,82 @@ def filter_queryset(self, _=None): # pylint:disable=too-many-locals,too-many-st
823823
locale_filter = filters.pop('locale', None) if is_semantic else get(filters, 'locale', None)
824824
faceted_criterion = self.get_faceted_criterion(False, filters, minimum_should_match=1) if filters else None
825825
apply_for_name_locale = locale_filter and isinstance(locale_filter, str) and len(locale_filter.split(',')) == 1
826+
encoder_model = self.request.GET.get('encoder_model', None)
827+
reranker = self.request.GET.get('reranker', None) in get_truthy_values() # enables reranker
828+
reranker = reranker and self.request.user.is_mapper_cross_encoder_group
829+
score_to_sort = 'search_rerank_score' if reranker else 'search_normalized_score'
826830
results = []
827-
import time
828831
for row in rows:
829-
start_time = time.time()
830832
search = ConceptFuzzySearch.search(
831833
row, target_repo_url, repo_params, include_retired,
832834
is_semantic, num_candidates, k_nearest, map_config, faceted_criterion, locale_filter
833835
)
834-
print("Search Query", time.time() - start_time)
835-
start_time = time.time()
836836
search = search.params(track_total_hits=False, request_cache=True)
837837
es_search = CustomESSearch(search[start:end], ConceptDocument)
838-
es_search.to_queryset(False, True, False)
839-
print("Search to Queryset", time.time() - start_time)
838+
name = row.get('name') or row.get('Name') if reranker else None
839+
es_search.to_queryset(False, True, False, name, encoder_model)
840840
result = {'row': row, 'results': [], 'map_config': map_config, 'filter': filters}
841-
start_time = time.time()
842841
for concept in es_search.queryset:
843842
concept._highlight = es_search.highlights.get(concept.id, {}) # pylint:disable=protected-access
844843
score_info = es_search.scores.get(concept.id, {})
845-
score = get(score_info, 'raw') or None
846-
normalized_score = get(score_info, 'normalized') or None
847-
concept._score = score # pylint:disable=protected-access
848-
concept._normalized_score = normalized_score # pylint:disable=protected-access
849-
if limit > 1:
850-
concept._match_type = 'low' # pylint:disable=protected-access
851-
score_to_check = normalized_score if normalized_score is not None else score
852-
if concept._highlight.get('name', None) or (is_semantic and score_to_check >= score_threshold): # pylint:disable=protected-access
853-
concept._match_type = 'very_high' # pylint:disable=protected-access
854-
elif concept._highlight.get('synonyms', None): # pylint:disable=protected-access
855-
concept._match_type = 'high' # pylint:disable=protected-access
856-
elif concept._highlight: # pylint:disable=protected-access
857-
concept._match_type = 'medium' # pylint:disable=protected-access
858-
else:
859-
concept._match_type = 'very_high' # pylint:disable=protected-access
844+
normalized_score = get(score_info, 'normalized') or 0
845+
self.apply_score(concept, is_semantic, score_info, score_threshold, reranker, limit)
860846
if not best_match or concept._match_type in ['medium', 'high', 'very_high']: # pylint:disable=protected-access
861847
if apply_for_name_locale:
862848
concept._requested_locale = locale_filter # pylint:disable=protected-access
863849
serializer = ConceptDetailSerializer if self.is_verbose() else ConceptMinimalSerializer
864850
data = serializer(concept, context={'request': self.request}).data
865851
data['search_meta']['search_normalized_score'] = normalized_score * 100
866852
result['results'].append(data)
867-
print("Queryset to Serializer", time.time() - start_time)
868-
start_time = time.time()
869853
if 'results' in result:
870854
result['results'] = sorted(
871-
result['results'], key=lambda res: get(res, 'search_meta.search_normalized_score'), reverse=True)
855+
result['results'], key=lambda res: get(res, f'search_meta.{score_to_sort}'), reverse=True)
872856
results.append(result)
873-
print("Sorting", time.time() - start_time)
874857

875858
return results
876859

860+
@staticmethod
861+
def apply_score(concept, is_semantic, scores, score_threshold, reranker, limit): # pylint: disable=too-many-arguments,too-many-branches
862+
score = get(scores, 'raw') or 0
863+
normalized_score = get(scores, 'normalized') or 0
864+
rerank_score = get(scores, 'rerank') or 0
865+
866+
concept._score = score # pylint:disable=protected-access
867+
concept._normalized_score = normalized_score # pylint:disable=protected-access
868+
if reranker:
869+
concept._rerank_score = rerank_score # pylint:disable=protected-access
870+
highlight = concept._highlight # pylint:disable=protected-access
871+
872+
match_type = 'low'
873+
if limit > 1:
874+
if is_semantic:
875+
if reranker:
876+
if normalized_score >= 0.9:
877+
match_type = 'very_high'
878+
elif normalized_score >= 0.65:
879+
match_type = 'high'
880+
elif normalized_score >= 0.5:
881+
match_type = 'medium'
882+
else:
883+
score_to_check = normalized_score if normalized_score is not None else score
884+
if highlight.get('name', None) or score_to_check >= score_threshold:
885+
match_type = 'very_high'
886+
elif highlight.get('synonyms', None):
887+
match_type = 'high'
888+
elif highlight:
889+
match_type = 'medium'
890+
else:
891+
if highlight.get('name', None):
892+
match_type = 'very_high'
893+
elif highlight.get('synonyms', None):
894+
match_type = 'high'
895+
elif highlight:
896+
match_type = 'medium'
897+
else:
898+
match_type = 'very_high'
899+
900+
concept._match_type = match_type # pylint:disable=protected-access
901+
877902
@staticmethod
878903
def get_repo_params(is_semantic, target_repo_params, target_repo_url):
879904
repo = ConceptFuzzySearch.get_target_repo(target_repo_url)

core/fixtures/auth_groups.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,7 @@
4646
pk: 12
4747
fields:
4848
name: superadmin_user
49+
- model: "auth.group"
50+
pk: 13
51+
fields:
52+
name: mapper_cross_encoder
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Generated by Django 4.2.16 on 2025-12-05 03:45
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
8+
dependencies = [
9+
('map_projects', '0018_mapproject_candidates'),
10+
]
11+
12+
operations = [
13+
migrations.AddField(
14+
model_name='mapproject',
15+
name='reranker',
16+
field=models.BooleanField(default=False),
17+
),
18+
]

core/map_projects/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class MapProject(BaseModel):
3636
score_configuration = models.JSONField(default=default_score_configuration, null=True, blank=True)
3737
filters = models.JSONField(default=dict, null=True, blank=True)
3838
candidates = models.JSONField(default=dict, null=True, blank=True)
39+
reranker = models.BooleanField(default=False)
3940

4041
# Custom API
4142
match_api_url = models.TextField(null=True, blank=True)

core/map_projects/serializers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class Meta:
2222
'created_by', 'updated_by', 'created_at', 'updated_at', 'url', 'is_active',
2323
'public_access', 'file', 'user_id', 'organization_id', 'description',
2424
'target_repo_url', 'matching_algorithm', 'include_retired', 'score_configuration',
25-
'match_api_url', 'match_api_token', 'batch_size', 'filters', 'candidates'
25+
'match_api_url', 'match_api_token', 'batch_size', 'filters', 'candidates', 'reranker'
2626
]
2727

2828
def prepare_object(self, validated_data, instance=None, file=None):
@@ -35,7 +35,7 @@ def prepare_object(self, validated_data, instance=None, file=None):
3535
if columns is not False:
3636
instance.columns = columns
3737
for attr in [
38-
'name', 'description', 'extras', 'target_repo_url', 'matching_algorithm', 'include_retired',
38+
'name', 'description', 'extras', 'target_repo_url', 'matching_algorithm', 'include_retired', 'reranker',
3939
'score_configuration', 'match_api_url', 'match_api_token', 'batch_size', 'filters', 'candidates'
4040
]:
4141
setattr(instance, attr, validated_data.get(attr, get(instance, attr)))
@@ -90,7 +90,8 @@ class Meta:
9090
'created_by', 'updated_by', 'created_at', 'updated_at', 'url', 'is_active',
9191
'owner', 'owner_type', 'owner_url', 'public_access',
9292
'target_repo_url', 'matching_algorithm', 'summary', 'logs', 'include_retired',
93-
'score_configuration', 'match_api_url', 'match_api_token', 'batch_size', 'filters', 'candidates'
93+
'score_configuration', 'match_api_url', 'match_api_token', 'batch_size', 'filters', 'candidates',
94+
'reranker'
9495
]
9596

9697
def __init__(self, *args, **kwargs):

core/users/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
MAPPER_AI_ASSISTANT_GROUP = 'mapper_ai_assistant'
1313
MAPPER_WAITLIST_GROUP = 'mapper-waitlist'
1414
MAPPER_APPROVED_GROUP = 'mapper-approved'
15+
MAPPER_CROSS_ENCODER_GROUP = 'mapper_cross_encoder'
1516
EARLY_ACCESS_NGO_GROUP = 'early_access_ngo'
1617
GUEST_GROUP = 'guest_user'
1718
STANDARD_GROUP = 'standard_user'
@@ -33,5 +34,6 @@
3334
PREMIUM_GROUP,
3435
STAFF_GROUP,
3536
SUPERADMIN_GROUP,
37+
MAPPER_CROSS_ENCODER_GROUP
3638
]
3739
INVALID_AUTH_GROUP_NAME = 'Invalid auth group.'

core/users/models.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from core.common.models import BaseModel, CommonLogoModel
1414
from core.common.tasks import send_user_verification_email, send_user_reset_password_email
1515
from core.common.utils import web_url
16-
from core.users.constants import AUTH_GROUPS, MAPPER_WAITLIST_GROUP, STAFF_GROUP, SUPERADMIN_GROUP, GUEST_GROUP
16+
from core.users.constants import AUTH_GROUPS, MAPPER_WAITLIST_GROUP, STAFF_GROUP, SUPERADMIN_GROUP, GUEST_GROUP, \
17+
MAPPER_APPROVED_GROUP, MAPPER_CROSS_ENCODER_GROUP
1718
from .constants import USER_OBJECT_TYPE
1819
from ..common.checksums import ChecksumModel
1920

@@ -226,6 +227,14 @@ def has_auth_group(self, group_name):
226227
def is_mapper_waitlisted(self):
227228
return self.has_auth_group(MAPPER_WAITLIST_GROUP)
228229

230+
@property
231+
def is_mapper_approved(self):
232+
return self.has_auth_group(MAPPER_APPROVED_GROUP)
233+
234+
@property
235+
def is_mapper_cross_encoder_group(self):
236+
return self.has_auth_group(MAPPER_CROSS_ENCODER_GROUP)
237+
229238
@property
230239
def is_guest_group(self):
231240
return self.has_auth_group(GUEST_GROUP)

0 commit comments

Comments
 (0)