Skip to content

Commit e3623ed

Browse files
mskarlinjamesbraza
andauthored
Add partitioning func capabilities to allow doc-types-based embedding ranking (#752)
Co-authored-by: James Braza <[email protected]>
1 parent 08026d3 commit e3623ed

File tree

6 files changed

+2929
-8
lines changed

6 files changed

+2929
-8
lines changed

paperqa/agents/tools.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from paperqa.docs import Docs
1616
from paperqa.llms import EmbeddingModel, LiteLLMModel
1717
from paperqa.settings import Settings
18-
from paperqa.types import DocDetails, PQASession
18+
from paperqa.types import DocDetails, Embeddable, PQASession
1919

2020
from .search import get_directory_index
2121

@@ -193,6 +193,7 @@ class GatherEvidence(NamedTool):
193193
settings: Settings
194194
summary_llm_model: LiteLLMModel
195195
embedding_model: EmbeddingModel
196+
partitioning_fn: Callable[[Embeddable], int] | None = None
196197

197198
async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
198199
"""
@@ -236,6 +237,7 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
236237
settings=self.settings,
237238
embedding_model=self.embedding_model,
238239
summary_llm_model=self.summary_llm_model,
240+
partitioning_fn=self.partitioning_fn,
239241
callbacks=self.settings.agent.callbacks.get(
240242
f"{self.TOOL_FN_NAME}_aget_evidence"
241243
),
@@ -275,6 +277,7 @@ class GenerateAnswer(NamedTool):
275277
llm_model: LiteLLMModel
276278
summary_llm_model: LiteLLMModel
277279
embedding_model: EmbeddingModel
280+
partitioning_fn: Callable[[Embeddable], int] | None = None
278281

279282
async def gen_answer(self, state: EnvironmentState) -> str:
280283
"""
@@ -305,6 +308,7 @@ async def gen_answer(self, state: EnvironmentState) -> str:
305308
llm_model=self.llm_model,
306309
summary_llm_model=self.summary_llm_model,
307310
embedding_model=self.embedding_model,
311+
partitioning_fn=self.partitioning_fn,
308312
callbacks=self.settings.agent.callbacks.get(
309313
f"{self.TOOL_FN_NAME}_aget_query"
310314
),

paperqa/docs.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
Doc,
3939
DocDetails,
4040
DocKey,
41+
Embeddable,
4142
LLMResult,
4243
PQASession,
4344
Text,
@@ -518,6 +519,7 @@ async def retrieve_texts(
518519
k: int,
519520
settings: MaybeSettings = None,
520521
embedding_model: EmbeddingModel | None = None,
522+
partitioning_fn: Callable[[Embeddable], int] | None = None,
521523
) -> list[Text]:
522524

523525
settings = get_settings(settings)
@@ -533,7 +535,11 @@ async def retrieve_texts(
533535
list[Text],
534536
(
535537
await self.texts_index.max_marginal_relevance_search(
536-
query, k=_k, fetch_k=2 * _k, embedding_model=embedding_model
538+
query,
539+
k=_k,
540+
fetch_k=2 * _k,
541+
embedding_model=embedding_model,
542+
partitioning_fn=partitioning_fn,
537543
)
538544
)[0],
539545
)
@@ -548,6 +554,7 @@ def get_evidence(
548554
callbacks: list[Callable] | None = None,
549555
embedding_model: EmbeddingModel | None = None,
550556
summary_llm_model: LLMModel | None = None,
557+
partitioning_fn: Callable[[Embeddable], int] | None = None,
551558
) -> PQASession:
552559
return get_loop().run_until_complete(
553560
self.aget_evidence(
@@ -557,6 +564,7 @@ def get_evidence(
557564
callbacks=callbacks,
558565
embedding_model=embedding_model,
559566
summary_llm_model=summary_llm_model,
567+
partitioning_fn=partitioning_fn,
560568
)
561569
)
562570

@@ -568,6 +576,7 @@ async def aget_evidence(
568576
callbacks: list[Callable] | None = None,
569577
embedding_model: EmbeddingModel | None = None,
570578
summary_llm_model: LLMModel | None = None,
579+
partitioning_fn: Callable[[Embeddable], int] | None = None,
571580
) -> PQASession:
572581

573582
evidence_settings = get_settings(settings)
@@ -600,7 +609,11 @@ async def aget_evidence(
600609

601610
if answer_config.evidence_retrieval:
602611
matches = await self.retrieve_texts(
603-
session.question, _k, evidence_settings, embedding_model
612+
session.question,
613+
_k,
614+
evidence_settings,
615+
embedding_model,
616+
partitioning_fn=partitioning_fn,
604617
)
605618
else:
606619
matches = self.texts
@@ -662,6 +675,7 @@ def query(
662675
llm_model: LLMModel | None = None,
663676
summary_llm_model: LLMModel | None = None,
664677
embedding_model: EmbeddingModel | None = None,
678+
partitioning_fn: Callable[[Embeddable], int] | None = None,
665679
) -> PQASession:
666680
return get_loop().run_until_complete(
667681
self.aquery(
@@ -671,6 +685,7 @@ def query(
671685
llm_model=llm_model,
672686
summary_llm_model=summary_llm_model,
673687
embedding_model=embedding_model,
688+
partitioning_fn=partitioning_fn,
674689
)
675690
)
676691

@@ -682,6 +697,7 @@ async def aquery( # noqa: PLR0912
682697
llm_model: LLMModel | None = None,
683698
summary_llm_model: LLMModel | None = None,
684699
embedding_model: EmbeddingModel | None = None,
700+
partitioning_fn: Callable[[Embeddable], int] | None = None,
685701
) -> PQASession:
686702

687703
query_settings = get_settings(settings)
@@ -709,6 +725,7 @@ async def aquery( # noqa: PLR0912
709725
settings=settings,
710726
embedding_model=embedding_model,
711727
summary_llm_model=summary_llm_model,
728+
partitioning_fn=partitioning_fn,
712729
)
713730
contexts = session.contexts
714731
pre_str = None

paperqa/llms.py

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import asyncio
44
import contextlib
55
import functools
6+
import itertools
7+
import logging
68
from abc import ABC, abstractmethod
79
from collections.abc import (
810
AsyncGenerator,
@@ -43,6 +45,8 @@
4345

4446
MODEL_COST_MAP = litellm.get_model_cost_map("")
4547

48+
logger = logging.getLogger(__name__)
49+
4650

4751
def prepare_args(func: Callable, chunk: str, name: str | None) -> tuple[tuple, dict]:
4852
with contextlib.suppress(TypeError):
@@ -802,8 +806,35 @@ async def similarity_search(
802806
def clear(self) -> None:
803807
self.texts_hashes = set()
804808

809+
async def partitioned_similarity_search(
810+
self,
811+
query: str,
812+
k: int,
813+
embedding_model: EmbeddingModel,
814+
partitioning_fn: Callable[[Embeddable], int],
815+
) -> tuple[Sequence[Embeddable], list[float]]:
816+
"""Partition the documents into different groups and perform similarity search.
817+
818+
Args:
819+
query: query string
820+
k: Number of results to return
821+
embedding_model: model used to embed the query
822+
partitioning_fn: function to partition the documents into different groups.
823+
824+
Returns:
825+
Tuple of lists of Embeddables and scores of length k.
826+
"""
827+
raise NotImplementedError(
828+
"partitioned_similarity_search is not implemented for this VectorStore."
829+
)
830+
805831
async def max_marginal_relevance_search(
806-
self, query: str, k: int, fetch_k: int, embedding_model: EmbeddingModel
832+
self,
833+
query: str,
834+
k: int,
835+
fetch_k: int,
836+
embedding_model: EmbeddingModel,
837+
partitioning_fn: Callable[[Embeddable], int] | None = None,
807838
) -> tuple[Sequence[Embeddable], list[float]]:
808839
"""Vectorized implementation of Maximal Marginal Relevance (MMR) search.
809840
@@ -812,14 +843,24 @@ async def max_marginal_relevance_search(
812843
k: Number of results to return.
813844
fetch_k: Number of results to fetch from the vector store.
814845
embedding_model: model used to embed the query
846+
partitioning_fn: optional function to partition the documents into
847+
different groups, performing MMR within each group.
815848
816849
Returns:
817850
List of tuples (doc, score) of length k.
818851
"""
819852
if fetch_k < k:
820853
raise ValueError("fetch_k must be greater or equal to k")
821854

822-
texts, scores = await self.similarity_search(query, fetch_k, embedding_model)
855+
if partitioning_fn is None:
856+
texts, scores = await self.similarity_search(
857+
query, fetch_k, embedding_model
858+
)
859+
else:
860+
texts, scores = await self.partitioned_similarity_search(
861+
query, fetch_k, embedding_model, partitioning_fn
862+
)
863+
823864
if len(texts) <= k or self.mmr_lambda >= 1.0:
824865
return texts, scores
825866

@@ -852,6 +893,7 @@ async def max_marginal_relevance_search(
852893
class NumpyVectorStore(VectorStore):
853894
texts: list[Embeddable] = Field(default_factory=list)
854895
_embeddings_matrix: np.ndarray | None = None
896+
_texts_filter: np.ndarray | None = None
855897

856898
def __eq__(self, other) -> bool:
857899
if not isinstance(other, type(self)):
@@ -875,12 +917,47 @@ def clear(self) -> None:
875917
super().clear()
876918
self.texts = []
877919
self._embeddings_matrix = None
920+
self._texts_filter = None
878921

879922
def add_texts_and_embeddings(self, texts: Iterable[Embeddable]) -> None:
880923
super().add_texts_and_embeddings(texts)
881924
self.texts.extend(texts)
882925
self._embeddings_matrix = np.array([t.embedding for t in self.texts])
883926

927+
async def partitioned_similarity_search(
928+
self,
929+
query: str,
930+
k: int,
931+
embedding_model: EmbeddingModel,
932+
partitioning_fn: Callable[[Embeddable], int],
933+
) -> tuple[Sequence[Embeddable], list[float]]:
934+
scores: list[list[float]] = []
935+
texts: list[Sequence[Embeddable]] = []
936+
937+
text_partitions = np.array([partitioning_fn(t) for t in self.texts])
938+
# CPU bound so replacing w a gather wouldn't get us anything
939+
# plus we need to reset self._texts_filter each iteration
940+
for partition in np.unique(text_partitions):
941+
self._texts_filter = text_partitions == partition
942+
_texts, _scores = await self.similarity_search(query, k, embedding_model)
943+
texts.append(_texts)
944+
scores.append(_scores)
945+
# reset the filter after running
946+
self._texts_filter = None
947+
948+
return (
949+
[
950+
t
951+
for t in itertools.chain.from_iterable(itertools.zip_longest(*texts))
952+
if t is not None
953+
][:k],
954+
[
955+
s
956+
for s in itertools.chain.from_iterable(itertools.zip_longest(*scores))
957+
if s is not None
958+
][:k],
959+
)
960+
884961
async def similarity_search(
885962
self, query: str, k: int, embedding_model: EmbeddingModel
886963
) -> tuple[Sequence[Embeddable], list[float]]:
@@ -895,16 +972,24 @@ async def similarity_search(
895972

896973
embedding_model.set_mode(EmbeddingModes.DOCUMENT)
897974

975+
embedding_matrix = self._embeddings_matrix
976+
977+
if self._texts_filter is not None:
978+
original_indices = np.where(self._texts_filter)[0]
979+
embedding_matrix = embedding_matrix[self._texts_filter] # type: ignore[index]
980+
else:
981+
original_indices = np.arange(len(self.texts))
982+
898983
similarity_scores = cosine_similarity(
899-
np_query.reshape(1, -1), self._embeddings_matrix
984+
np_query.reshape(1, -1), embedding_matrix
900985
)[0]
901986
similarity_scores = np.nan_to_num(similarity_scores, nan=-np.inf)
902987
# minus so descending
903988
# we could use arg-partition here
904989
# but a lot of algorithms expect a sorted list
905990
sorted_indices = np.argsort(-similarity_scores)
906991
return (
907-
[self.texts[i] for i in sorted_indices[:k]],
992+
[self.texts[i] for i in original_indices[sorted_indices][:k]],
908993
[similarity_scores[i] for i in sorted_indices[:k]],
909994
)
910995

0 commit comments

Comments
 (0)