Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions tests/test_rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import pytest
import torch
import torch.nn as nn

Check failure on line 3 in tests/test_rag.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (F401)

tests/test_rag.py:3:20: F401 `torch.nn` imported but unused
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from unittest.mock import MagicMock, patch
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_milvus.utils.sparse import BaseSparseEmbedding
from pymilvus import MilvusClient


from mmore.rag.retriever import Retriever, RetrieverConfig

Check failure on line 12 in tests/test_rag.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (W291)

tests/test_rag.py:12:59: W291 Trailing whitespace

Check failure on line 12 in tests/test_rag.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (F401)

tests/test_rag.py:12:44: F401 `mmore.rag.retriever.RetrieverConfig` imported but unused

Check failure on line 12 in tests/test_rag.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (I001)

tests/test_rag.py:1:1: I001 Import block is un-sorted or un-formatted

# Mock Classes

class MockEmbeddings(Embeddings):
def embed_query(self, text): return [0.1, 0.2]
def embed_documents(self, texts): return [[0.1, 0.2] for _ in texts]

class MockSparse(BaseSparseEmbedding):
def embed_query(self, text): return {0: 1.0}
def embed_documents(self, texts): return [{0: 1.0} for _ in texts]

class MockMilvus(MilvusClient):
def __init__(self): pass

class MockModel(PreTrainedModel):
def __init__(self):
from transformers import PretrainedConfig
config = PretrainedConfig()
super().__init__(config)
self.logits = torch.tensor([[0.1], [2.0]])

def forward(self, **kwargs):
class Output:
def __init__(self, logits):
self.logits = logits
return Output(self.logits)

class MockBatch:
def __init__(self, data):
self.data = data
def to(self, device): return self
def __getitem__(self, k): return self.data[k]

class MockTokenizer(PreTrainedTokenizerBase):
def __call__(self, queries, docs, **kwargs):
return MockBatch({
"input_ids": torch.tensor([[1, 2], [3, 4]]),
"attention_mask": torch.tensor([[1, 1], [1, 1]])
})


# Tests


def test_retriever_initialization():
"""Test Retriever.from_config initializes correctly with mocked components."""
retriever = Retriever(
dense_model=MockEmbeddings(),
sparse_model=MockSparse(),
client=MockMilvus(),
hybrid_search_weight=0.5,
k=2,
use_web=False,
reranker_model=MockModel(),
reranker_tokenizer=MockTokenizer(),
)
assert isinstance(retriever, Retriever)


@patch("mmore.rag.retriever.Retriever.rerank")
def test_rerank_batch(mock_rerank):
"""Test the reranking logic and ensure docs are sorted correctly by mock model scores."""

Check failure on line 75 in tests/test_rag.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (W293)

tests/test_rag.py:75:1: W293 Blank line contains whitespace
docs = [
Document(page_content="doc1", metadata={"id": "1"}),
Document(page_content="doc2", metadata={"id": "2"}),
]

Check failure on line 80 in tests/test_rag.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (W293)

tests/test_rag.py:80:1: W293 Blank line contains whitespace
def mock_rerank_side_effect(query, docs):
scores = [0.1, 2.0]
scored_docs = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
reranked_docs = []
for doc, score in scored_docs:
new_doc = doc.copy()
new_doc.metadata["similarity"] = score
reranked_docs.append(new_doc)
return reranked_docs

Check failure on line 90 in tests/test_rag.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (W293)

tests/test_rag.py:90:1: W293 Blank line contains whitespace
mock_rerank.side_effect = mock_rerank_side_effect

retriever = Retriever(
dense_model=MockEmbeddings(),
sparse_model=MockSparse(),
client=MockMilvus(),
hybrid_search_weight=0.5,
k=2,
use_web=False,
reranker_model=MockModel(),
reranker_tokenizer=MockTokenizer(),
)

reranked = retriever.rerank("test query", docs)

Check failure on line 104 in tests/test_rag.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (W291)

tests/test_rag.py:104:52: W291 Trailing whitespace

Check failure on line 105 in tests/test_rag.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (W293)

tests/test_rag.py:105:1: W293 Blank line contains whitespace
# Assertions
assert isinstance(reranked, list)
assert reranked[0].page_content == "doc2"
assert reranked[1].page_content == "doc1"
assert reranked[0].metadata["similarity"] == pytest.approx(2.0)
mock_rerank.assert_called_once()


