Skip to content

Commit 2f477a6

Browse files
committed
fix: use run_in_executor instead of raising error
1 parent c0f0648 commit 2f477a6

File tree

3 files changed

+99
-54
lines changed

3 files changed

+99
-54
lines changed

libs/weaviate/langchain_weaviate/vectorstores.py

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import weaviate # type: ignore
2424
from langchain_core.documents import Document
2525
from langchain_core.embeddings import Embeddings
26+
from langchain_core.runnables.config import run_in_executor
2627
from langchain_core.vectorstores import VectorStore
2728

2829
from langchain_weaviate.utils import maximal_marginal_relevance
@@ -570,7 +571,10 @@ async def aadd_texts(
570571
) -> List[str]:
571572
"""Add texts to Weaviate asynchronously."""
572573
if self._client_async is None:
573-
raise ValueError("client_async must be an instance of WeaviateAsyncClient")
574+
logger.warning("client_async is None, using synchronous client instead")
575+
return await run_in_executor(
576+
None, self.add_texts, texts, metadatas, tenant, **kwargs
577+
)
574578
from weaviate.util import get_valid_uuid # type: ignore
575579

576580
if tenant and not await self._adoes_tenant_exist(tenant):
@@ -678,7 +682,7 @@ async def _perform_asearch(
678682
ValueError: If _embedding is None or an invalid search method is provided.
679683
"""
680684
if self._client_async is None:
681-
raise ValueError("client_async must be an instance of WeaviateAsyncClient")
685+
raise ValueError("cannot perform asearch with synchronous client")
682686
if self._embedding is None:
683687
raise ValueError("_embedding cannot be None for similarity_search")
684688

@@ -739,8 +743,9 @@ async def _perform_asearch(
739743

740744
async def _adoes_tenant_exist(self, tenant: str) -> bool:
741745
"""Check if tenant exists in Weaviate asynchronously."""
742-
if self._client_async is None:
743-
raise ValueError("client_async must be an instance of WeaviateAsyncClient")
746+
assert (
747+
self._client_async is not None
748+
), "client_async must be an instance of WeaviateAsyncClient"
744749
assert (
745750
self._multi_tenancy_enabled
746751
), "Cannot check for tenant existence when multi-tenancy is not enabled"
@@ -763,6 +768,10 @@ async def asimilarity_search(
763768
Returns:
764769
List of Documents most similar to the query.
765770
"""
771+
if self._client_async is None:
772+
return await run_in_executor(
773+
None, self.similarity_search, query, k, **kwargs
774+
)
766775
result = await self._perform_asearch(query, k, **kwargs)
767776
return result
768777

@@ -791,6 +800,16 @@ async def amax_marginal_relevance_search(
791800
Returns:
792801
List of Documents selected by maximal marginal relevance.
793802
"""
803+
if self._client_async is None:
804+
return await run_in_executor(
805+
None,
806+
self.max_marginal_relevance_search,
807+
query,
808+
k,
809+
fetch_k,
810+
lambda_mult,
811+
**kwargs,
812+
)
794813
if self._embedding is not None:
795814
embedding = await self._embedding.aembed_query(query)
796815
else:
@@ -827,6 +846,16 @@ async def amax_marginal_relevance_search_by_vector(
827846
Returns:
828847
List of Documents selected by maximal marginal relevance.
829848
"""
849+
if self._client_async is None:
850+
return await run_in_executor(
851+
None,
852+
self.max_marginal_relevance_search_by_vector,
853+
embedding,
854+
k,
855+
fetch_k,
856+
lambda_mult,
857+
**kwargs,
858+
)
830859
results = await self._perform_asearch(
831860
query=None,
832861
k=fetch_k,
@@ -857,9 +886,36 @@ async def asimilarity_search_with_score(
857886
text and cosine distance in float for each.
858887
Lower score represents more similarity.
859888
"""
889+
if self._client_async is None:
890+
return await run_in_executor(
891+
None, self.similarity_search_with_score, query, k, **kwargs
892+
)
860893
results = await self._perform_asearch(query, k, return_score=True, **kwargs)
861894
return results
862895

896+
async def asimilarity_search_by_vector(
897+
self, embedding: List[float], k: int = 4, **kwargs: Any
898+
) -> List[Document]:
899+
"""Return docs most similar to embedding vector asynchronously.
900+
901+
Args:
902+
embedding: Embedding vector to look up documents similar to.
903+
k: Number of Documents to return. Defaults to 4.
904+
**kwargs: Additional keyword arguments will be passed to the `hybrid()`
905+
function of the weaviate client.
906+
907+
Returns:
908+
List of Documents most similar to the embedding.
909+
"""
910+
if self._client_async is None:
911+
return await run_in_executor(
912+
None, self.similarity_search_by_vector, embedding, k, **kwargs
913+
)
914+
result = await self._perform_asearch(
915+
query=None, k=k, vector=embedding, **kwargs
916+
)
917+
return result
918+
863919
@classmethod
864920
async def afrom_texts(
865921
cls,
@@ -917,8 +973,6 @@ async def afrom_texts(
917973

918974
if client is None:
919975
raise ValueError("client must be an instance of WeaviateClient")
920-
if client_async is None:
921-
raise ValueError("client_async must be an instance of WeaviateAsyncClient")
922976

923977
weaviate_vector_store = cls(
924978
client,
@@ -962,22 +1016,3 @@ async def _atenant_context(
9621016
yield collection
9631017
finally:
9641018
pass
965-
966-
async def asimilarity_search_by_vector(
967-
self, embedding: List[float], k: int = 4, **kwargs: Any
968-
) -> List[Document]:
969-
"""Return docs most similar to embedding vector asynchronously.
970-
971-
Args:
972-
embedding: Embedding vector to look up documents similar to.
973-
k: Number of Documents to return. Defaults to 4.
974-
**kwargs: Additional keyword arguments will be passed to the `hybrid()`
975-
function of the weaviate client.
976-
977-
Returns:
978-
List of Documents most similar to the embedding.
979-
"""
980-
result = await self._perform_asearch(
981-
query=None, k=k, vector=embedding, **kwargs
982-
)
983-
return result

libs/weaviate/tests/unit_tests/test_vectorstores_additional.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,19 +1035,9 @@ async def test_atenant_context_missing_tenant(
10351035
async with docsearch._atenant_context(tenant=None):
10361036
pass
10371037

1038-
1039-
@pytest.mark.asyncio
1040-
async def test_warning_when_client_async_not_provided(
1041-
embedding: Any, mock_weaviate_client: MagicMock
1042-
) -> None:
1043-
"""Test that a warning is logged when client_async is not provided."""
1044-
# Create store with client_async not provided
1038+
# Test that _perform_asearch raises ValueError when client_async is None
1039+
docsearch._client_async = None
10451040
with pytest.raises(
1046-
ValueError, match="client_async must be an instance of WeaviateAsyncClient"
1041+
ValueError, match="cannot perform asearch with synchronous client"
10471042
):
1048-
await WeaviateVectorStore.afrom_texts(
1049-
client=mock_weaviate_client,
1050-
texts=["test"],
1051-
embedding=embedding,
1052-
client_async=None,
1053-
)
1043+
await docsearch._perform_asearch(query="test", k=5)

libs/weaviate/tests/unit_tests/test_vectorstores_integration_async.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,31 @@ async def test_invalid_search_param(
932932
await weaviate_vector_store._perform_asearch(query=None, vector=[1, 2, 3], k=5)
933933

934934

935+
@pytest.mark.asyncio
936+
async def test_async_function_without_async_client(
937+
weaviate_client: weaviate.WeaviateClient,
938+
consistent_embedding: ConsistentFakeEmbeddings,
939+
) -> None:
940+
index_name = f"TestIndex_{uuid.uuid4().hex}"
941+
docsearch = WeaviateVectorStore(
942+
client=weaviate_client,
943+
client_async=None,
944+
index_name=index_name,
945+
text_key="text",
946+
embedding=consistent_embedding,
947+
)
948+
949+
# test all async functions without raising errors
950+
# these should be run in the executor, this is identical to how the
951+
# BaseVectorStore class runs the functions without the async method override
952+
await docsearch.asimilarity_search("foo", k=1)
953+
await docsearch.asimilarity_search_with_score("foo", k=1)
954+
await docsearch.asimilarity_search_by_vector([1, 2, 3], k=1)
955+
await docsearch.amax_marginal_relevance_search("foo", k=1)
956+
await docsearch.amax_marginal_relevance_search_by_vector([1, 2, 3], k=1)
957+
await docsearch.aadd_texts(["foo"])
958+
959+
935960
@pytest.mark.asyncio
936961
async def test_missing_coverage_edge_cases(
937962
weaviate_client: weaviate.WeaviateClient,
@@ -954,9 +979,10 @@ async def test_missing_coverage_edge_cases(
954979
text_key="text",
955980
embedding=embedding,
956981
)
957-
958-
with pytest.raises(ValueError, match="client_async must be an instance"):
982+
# check if warning is logged
983+
with caplog.at_level(logging.WARNING, logger="langchain_weaviate.vectorstores"):
959984
await docsearch_no_async.aadd_texts(["test"])
985+
assert "client_async is None, using synchronous client instead" in caplog.text
960986

961987
# Test line 355: Test max_marginal_relevance_search without embedding
962988
docsearch_no_embedding = WeaviateVectorStore(
@@ -970,10 +996,6 @@ async def test_missing_coverage_edge_cases(
970996
with pytest.raises(ValueError, match="max_marginal_relevance_search requires"):
971997
await docsearch_no_embedding.amax_marginal_relevance_search("test")
972998

973-
# Test lines 706-707: Test _perform_asearch without client_async or embedding
974-
with pytest.raises(ValueError, match="client_async must be an instance"):
975-
await docsearch_no_async._perform_asearch(query="test", k=1)
976-
977999
docsearch = WeaviateVectorStore(
9781000
client=weaviate_client,
9791001
client_async=weaviate_client_async,
@@ -1003,15 +1025,6 @@ async def patched_perform_asearch(
10031025
with pytest.raises(ValueError, match="Either query or vector must be provided"):
10041026
await docsearch._perform_asearch(query=None, vector=None, k=1)
10051027

1006-
# Test line 735: Test _adoes_tenant_exist with no client_async
1007-
with pytest.raises(ValueError, match="client_async must be an instance"):
1008-
await docsearch_no_async._adoes_tenant_exist("test_tenant")
1009-
1010-
# Test lines 938, 946: Test _atenant_context with errors
1011-
with pytest.raises(ValueError, match="client_async must be an instance"):
1012-
async with docsearch_no_async._atenant_context("test_tenant"):
1013-
pass
1014-
10151028
# Create a docsearch with client_async but without multi-tenancy
10161029
docsearch_no_mt = WeaviateVectorStore(
10171030
client=weaviate_client,
@@ -1043,3 +1056,10 @@ async def patched_perform_asearch(
10431056
with pytest.raises(ValueError, match=msg):
10441057
async with docsearch_with_mt._atenant_context(tenant=None):
10451058
pass
1059+
1060+
# Test that an error is raised without the async client
1061+
with pytest.raises(
1062+
ValueError, match="client_async must be an instance of WeaviateAsyncClient"
1063+
):
1064+
async with docsearch_no_async._atenant_context(tenant="test_tenant"):
1065+
pass

0 commit comments

Comments
 (0)