3
3
import asyncio
4
4
import contextlib
5
5
import functools
6
+ import itertools
7
+ import logging
6
8
from abc import ABC , abstractmethod
7
9
from collections .abc import (
8
10
AsyncGenerator ,
43
45
44
46
MODEL_COST_MAP = litellm .get_model_cost_map ("" )
45
47
48
+ logger = logging .getLogger (__name__ )
49
+
46
50
47
51
def prepare_args (func : Callable , chunk : str , name : str | None ) -> tuple [tuple , dict ]:
48
52
with contextlib .suppress (TypeError ):
@@ -802,8 +806,35 @@ async def similarity_search(
802
806
def clear (self ) -> None :
803
807
self .texts_hashes = set ()
804
808
809
+ async def partitioned_similarity_search (
810
+ self ,
811
+ query : str ,
812
+ k : int ,
813
+ embedding_model : EmbeddingModel ,
814
+ partitioning_fn : Callable [[Embeddable ], int ],
815
+ ) -> tuple [Sequence [Embeddable ], list [float ]]:
816
+ """Partition the documents into different groups and perform similarity search.
817
+
818
+ Args:
819
+ query: query string
820
+ k: Number of results to return
821
+ embedding_model: model used to embed the query
822
+ partitioning_fn: function to partition the documents into different groups.
823
+
824
+ Returns:
825
+ Tuple of lists of Embeddables and scores of length k.
826
+ """
827
+ raise NotImplementedError (
828
+ "partitioned_similarity_search is not implemented for this VectorStore."
829
+ )
830
+
805
831
async def max_marginal_relevance_search (
806
- self , query : str , k : int , fetch_k : int , embedding_model : EmbeddingModel
832
+ self ,
833
+ query : str ,
834
+ k : int ,
835
+ fetch_k : int ,
836
+ embedding_model : EmbeddingModel ,
837
+ partitioning_fn : Callable [[Embeddable ], int ] | None = None ,
807
838
) -> tuple [Sequence [Embeddable ], list [float ]]:
808
839
"""Vectorized implementation of Maximal Marginal Relevance (MMR) search.
809
840
@@ -812,14 +843,24 @@ async def max_marginal_relevance_search(
812
843
k: Number of results to return.
813
844
fetch_k: Number of results to fetch from the vector store.
814
845
embedding_model: model used to embed the query
846
+ partitioning_fn: optional function to partition the documents into
847
+ different groups, performing MMR within each group.
815
848
816
849
Returns:
817
850
List of tuples (doc, score) of length k.
818
851
"""
819
852
if fetch_k < k :
820
853
raise ValueError ("fetch_k must be greater or equal to k" )
821
854
822
- texts , scores = await self .similarity_search (query , fetch_k , embedding_model )
855
+ if partitioning_fn is None :
856
+ texts , scores = await self .similarity_search (
857
+ query , fetch_k , embedding_model
858
+ )
859
+ else :
860
+ texts , scores = await self .partitioned_similarity_search (
861
+ query , fetch_k , embedding_model , partitioning_fn
862
+ )
863
+
823
864
if len (texts ) <= k or self .mmr_lambda >= 1.0 :
824
865
return texts , scores
825
866
@@ -852,6 +893,7 @@ async def max_marginal_relevance_search(
852
893
class NumpyVectorStore (VectorStore ):
853
894
texts : list [Embeddable ] = Field (default_factory = list )
854
895
_embeddings_matrix : np .ndarray | None = None
896
+ _texts_filter : np .ndarray | None = None
855
897
856
898
def __eq__ (self , other ) -> bool :
857
899
if not isinstance (other , type (self )):
@@ -875,12 +917,47 @@ def clear(self) -> None:
875
917
super ().clear ()
876
918
self .texts = []
877
919
self ._embeddings_matrix = None
920
+ self ._texts_filter = None
878
921
879
922
def add_texts_and_embeddings (self , texts : Iterable [Embeddable ]) -> None :
880
923
super ().add_texts_and_embeddings (texts )
881
924
self .texts .extend (texts )
882
925
self ._embeddings_matrix = np .array ([t .embedding for t in self .texts ])
883
926
927
+ async def partitioned_similarity_search (
928
+ self ,
929
+ query : str ,
930
+ k : int ,
931
+ embedding_model : EmbeddingModel ,
932
+ partitioning_fn : Callable [[Embeddable ], int ],
933
+ ) -> tuple [Sequence [Embeddable ], list [float ]]:
934
+ scores : list [list [float ]] = []
935
+ texts : list [Sequence [Embeddable ]] = []
936
+
937
+ text_partitions = np .array ([partitioning_fn (t ) for t in self .texts ])
938
+ # CPU bound so replacing w a gather wouldn't get us anything
939
+ # plus we need to reset self._texts_filter each iteration
940
+ for partition in np .unique (text_partitions ):
941
+ self ._texts_filter = text_partitions == partition
942
+ _texts , _scores = await self .similarity_search (query , k , embedding_model )
943
+ texts .append (_texts )
944
+ scores .append (_scores )
945
+ # reset the filter after running
946
+ self ._texts_filter = None
947
+
948
+ return (
949
+ [
950
+ t
951
+ for t in itertools .chain .from_iterable (itertools .zip_longest (* texts ))
952
+ if t is not None
953
+ ][:k ],
954
+ [
955
+ s
956
+ for s in itertools .chain .from_iterable (itertools .zip_longest (* scores ))
957
+ if s is not None
958
+ ][:k ],
959
+ )
960
+
884
961
async def similarity_search (
885
962
self , query : str , k : int , embedding_model : EmbeddingModel
886
963
) -> tuple [Sequence [Embeddable ], list [float ]]:
@@ -895,16 +972,24 @@ async def similarity_search(
895
972
896
973
embedding_model .set_mode (EmbeddingModes .DOCUMENT )
897
974
975
+ embedding_matrix = self ._embeddings_matrix
976
+
977
+ if self ._texts_filter is not None :
978
+ original_indices = np .where (self ._texts_filter )[0 ]
979
+ embedding_matrix = embedding_matrix [self ._texts_filter ] # type: ignore[index]
980
+ else :
981
+ original_indices = np .arange (len (self .texts ))
982
+
898
983
similarity_scores = cosine_similarity (
899
- np_query .reshape (1 , - 1 ), self . _embeddings_matrix
984
+ np_query .reshape (1 , - 1 ), embedding_matrix
900
985
)[0 ]
901
986
similarity_scores = np .nan_to_num (similarity_scores , nan = - np .inf )
902
987
# minus so descending
903
988
# we could use arg-partition here
904
989
# but a lot of algorithms expect a sorted list
905
990
sorted_indices = np .argsort (- similarity_scores )
906
991
return (
907
- [self .texts [i ] for i in sorted_indices [:k ]],
992
+ [self .texts [i ] for i in original_indices [ sorted_indices ] [:k ]],
908
993
[similarity_scores [i ] for i in sorted_indices [:k ]],
909
994
)
910
995
0 commit comments