From 92036e67e9b7987766d64f6077d1bfd0caedbd78 Mon Sep 17 00:00:00 2001 From: portgas37 Date: Tue, 14 Oct 2025 18:40:57 +0000 Subject: [PATCH 1/2] rag test jsonl --- tests/test_rag.py | 157 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 tests/test_rag.py diff --git a/tests/test_rag.py b/tests/test_rag.py new file mode 100644 index 00000000..7291ef03 --- /dev/null +++ b/tests/test_rag.py @@ -0,0 +1,157 @@ +import pytest +import torch +import torch.nn as nn +from transformers import PreTrainedModel, PreTrainedTokenizerBase +from unittest.mock import MagicMock, patch +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_milvus.utils.sparse import BaseSparseEmbedding +from pymilvus import MilvusClient + + +from mmore.rag.retriever import Retriever, RetrieverConfig + +# Mock Classes + +class MockEmbeddings(Embeddings): + def embed_query(self, text): return [0.1, 0.2] + def embed_documents(self, texts): return [[0.1, 0.2] for _ in texts] + +class MockSparse(BaseSparseEmbedding): + def embed_query(self, text): return {0: 1.0} + def embed_documents(self, texts): return [{0: 1.0} for _ in texts] + +class MockMilvus(MilvusClient): + def __init__(self): pass + +class MockModel(PreTrainedModel): + def __init__(self): + from transformers import PretrainedConfig + config = PretrainedConfig() + super().__init__(config) + self.logits = torch.tensor([[0.1], [2.0]]) + + def forward(self, **kwargs): + class Output: + def __init__(self, logits): + self.logits = logits + return Output(self.logits) + +class MockBatch: + def __init__(self, data): + self.data = data + def to(self, device): return self + def __getitem__(self, k): return self.data[k] + +class MockTokenizer(PreTrainedTokenizerBase): + def __call__(self, queries, docs, **kwargs): + return MockBatch({ + "input_ids": torch.tensor([[1, 2], [3, 4]]), + "attention_mask": torch.tensor([[1, 1], [1, 1]]) + }) + + +# Tests + + +def test_retriever_initialization(): + """Test Retriever.from_config initializes correctly with mocked components.""" + retriever = Retriever( + dense_model=MockEmbeddings(), + sparse_model=MockSparse(), + client=MockMilvus(), + hybrid_search_weight=0.5, + k=2, + use_web=False, + reranker_model=MockModel(), + reranker_tokenizer=MockTokenizer(), + ) + assert isinstance(retriever, Retriever) + + +@patch("mmore.rag.retriever.Retriever.rerank") +def test_rerank_batch(mock_rerank): + """Test the reranking logic and ensure docs are sorted correctly by mock model scores.""" + + docs = [ + Document(page_content="doc1", metadata={"id": "1"}), + Document(page_content="doc2", metadata={"id": "2"}), + ] + + def mock_rerank_side_effect(query, docs): + scores = [0.1, 2.0] + scored_docs = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True) + reranked_docs = [] + for doc, score in scored_docs: + new_doc = doc.copy() + new_doc.metadata["similarity"] = score + reranked_docs.append(new_doc) + return reranked_docs + + mock_rerank.side_effect = mock_rerank_side_effect + + retriever = Retriever( + dense_model=MockEmbeddings(), + sparse_model=MockSparse(), + client=MockMilvus(), + hybrid_search_weight=0.5, + k=2, + use_web=False, + reranker_model=MockModel(), + reranker_tokenizer=MockTokenizer(), + ) + + reranked = retriever.rerank("test query", docs) + + # Assertions + assert isinstance(reranked, list) + assert reranked[0].page_content == "doc2" + assert reranked[1].page_content == "doc1" + assert reranked[0].metadata["similarity"] == pytest.approx(2.0) + mock_rerank.assert_called_once() + + +@patch("mmore.rag.retriever.Retriever.retrieve") +@patch("mmore.rag.retriever.Retriever.rerank") +def test_get_relevant_documents(mock_rerank, mock_retrieve): + """Test that _get_relevant_documents integrates retrieval + reranking and transforms Milvus results to Documents.""" + + # 1. Setup Mocks for Dependencies + mock_retrieve.return_value = [ + {"id": "1", "distance": 0.1, "entity": {"text": "doc1 content"}}, + {"id": "2", "distance": 0.3, "entity": {"text": "doc2 content"}}, + ] + + def mock_rerank_side_effect(query, docs, **kwargs): + assert all(isinstance(d, Document) for d in docs) + docs[0].metadata["similarity"] = 0.95 + docs[1].metadata["similarity"] = 0.85 + return [docs[0], docs[1]] + + mock_rerank.side_effect = mock_rerank_side_effect + + # 2. Initialize the Retriever (Real class) + retriever = Retriever( + dense_model=MockEmbeddings(), + sparse_model=MockSparse(), + client=MockMilvus(), + hybrid_search_weight=0.5, + k=2, + use_web=False, + reranker_model=MockModel(), + reranker_tokenizer=MockTokenizer(), + ) + + # 3. Call the actual method + docs = retriever._get_relevant_documents("query", run_manager=MagicMock()) + + # 4. Assertions + assert len(docs) == 2 + assert all(isinstance(d, Document) for d in docs) + mock_retrieve.assert_called_once() + mock_rerank.assert_called_once() + + assert docs[0].page_content == "doc1 content" + assert docs[0].metadata["similarity"] == pytest.approx(0.95) + assert docs[1].page_content == "doc2 content" + assert docs[1].metadata["similarity"] == pytest.approx(0.85) From a4396cc671c4ebf27cc8b4bfc100ced8f45dbe64 Mon Sep 17 00:00:00 2001 From: portgas37 Date: Tue, 14 Oct 2025 20:41:43 +0000 Subject: [PATCH 2/2] test for the api endpoint --- tests/test_rag_api.py | 77 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 tests/test_rag_api.py diff --git a/tests/test_rag_api.py b/tests/test_rag_api.py new file mode 100644 index 00000000..a8b4c7d7 --- /dev/null +++ b/tests/test_rag_api.py @@ -0,0 +1,77 @@ +import pytest +from fastapi import FastAPI, Body +from fastapi.testclient import TestClient +from mmore.rag.pipeline import RAGPipeline, RAGConfig +from mmore.rag.retriever import RetrieverConfig, DBConfig +from mmore.rag.llm import LLMConfig +from pydantic import BaseModel, Field +from unittest.mock import patch, MagicMock + +class RAGInput(BaseModel): + """ + Defines the expected input structure for the /rag endpoint. + This structure must align with the MMOREInput model used internally + by the RAGPipeline. + """ + input: str = Field(..., description="The user query or question.") + collection_name: str = Field(..., description="The Milvus collection name to search.") + + +@pytest.fixture(scope="module") +def app(): + retriever_cfg = RetrieverConfig( + db=DBConfig(uri="./proc_demo.db", name="my_db"), + hybrid_search_weight=0.5, + k=2 + ) + llm_cfg = LLMConfig(llm_name="gpt2") + rag_cfg = RAGConfig(retriever=retriever_cfg, llm=llm_cfg) + + + with patch("mmore.rag.pipeline.RAGPipeline.from_config") as mock_from_config: + + + def mock_runnable(input_data, return_dict=False): + + if return_dict: + return [{"answer": f"Mocked answer for query: {input_data['input']}"}] + + return [{"answer": f"Mocked answer for query: {input_data['input']}"}] + + # Create a mock RAGPipeline instance + mock_pipeline = MagicMock(spec=RAGPipeline) + + mock_pipeline.side_effect = mock_runnable + + mock_from_config.return_value = mock_pipeline + + rag_pipeline = RAGPipeline.from_config(rag_cfg) + + api = FastAPI() + + @api.post("/rag") + def rag_endpoint(input_data: RAGInput): + return rag_pipeline(input_data.model_dump(), return_dict=True)[0] + + return api + + +@pytest.fixture(scope="module") +def client(app): + return TestClient(app) + + +def test_rag_endpoint(client): + """Test that the /rag endpoint returns a valid response structure.""" + + response = client.post( + "/rag", + json={"input": "What is RAG?", "collection_name": "my_docs"} + ) + + assert response.status_code == 200 + + data = response.json() + assert "answer" in data + assert isinstance(data["answer"], str) + assert data["answer"].startswith("Mocked answer")