Skip to content

Commit a98de0b

Browse files
authored
Regenerate course metadata when course information updates (#2419)
* ensuring that changed metadata docs get re-embedded * adding test * remove existing metadata documents * fixing bug with missing key sent in params * adding types * adding type
1 parent 8ca3a45 commit a98de0b

File tree

4 files changed

+141
-24
lines changed

4 files changed

+141
-24
lines changed

learning_resources/models.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import uuid
44
from abc import abstractmethod
55
from functools import cached_property
6-
from hashlib import md5
76
from typing import TYPE_CHECKING, Optional
87

98
from django.conf import settings
@@ -25,6 +24,7 @@
2524
PrivacyLevel,
2625
)
2726
from main.models import TimestampedModel, TimestampedModelQuerySet
27+
from main.utils import checksum_for_content
2828

2929
if TYPE_CHECKING:
3030
from django.contrib.auth import get_user_model
@@ -942,15 +942,8 @@ class ContentFile(TimestampedModel):
942942
summary = models.TextField(blank=True, default="")
943943
flashcards = models.JSONField(blank=True, default=list)
944944

945-
def content_checksum(self):
946-
hasher = md5() # noqa: S324
947-
if self.content:
948-
hasher.update(self.content.encode("utf-8"))
949-
return hasher.hexdigest()
950-
return None
951-
952945
def save(self, **kwargs):
953-
self.checksum = self.content_checksum()
946+
self.checksum = checksum_for_content(self.content)
954947
super().save(**kwargs)
955948

956949
class Meta:

main/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
from enum import Flag, auto
77
from functools import wraps
8+
from hashlib import md5
89
from itertools import islice
910
from urllib.parse import urljoin
1011

@@ -364,3 +365,14 @@ def clear_search_cache():
364365
search_keys = cache.keys("views.decorators.cache.cache_header.search.*")
365366
cleared += cache.delete_many(search_keys) or 0
366367
return cleared
368+
369+
370+
def checksum_for_content(content: str) -> str:
371+
"""
372+
Generate a checksum based on the provided content string
373+
"""
374+
hasher = md5() # noqa: S324
375+
if content:
376+
hasher.update(content.encode("utf-8"))
377+
return hasher.hexdigest()
378+
return None

vector_search/utils.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import uuid
3+
from typing import Optional
34

