Skip to content

INTPYTHON-687 Add multi_field support for $search operation #166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions libs/langchain-mongodb/langchain_mongodb/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from time import monotonic, sleep
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Union

from pymongo.collection import Collection
from pymongo.operations import SearchIndexModel
Expand Down Expand Up @@ -202,7 +202,7 @@ def _wait_for_predicate(
def create_fulltext_search_index(
collection: Collection,
index_name: str,
field: str,
field: Union[str, List[str]],
*,
wait_until_complete: Optional[float] = None,
**kwargs: Any,
Expand All @@ -222,9 +222,11 @@ def create_fulltext_search_index(
if collection.name not in collection.database.list_collection_names():
collection.database.create_collection(collection.name)

definition = {
"mappings": {"dynamic": False, "fields": {field: [{"type": "string"}]}}
}
if isinstance(field, str):
fields_definition = {field: [{"type": "string"}]}
else:
fields_definition = {f: [{"type": "string"}] for f in field}
definition = {"mappings": {"dynamic": False, "fields": fields_definition}}
result = collection.create_search_index(
SearchIndexModel(
definition=definition,
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain-mongodb/langchain_mongodb/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
- `Filter Example <https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter>`_
"""

from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union


def text_search_stage(
query: str,
search_field: str,
search_field: Union[str, List[str]],
index_name: str,
limit: Optional[int] = None,
filter: Optional[Dict[str, Any]] = None,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated, Any, Dict, List, Optional
from typing import Annotated, Any, Dict, List, Optional, Union

from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
Expand All @@ -17,7 +17,7 @@ class MongoDBAtlasFullTextSearchRetriever(BaseRetriever):
"""MongoDB Collection on an Atlas cluster"""
search_index_name: str
"""Atlas Search Index name"""
search_field: str
search_field: Union[str, List[str]]
"""Collection field that contains the text to be searched. It must be indexed"""
k: Optional[int] = None
"""Number of documents to return. Default is no limit"""
Expand Down Expand Up @@ -61,7 +61,11 @@ def _get_relevant_documents(
# Formatting
docs = []
for res in cursor:
text = res.pop(self.search_field)
text = (
res.pop(self.search_field)
if isinstance(self.search_field, str)
else res.pop(self.search_field[0])
)
make_serializable(res)
docs.append(Document(page_content=text, metadata=res))
return docs
7 changes: 4 additions & 3 deletions libs/langchain-mongodb/langchain_mongodb/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def __init__(
collection: Collection[Dict[str, Any]],
embedding: Embeddings,
index_name: str = "vector_index",
text_key: str = "text",
text_key: Union[str, List[str]] = "text",
embedding_key: str = "embedding",
relevance_score_fn: str = "cosine",
dimensions: int = -1,
Expand All @@ -216,7 +216,8 @@ def __init__(
Args:
collection: MongoDB collection to add the texts to
embedding: Text embedding model to use
text_key: MongoDB field that will contain the text for each document
text_key: MongoDB field that will contain the text for each document. It is possible to parse a list of fields.\
The first one will be used as text key. Default: 'text'
index_name: Existing Atlas Vector Search Index
embedding_key: Field that will contain the embedding for each document
relevance_score_fn: The similarity score used for the index
Expand All @@ -229,7 +230,7 @@ def __init__(
self._collection = collection
self._embedding = embedding
self._index_name = index_name
self._text_key = text_key
self._text_key = text_key if isinstance(text_key, str) else text_key[0]
self._embedding_key = embedding_key
self._relevance_score_fn = relevance_score_fn

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
from time import sleep, time
from typing import Generator, List

import pytest
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from pymongo import MongoClient
from pymongo.collection import Collection

from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_mongodb.index import (
create_fulltext_search_index,
create_vector_search_index,
)
from langchain_mongodb.retrievers import (
MongoDBAtlasFullTextSearchRetriever,
MongoDBAtlasHybridSearchRetriever,
)

from ..utils import DB_NAME, PatchedMongoDBAtlasVectorSearch

COLLECTION_NAME = "langchain_test_retrievers"
COLLECTION_NAME_NESTED = "langchain_test_retrievers_nested"
VECTOR_INDEX_NAME = "vector_index"
EMBEDDING_FIELD = "embedding"
PAGE_CONTENT_FIELD = ["text", "keywords"]
PAGE_CONTENT_FIELD_NESTED = "title.text"
SEARCH_INDEX_NAME = "text_index_multi"
SEARCH_INDEX_NAME_NESTED = "text_index_nested"

TIMEOUT = 60.0
INTERVAL = 0.5


@pytest.fixture(scope="module")
def example_documents() -> List[Document]:
return [
Document(
page_content="In 2023, I visited Paris", metadata={"keywords": "MongoDB"}
),
Document(
page_content="In 2022, I visited New York",
metadata={"keywords": "Atlas"},
),
Document(
page_content="In 2021, I visited New Orleans",
metadata={"keywords": "Search"},
),
Document(
page_content="Sandwiches are beautiful. Sandwiches are fine.",
metadata={"keywords": "is awesome"},
),
]


@pytest.fixture(scope="module")
def collection(client: MongoClient, dimensions: int) -> Collection:
"""A Collection with both a Vector and a Full-text Search Index"""
if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
else:
clxn = client[DB_NAME][COLLECTION_NAME]

clxn.delete_many({})

if not any([VECTOR_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
create_vector_search_index(
collection=clxn,
index_name=VECTOR_INDEX_NAME,
dimensions=dimensions,
path="embedding",
similarity="cosine",
wait_until_complete=TIMEOUT,
)

if not any([SEARCH_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
create_fulltext_search_index(
collection=clxn,
index_name=SEARCH_INDEX_NAME,
field=PAGE_CONTENT_FIELD,
wait_until_complete=TIMEOUT,
)

return clxn


@pytest.fixture(scope="module")
def collection_nested(client: MongoClient, dimensions: int) -> Collection:
"""A Collection with both a Vector and a Full-text Search Index"""
if COLLECTION_NAME_NESTED not in client[DB_NAME].list_collection_names():
clxn = client[DB_NAME].create_collection(COLLECTION_NAME_NESTED)
else:
clxn = client[DB_NAME][COLLECTION_NAME_NESTED]

clxn.delete_many({})

if not any([VECTOR_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
create_vector_search_index(
collection=clxn,
index_name=VECTOR_INDEX_NAME,
dimensions=dimensions,
path="embedding",
similarity="cosine",
wait_until_complete=TIMEOUT,
)

if not any(
[SEARCH_INDEX_NAME_NESTED == ix["name"] for ix in clxn.list_search_indexes()]
):
create_fulltext_search_index(
collection=clxn,
index_name=SEARCH_INDEX_NAME_NESTED,
field=PAGE_CONTENT_FIELD_NESTED,
wait_until_complete=TIMEOUT,
)

return clxn


@pytest.fixture(scope="module")
def indexed_vectorstore(
collection: Collection,
example_documents: List[Document],
embedding: Embeddings,
) -> Generator[MongoDBAtlasVectorSearch, None, None]:
"""Return a VectorStore with example document embeddings indexed."""

vectorstore = PatchedMongoDBAtlasVectorSearch(
collection=collection,
embedding=embedding,
index_name=VECTOR_INDEX_NAME,
text_key=PAGE_CONTENT_FIELD,
)

vectorstore.add_documents(example_documents)

yield vectorstore

vectorstore.collection.delete_many({})


@pytest.fixture(scope="module")
def indexed_nested_vectorstore(
collection_nested: Collection,
example_documents: List[Document],
embedding: Embeddings,
) -> Generator[MongoDBAtlasVectorSearch, None, None]:
"""Return a VectorStore with example document embeddings indexed."""

vectorstore = PatchedMongoDBAtlasVectorSearch(
collection=collection_nested,
embedding=embedding,
index_name=VECTOR_INDEX_NAME,
text_key=PAGE_CONTENT_FIELD_NESTED,
)

vectorstore.add_documents(example_documents)

yield vectorstore

vectorstore.collection.delete_many({})


def test_vector_retriever(indexed_vectorstore: PatchedMongoDBAtlasVectorSearch) -> None:
"""Test VectorStoreRetriever"""
retriever = indexed_vectorstore.as_retriever()

query1 = "When did I visit France?"
results = retriever.invoke(query1)
assert len(results) == 4
assert "Paris" in results[0].page_content
assert "MongoDB" == results[0].metadata["keywords"]

query2 = "When was the last time I visited new orleans?"
results = retriever.invoke(query2)
assert "New Orleans" in results[0].page_content
assert "Search" == results[0].metadata["keywords"]


def test_hybrid_retriever(indexed_vectorstore: PatchedMongoDBAtlasVectorSearch) -> None:
"""Test basic usage of MongoDBAtlasHybridSearchRetriever"""

retriever = MongoDBAtlasHybridSearchRetriever(
vectorstore=indexed_vectorstore,
search_index_name=SEARCH_INDEX_NAME,
k=3,
)

query1 = "When did I visit France?"
results = retriever.invoke(query1)
assert len(results) == 3
assert "Paris" in results[0].page_content

query2 = "When was the last time I visited new orleans?"
results = retriever.invoke(query2)
assert "New Orleans" in results[0].page_content


def test_hybrid_retriever_deprecated_top_k(
indexed_vectorstore: PatchedMongoDBAtlasVectorSearch,
) -> None:
"""Test basic usage of MongoDBAtlasHybridSearchRetriever"""
retriever = MongoDBAtlasHybridSearchRetriever(
vectorstore=indexed_vectorstore,
search_index_name=SEARCH_INDEX_NAME,
top_k=3,
)

query1 = "When did I visit France?"
results = retriever.invoke(query1)
assert len(results) == 3
assert "Paris" in results[0].page_content

query2 = "When was the last time I visited new orleans?"
results = retriever.invoke(query2)
assert "New Orleans" in results[0].page_content


def test_hybrid_retriever_nested(
indexed_nested_vectorstore: PatchedMongoDBAtlasVectorSearch,
) -> None:
"""Test basic usage of MongoDBAtlasHybridSearchRetriever"""
retriever = MongoDBAtlasHybridSearchRetriever(
vectorstore=indexed_nested_vectorstore,
search_index_name=SEARCH_INDEX_NAME_NESTED,
k=3,
)

query1 = "What did I visit France?"
results = retriever.invoke(query1)
assert len(results) == 3
assert "Paris" in results[0].page_content

query2 = "When was the last time I visited new orleans?"
results = retriever.invoke(query2)
assert "New Orleans" in results[0].page_content


def test_fulltext_retriever(
indexed_vectorstore: PatchedMongoDBAtlasVectorSearch,
) -> None:
"""Test result of performing fulltext search.

The Retriever is independent of the VectorStore.
We use it here only to get the Collection, which we know to be indexed.
"""

collection: Collection = indexed_vectorstore.collection

retriever = MongoDBAtlasFullTextSearchRetriever(
collection=collection,
search_index_name=SEARCH_INDEX_NAME,
search_field=PAGE_CONTENT_FIELD,
)

# Wait for the search index to complete.
search_content = dict(
index=SEARCH_INDEX_NAME,
wildcard=dict(query="*", path=PAGE_CONTENT_FIELD, allowAnalyzedField=True),
)
n_docs = collection.count_documents({})
t0 = time()
while True:
if (time() - t0) > TIMEOUT:
raise TimeoutError(
f"Search index {SEARCH_INDEX_NAME} did not complete in {TIMEOUT}"
)
cursor = collection.aggregate([{"$search": search_content}])
if len(list(cursor)) == n_docs:
break
sleep(INTERVAL)

query = "What is MongoDB"
results = retriever.invoke(query)
print(results)
print(list(collection.list_search_indexes()))
# assert "New Orleans" in results[0].page_content
assert "MongoDB" in results[0].metadata["keywords"]
assert "score" in results[0].metadata
Loading