diff --git a/libs/langchain-mongodb/langchain_mongodb/index.py b/libs/langchain-mongodb/langchain_mongodb/index.py index ada1df0..0918900 100644 --- a/libs/langchain-mongodb/langchain_mongodb/index.py +++ b/libs/langchain-mongodb/langchain_mongodb/index.py @@ -2,7 +2,7 @@ import logging from time import monotonic, sleep -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from pymongo.collection import Collection from pymongo.operations import SearchIndexModel @@ -202,7 +202,7 @@ def _wait_for_predicate( def create_fulltext_search_index( collection: Collection, index_name: str, - field: str, + field: Union[str, List[str]], *, wait_until_complete: Optional[float] = None, **kwargs: Any, @@ -222,9 +222,11 @@ def create_fulltext_search_index( if collection.name not in collection.database.list_collection_names(): collection.database.create_collection(collection.name) - definition = { - "mappings": {"dynamic": False, "fields": {field: [{"type": "string"}]}} - } + if isinstance(field, str): + fields_definition = {field: [{"type": "string"}]} + else: + fields_definition = {f: [{"type": "string"}] for f in field} + definition = {"mappings": {"dynamic": False, "fields": fields_definition}} result = collection.create_search_index( SearchIndexModel( definition=definition, diff --git a/libs/langchain-mongodb/langchain_mongodb/pipelines.py b/libs/langchain-mongodb/langchain_mongodb/pipelines.py index 85ada4f..a4baf40 100644 --- a/libs/langchain-mongodb/langchain_mongodb/pipelines.py +++ b/libs/langchain-mongodb/langchain_mongodb/pipelines.py @@ -7,12 +7,12 @@ - `Filter Example `_ """ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union def text_search_stage( query: str, - search_field: str, + search_field: Union[str, List[str]], index_name: str, limit: Optional[int] = None, filter: Optional[Dict[str, Any]] = None, diff --git a/libs/langchain-mongodb/langchain_mongodb/retrievers/full_text_search.py b/libs/langchain-mongodb/langchain_mongodb/retrievers/full_text_search.py index bdedb51..11a4f12 100644 --- a/libs/langchain-mongodb/langchain_mongodb/retrievers/full_text_search.py +++ b/libs/langchain-mongodb/langchain_mongodb/retrievers/full_text_search.py @@ -1,4 +1,4 @@ -from typing import Annotated, Any, Dict, List, Optional +from typing import Annotated, Any, Dict, List, Optional, Union from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun from langchain_core.documents import Document @@ -17,7 +17,7 @@ class MongoDBAtlasFullTextSearchRetriever(BaseRetriever): """MongoDB Collection on an Atlas cluster""" search_index_name: str """Atlas Search Index name""" - search_field: str + search_field: Union[str, List[str]] """Collection field that contains the text to be searched. It must be indexed""" k: Optional[int] = None """Number of documents to return. Default is no limit""" @@ -61,7 +61,11 @@ def _get_relevant_documents( # Formatting docs = [] for res in cursor: - text = res.pop(self.search_field) + text = ( + res.pop(self.search_field) + if isinstance(self.search_field, str) + else res.pop(self.search_field[0]) + ) make_serializable(res) docs.append(Document(page_content=text, metadata=res)) return docs diff --git a/libs/langchain-mongodb/langchain_mongodb/vectorstores.py b/libs/langchain-mongodb/langchain_mongodb/vectorstores.py index 43cb4d8..5924c88 100644 --- a/libs/langchain-mongodb/langchain_mongodb/vectorstores.py +++ b/libs/langchain-mongodb/langchain_mongodb/vectorstores.py @@ -204,7 +204,7 @@ def __init__( collection: Collection[Dict[str, Any]], embedding: Embeddings, index_name: str = "vector_index", - text_key: str = "text", + text_key: Union[str, List[str]] = "text", embedding_key: str = "embedding", relevance_score_fn: str = "cosine", dimensions: int = -1, @@ -216,7 +216,8 @@ def __init__( Args: collection: MongoDB collection to add the texts to embedding: Text embedding model to use - text_key: MongoDB field that will contain the text for each document + text_key: MongoDB field that will contain the text for each document. It is possible to parse a list of fields.\ + The first one will be used as text key. Default: 'text' index_name: Existing Atlas Vector Search Index embedding_key: Field that will contain the embedding for each document relevance_score_fn: The similarity score used for the index @@ -229,7 +230,7 @@ def __init__( self._collection = collection self._embedding = embedding self._index_name = index_name - self._text_key = text_key + self._text_key = text_key if isinstance(text_key, str) else text_key[0] self._embedding_key = embedding_key self._relevance_score_fn = relevance_score_fn diff --git a/libs/langchain-mongodb/tests/integration_tests/test_retrievers_multi_field.py b/libs/langchain-mongodb/tests/integration_tests/test_retrievers_multi_field.py new file mode 100644 index 0000000..0b0bbe8 --- /dev/null +++ b/libs/langchain-mongodb/tests/integration_tests/test_retrievers_multi_field.py @@ -0,0 +1,279 @@ +from time import sleep, time +from typing import Generator, List + +import pytest +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from pymongo import MongoClient +from pymongo.collection import Collection + +from langchain_mongodb import MongoDBAtlasVectorSearch +from langchain_mongodb.index import ( + create_fulltext_search_index, + create_vector_search_index, +) +from langchain_mongodb.retrievers import ( + MongoDBAtlasFullTextSearchRetriever, + MongoDBAtlasHybridSearchRetriever, +) + +from ..utils import DB_NAME, PatchedMongoDBAtlasVectorSearch + +COLLECTION_NAME = "langchain_test_retrievers" +COLLECTION_NAME_NESTED = "langchain_test_retrievers_nested" +VECTOR_INDEX_NAME = "vector_index" +EMBEDDING_FIELD = "embedding" +PAGE_CONTENT_FIELD = ["text", "keywords"] +PAGE_CONTENT_FIELD_NESTED = "title.text" +SEARCH_INDEX_NAME = "text_index_multi" +SEARCH_INDEX_NAME_NESTED = "text_index_nested" + +TIMEOUT = 60.0 +INTERVAL = 0.5 + + +@pytest.fixture(scope="module") +def example_documents() -> List[Document]: + return [ + Document( + page_content="In 2023, I visited Paris", metadata={"keywords": "MongoDB"} + ), + Document( + page_content="In 2022, I visited New York", + metadata={"keywords": "Atlas"}, + ), + Document( + page_content="In 2021, I visited New Orleans", + metadata={"keywords": "Search"}, + ), + Document( + page_content="Sandwiches are beautiful. Sandwiches are fine.", + metadata={"keywords": "is awesome"}, + ), + ] + + +@pytest.fixture(scope="module") +def collection(client: MongoClient, dimensions: int) -> Collection: + """A Collection with both a Vector and a Full-text Search Index""" + if COLLECTION_NAME not in client[DB_NAME].list_collection_names(): + clxn = client[DB_NAME].create_collection(COLLECTION_NAME) + else: + clxn = client[DB_NAME][COLLECTION_NAME] + + clxn.delete_many({}) + + if not any([VECTOR_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]): + create_vector_search_index( + collection=clxn, + index_name=VECTOR_INDEX_NAME, + dimensions=dimensions, + path="embedding", + similarity="cosine", + wait_until_complete=TIMEOUT, + ) + + if not any([SEARCH_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]): + create_fulltext_search_index( + collection=clxn, + index_name=SEARCH_INDEX_NAME, + field=PAGE_CONTENT_FIELD, + wait_until_complete=TIMEOUT, + ) + + return clxn + + +@pytest.fixture(scope="module") +def collection_nested(client: MongoClient, dimensions: int) -> Collection: + """A Collection with both a Vector and a Full-text Search Index""" + if COLLECTION_NAME_NESTED not in client[DB_NAME].list_collection_names(): + clxn = client[DB_NAME].create_collection(COLLECTION_NAME_NESTED) + else: + clxn = client[DB_NAME][COLLECTION_NAME_NESTED] + + clxn.delete_many({}) + + if not any([VECTOR_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]): + create_vector_search_index( + collection=clxn, + index_name=VECTOR_INDEX_NAME, + dimensions=dimensions, + path="embedding", + similarity="cosine", + wait_until_complete=TIMEOUT, + ) + + if not any( + [SEARCH_INDEX_NAME_NESTED == ix["name"] for ix in clxn.list_search_indexes()] + ): + create_fulltext_search_index( + collection=clxn, + index_name=SEARCH_INDEX_NAME_NESTED, + field=PAGE_CONTENT_FIELD_NESTED, + wait_until_complete=TIMEOUT, + ) + + return clxn + + +@pytest.fixture(scope="module") +def indexed_vectorstore( + collection: Collection, + example_documents: List[Document], + embedding: Embeddings, +) -> Generator[MongoDBAtlasVectorSearch, None, None]: + """Return a VectorStore with example document embeddings indexed.""" + + vectorstore = PatchedMongoDBAtlasVectorSearch( + collection=collection, + embedding=embedding, + index_name=VECTOR_INDEX_NAME, + text_key=PAGE_CONTENT_FIELD, + ) + + vectorstore.add_documents(example_documents) + + yield vectorstore + + vectorstore.collection.delete_many({}) + + +@pytest.fixture(scope="module") +def indexed_nested_vectorstore( + collection_nested: Collection, + example_documents: List[Document], + embedding: Embeddings, +) -> Generator[MongoDBAtlasVectorSearch, None, None]: + """Return a VectorStore with example document embeddings indexed.""" + + vectorstore = PatchedMongoDBAtlasVectorSearch( + collection=collection_nested, + embedding=embedding, + index_name=VECTOR_INDEX_NAME, + text_key=PAGE_CONTENT_FIELD_NESTED, + ) + + vectorstore.add_documents(example_documents) + + yield vectorstore + + vectorstore.collection.delete_many({}) + + +def test_vector_retriever(indexed_vectorstore: PatchedMongoDBAtlasVectorSearch) -> None: + """Test VectorStoreRetriever""" + retriever = indexed_vectorstore.as_retriever() + + query1 = "When did I visit France?" + results = retriever.invoke(query1) + assert len(results) == 4 + assert "Paris" in results[0].page_content + assert "MongoDB" == results[0].metadata["keywords"] + + query2 = "When was the last time I visited new orleans?" + results = retriever.invoke(query2) + assert "New Orleans" in results[0].page_content + assert "Search" == results[0].metadata["keywords"] + + +def test_hybrid_retriever(indexed_vectorstore: PatchedMongoDBAtlasVectorSearch) -> None: + """Test basic usage of MongoDBAtlasHybridSearchRetriever""" + + retriever = MongoDBAtlasHybridSearchRetriever( + vectorstore=indexed_vectorstore, + search_index_name=SEARCH_INDEX_NAME, + k=3, + ) + + query1 = "When did I visit France?" + results = retriever.invoke(query1) + assert len(results) == 3 + assert "Paris" in results[0].page_content + + query2 = "When was the last time I visited new orleans?" + results = retriever.invoke(query2) + assert "New Orleans" in results[0].page_content + + +def test_hybrid_retriever_deprecated_top_k( + indexed_vectorstore: PatchedMongoDBAtlasVectorSearch, +) -> None: + """Test basic usage of MongoDBAtlasHybridSearchRetriever""" + retriever = MongoDBAtlasHybridSearchRetriever( + vectorstore=indexed_vectorstore, + search_index_name=SEARCH_INDEX_NAME, + top_k=3, + ) + + query1 = "When did I visit France?" + results = retriever.invoke(query1) + assert len(results) == 3 + assert "Paris" in results[0].page_content + + query2 = "When was the last time I visited new orleans?" + results = retriever.invoke(query2) + assert "New Orleans" in results[0].page_content + + +def test_hybrid_retriever_nested( + indexed_nested_vectorstore: PatchedMongoDBAtlasVectorSearch, +) -> None: + """Test basic usage of MongoDBAtlasHybridSearchRetriever""" + retriever = MongoDBAtlasHybridSearchRetriever( + vectorstore=indexed_nested_vectorstore, + search_index_name=SEARCH_INDEX_NAME_NESTED, + k=3, + ) + + query1 = "What did I visit France?" + results = retriever.invoke(query1) + assert len(results) == 3 + assert "Paris" in results[0].page_content + + query2 = "When was the last time I visited new orleans?" + results = retriever.invoke(query2) + assert "New Orleans" in results[0].page_content + + +def test_fulltext_retriever( + indexed_vectorstore: PatchedMongoDBAtlasVectorSearch, +) -> None: + """Test result of performing fulltext search. + + The Retriever is independent of the VectorStore. + We use it here only to get the Collection, which we know to be indexed. + """ + + collection: Collection = indexed_vectorstore.collection + + retriever = MongoDBAtlasFullTextSearchRetriever( + collection=collection, + search_index_name=SEARCH_INDEX_NAME, + search_field=PAGE_CONTENT_FIELD, + ) + + # Wait for the search index to complete. + search_content = dict( + index=SEARCH_INDEX_NAME, + wildcard=dict(query="*", path=PAGE_CONTENT_FIELD, allowAnalyzedField=True), + ) + n_docs = collection.count_documents({}) + t0 = time() + while True: + if (time() - t0) > TIMEOUT: + raise TimeoutError( + f"Search index {SEARCH_INDEX_NAME} did not complete in {TIMEOUT}" + ) + cursor = collection.aggregate([{"$search": search_content}]) + if len(list(cursor)) == n_docs: + break + sleep(INTERVAL) + + query = "What is MongoDB" + results = retriever.invoke(query) + print(results) + print(list(collection.list_search_indexes())) + # assert "New Orleans" in results[0].page_content + assert "MongoDB" in results[0].metadata["keywords"] + assert "score" in results[0].metadata