45
from django.conf import settings
56
from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -21,6 +22,7 @@
2122
serialize_bulk_content_files,
2223
serialize_bulk_learning_resources,
2324
)
25+
from main.utils import checksum_for_content
2426
from vector_search.constants import (
2527
CONTENT_FILES_COLLECTION_NAME,
2628
QDRANT_CONTENT_FILE_INDEXES,
@@ -253,11 +255,13 @@ def update_learning_resource_payload(serialized_document):
253255

254256

255257
def update_content_file_payload(serialized_document):
256-
params = {
257-
"resource_readable_id": serialized_document["resource_readable_id"],
258-
"key": serialized_document["key"],
259-
"run_readable_id": serialized_document["run_readable_id"],
260-
}
258+
search_keys = ["resource_readable_id", "key", "run_readable_id"]
259+
params = {}
260+
for key in search_keys:
261+
if key in serialized_document:
262+
params[key] = serialized_document[key]
263+
if not params:
264+
return
261265
points = [
262266
point.id
263267
for point in retrieve_points_matching_params(
@@ -319,18 +323,20 @@ def should_generate_resource_embeddings(serialized_document):
319323
return True
320324

321325

322-
def should_generate_content_embeddings(serialized_document):
326+
def should_generate_content_embeddings(
327+
serialized_document: dict, point_id: Optional[str] = None
328+
) -> bool:
323329
"""
324330
Determine if we should generate embeddings for a content file
325331
"""
326332
client = qdrant_client()
327-
328-
# we just need metadata from the first chunk
329-
point_id = vector_point_id(
330-
f"{serialized_document['resource_readable_id']}."
331-
f"{serialized_document.get('run_readable_id', '')}."
332-
f"{serialized_document['key']}.0"
333-
)
333+
if not point_id:
334+
# we just need metadata from the first chunk
335+
point_id = vector_point_id(
336+
f"{serialized_document['resource_readable_id']}."
337+
f"{serialized_document.get('run_readable_id', '')}."
338+
f"{serialized_document['key']}.0"
339+
)
334340
response = client.retrieve(
335341
collection_name=CONTENT_FILES_COLLECTION_NAME,
336342
ids=[point_id],
@@ -353,19 +359,37 @@ def _embed_course_metadata_as_contentfile(serialized_resources):
353359
ids = []
354360
docs = []
355361
for doc in serialized_resources:
356-
if not should_generate_resource_embeddings(doc):
357-
continue
358362
readable_id = doc["readable_id"]
359363
resource_vector_point_id = str(vector_point_id(readable_id))
360364
serializer = LearningResourceMetadataDisplaySerializer(doc)
365+
serialized_document = serializer.render_document()
366+
checksum = checksum_for_content(str(serialized_document))
367+
key = f"{doc['readable_id']}.course_metadata"
368+
serialized_document["checksum"] = checksum
369+
serialized_document["key"] = key
370+
document_point_id = vector_point_id(
371+
f"{doc['readable_id']}.course_information.0"
372+
)
373+
if not should_generate_content_embeddings(
374+
serialized_document, document_point_id
375+
):
376+
continue
377+
# remove existing course info docs
378+
remove_points_matching_params(
379+
{"key": key}, collection_name=CONTENT_FILES_COLLECTION_NAME
380+
)
361381
split_texts = serializer.render_chunks()
362382
split_metadatas = [
363383
{
364384
"resource_point_id": str(resource_vector_point_id),
365385
"chunk_number": chunk_id,
366386
"chunk_content": chunk_content,
367387
"resource_readable_id": doc["readable_id"],
388+
"run_readable_id": doc["readable_id"],
368389
"file_extension": ".txt",
390+
"file_type": "course_metadata",
391+
"key": key,
392+
"checksum": checksum,
369393
**{key: doc[key] for key in ["offered_by", "platform"]},
370394
}
371395
for chunk_id, chunk_content in enumerate(split_texts)

vector_search/utils_test.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
LearningResourceRunFactory,
1313
)
1414
from learning_resources.models import LearningResource
15+
from learning_resources.serializers import LearningResourceMetadataDisplaySerializer
1516
from learning_resources_search.serializers import (
1617
serialize_bulk_content_files,
1718
serialize_bulk_learning_resources,
1819
)
20+
from main.utils import checksum_for_content
1921
from vector_search.constants import (
2022
CONTENT_FILES_COLLECTION_NAME,
2123
QDRANT_CONTENT_FILE_PARAM_MAP,
@@ -635,3 +637,89 @@ def test_embed_learning_resources_summarizes_only_contentfiles_with_summary(mock
635637
# Only contentfiles with summary should be passed
636638
expected_ids = [cf.id for cf in contentfiles_with_summary]
637639
summarize_mock.assert_called_once_with(expected_ids, True) # noqa: FBT003
640+
641+
642+
@pytest.mark.django_db
643+
def test_embed_course_metadata_as_contentfile_uploads_points_on_change(mocker):
644+
"""
645+
Test that _embed_course_metadata_as_contentfile uploads points to Qdrant
646+
if any property of a serialized_resource has changed
647+
"""
648+
649+
mock_client = mocker.patch("vector_search.utils.qdrant_client").return_value
650+
mock_encoder = mocker.patch("vector_search.utils.dense_encoder").return_value
651+
mock_encoder.model_short_name.return_value = "test-model"
652+
mock_encoder.embed_documents.return_value = [[0.1, 0.2, 0.3]]
653+
resource = LearningResourceFactory.create()
654+
serialized_resource = next(serialize_bulk_learning_resources([resource.id]))
655+
serializer = LearningResourceMetadataDisplaySerializer(serialized_resource)
656+
rendered_document = serializer.render_document()
657+
resource_checksum = checksum_for_content(str(rendered_document))
658+
659+
"""
660+
Simulate qdrant returning a checksum for existing
661+
record that matches the checksum of metadata doc
662+
"""
663+
mock_point = mocker.Mock()
664+
mock_point.payload = {"checksum": "checksum2"}
665+
mock_client.retrieve.return_value = [mock_point]
666+
667+
_embed_course_metadata_as_contentfile([serialized_resource])
668+
669+
# Assert upload_points was called
670+
assert mock_client.upload_points.called
671+
args, kwargs = mock_client.upload_points.call_args
672+
assert args[0] == CONTENT_FILES_COLLECTION_NAME
673+
points = list(kwargs["points"])
674+
assert len(points) == 1
675+
assert points[0].payload["resource_readable_id"] == resource.readable_id
676+
assert points[0].payload["checksum"] == resource_checksum
677+
678+
# simulate qdrant returning the same checksum for the metadata doc
679+
mock_point.payload = {"checksum": resource_checksum}
680+
mock_client.upload_points.reset_mock()
681+
_embed_course_metadata_as_contentfile([serialized_resource])
682+
683+
# nothing has changed - no updates to make
684+
assert not mock_client.upload_points.called
685+
686+
687+
@pytest.mark.parametrize(
688+
("serialized_document", "expected_params"),
689+
[
690+
(
691+
{"resource_readable_id": "r1", "key": "k1", "run_readable_id": "run1"},
692+
{"resource_readable_id": "r1", "key": "k1", "run_readable_id": "run1"},
693+
),
694+
(
695+
{"resource_readable_id": "r2", "key": "k2"},
696+
{"resource_readable_id": "r2", "key": "k2"},
697+
),
698+
(
699+
{"run_readable_id": "run3"},
700+
{"run_readable_id": "run3"},
701+
),
702+
({"test": "run3"}, None),
703+
],
704+
)
705+
def test_update_content_file_payload_only_includes_existing_keys(
706+
mocker, serialized_document, expected_params
707+
):
708+
"""
709+
Test that params only includes keys
710+
that are defined in the input document
711+
"""
712+
mock_retrieve = mocker.patch(
713+
"vector_search.utils.retrieve_points_matching_params", return_value=[]
714+
)
715+
mocker.patch("vector_search.utils._set_payload")
716+
717+
update_content_file_payload(serialized_document)
718+
if expected_params:
719+
# Check that retrieve_points_matching_params was called with only the expected keys
720+
mock_retrieve.assert_called_once_with(
721+
expected_params,
722+
collection_name=CONTENT_FILES_COLLECTION_NAME,
723+
)
724+
else:
725+
mock_retrieve.assert_not_called()

0 commit comments

Comments
 (0)