@patch("mmore.rag.retriever.Retriever.retrieve")
@patch("mmore.rag.retriever.Retriever.rerank")
def test_get_relevant_documents(mock_rerank, mock_retrieve):

Check failure on line 116 in tests/test_rag.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (W291)

tests/test_rag.py:116:61: W291 Trailing whitespace
"""Test that _get_relevant_documents integrates retrieval + reranking and transforms Milvus results to Documents."""

# 1. Setup Mocks for Dependencies
mock_retrieve.return_value = [
{"id": "1", "distance": 0.1, "entity": {"text": "doc1 content"}},
{"id": "2", "distance": 0.3, "entity": {"text": "doc2 content"}},
]

def mock_rerank_side_effect(query, docs, **kwargs):
assert all(isinstance(d, Document) for d in docs)
docs[0].metadata["similarity"] = 0.95
docs[1].metadata["similarity"] = 0.85
return [docs[0], docs[1]]

mock_rerank.side_effect = mock_rerank_side_effect

# 2. Initialize the Retriever (Real class)
retriever = Retriever(
dense_model=MockEmbeddings(),
sparse_model=MockSparse(),
client=MockMilvus(),
hybrid_search_weight=0.5,
k=2,
use_web=False,
reranker_model=MockModel(),
reranker_tokenizer=MockTokenizer(),
)

# 3. Call the actual method
docs = retriever._get_relevant_documents("query", run_manager=MagicMock())

# 4. Assertions
assert len(docs) == 2
assert all(isinstance(d, Document) for d in docs)
mock_retrieve.assert_called_once()
mock_rerank.assert_called_once()

assert docs[0].page_content == "doc1 content"
assert docs[0].metadata["similarity"] == pytest.approx(0.95)
assert docs[1].page_content == "doc2 content"
assert docs[1].metadata["similarity"] == pytest.approx(0.85)
77 changes: 77 additions & 0 deletions tests/test_rag_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import pytest
from fastapi import FastAPI, Body
from fastapi.testclient import TestClient
from mmore.rag.pipeline import RAGPipeline, RAGConfig
from mmore.rag.retriever import RetrieverConfig, DBConfig
from mmore.rag.llm import LLMConfig
from pydantic import BaseModel, Field
from unittest.mock import patch, MagicMock

class RAGInput(BaseModel):
"""
Defines the expected input structure for the /rag endpoint.
This structure must align with the MMOREInput model used internally
by the RAGPipeline.
"""
input: str = Field(..., description="The user query or question.")
collection_name: str = Field(..., description="The Milvus collection name to search.")


@pytest.fixture(scope="module")
def app():
retriever_cfg = RetrieverConfig(
db=DBConfig(uri="./proc_demo.db", name="my_db"),
hybrid_search_weight=0.5,
k=2
)
llm_cfg = LLMConfig(llm_name="gpt2")
rag_cfg = RAGConfig(retriever=retriever_cfg, llm=llm_cfg)


with patch("mmore.rag.pipeline.RAGPipeline.from_config") as mock_from_config:


def mock_runnable(input_data, return_dict=False):

if return_dict:
return [{"answer": f"Mocked answer for query: {input_data['input']}"}]

return [{"answer": f"Mocked answer for query: {input_data['input']}"}]

# Create a mock RAGPipeline instance
mock_pipeline = MagicMock(spec=RAGPipeline)

mock_pipeline.side_effect = mock_runnable

mock_from_config.return_value = mock_pipeline

rag_pipeline = RAGPipeline.from_config(rag_cfg)

api = FastAPI()

@api.post("/rag")
def rag_endpoint(input_data: RAGInput):
return rag_pipeline(input_data.model_dump(), return_dict=True)[0]

return api


@pytest.fixture(scope="module")
def client(app):
return TestClient(app)


def test_rag_endpoint(client):
"""Test that the /rag endpoint returns a valid response structure."""

response = client.post(
"/rag",
json={"input": "What is RAG?", "collection_name": "my_docs"}
)

assert response.status_code == 200

data = response.json()
assert "answer" in data
assert isinstance(data["answer"], str)
assert data["answer"].startswith("Mocked answer")
Loading