diff --git a/examples/customize/embeddings/custom_embeddings.py b/examples/customize/embeddings/custom_embeddings.py index e77127359..61517c159 100644 --- a/examples/customize/embeddings/custom_embeddings.py +++ b/examples/customize/embeddings/custom_embeddings.py @@ -6,9 +6,10 @@ class CustomEmbeddings(Embedder): def __init__(self, dimension: int = 10, **kwargs: Any): + super().__init__(**kwargs) self.dimension = dimension - def _embed_query(self, input: str) -> list[float]: + def embed_query(self, input: str) -> list[float]: return [random.random() for _ in range(self.dimension)] diff --git a/examples/customize/retrievers/hybrid_retrievers/hybrid_cypher_search.py b/examples/customize/retrievers/hybrid_retrievers/hybrid_cypher_search.py index 6268e82fd..e1d59e379 100644 --- a/examples/customize/retrievers/hybrid_retrievers/hybrid_cypher_search.py +++ b/examples/customize/retrievers/hybrid_retrievers/hybrid_cypher_search.py @@ -20,7 +20,7 @@ # Create Embedder object class CustomEmbedder(Embedder): - def _embed_query(self, text: str) -> list[float]: + def embed_query(self, text: str) -> list[float]: return [random() for _ in range(DIMENSION)] diff --git a/examples/customize/retrievers/hybrid_retrievers/hybrid_search.py b/examples/customize/retrievers/hybrid_retrievers/hybrid_search.py index b9b9dd792..69940596a 100644 --- a/examples/customize/retrievers/hybrid_retrievers/hybrid_search.py +++ b/examples/customize/retrievers/hybrid_retrievers/hybrid_search.py @@ -20,7 +20,7 @@ # Create Embedder object class CustomEmbedder(Embedder): - def _embed_query(self, text: str) -> list[float]: + def embed_query(self, text: str) -> list[float]: return [random() for _ in range(DIMENSION)] diff --git a/src/neo4j_graphrag/embeddings/base.py b/src/neo4j_graphrag/embeddings/base.py index 02e5b7a51..34c0c7b59 100644 --- a/src/neo4j_graphrag/embeddings/base.py +++ b/src/neo4j_graphrag/embeddings/base.py @@ -20,7 +20,6 @@ from neo4j_graphrag.utils.rate_limit import ( DEFAULT_RATE_LIMIT_HANDLER, RateLimitHandler, - rate_limit_handler, ) @@ -39,20 +38,8 @@ def __init__(self, rate_limit_handler: Optional[RateLimitHandler] = None): else: self._rate_limit_handler = DEFAULT_RATE_LIMIT_HANDLER - @rate_limit_handler - def embed_query(self, text: str) -> list[float]: - """Embed query text. - - Args: - text (str): Text to convert to vector embedding - - Returns: - list[float]: A vector embedding. - """ - return self._embed_query(text) - @abstractmethod - def _embed_query(self, text: str) -> list[float]: + def embed_query(self, text: str) -> list[float]: """Embed query text. Args: diff --git a/src/neo4j_graphrag/embeddings/cohere.py b/src/neo4j_graphrag/embeddings/cohere.py index 6d89fcca0..9e371296b 100644 --- a/src/neo4j_graphrag/embeddings/cohere.py +++ b/src/neo4j_graphrag/embeddings/cohere.py @@ -18,7 +18,7 @@ from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError -from neo4j_graphrag.utils.rate_limit import RateLimitHandler +from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler try: import cohere @@ -42,7 +42,8 @@ def __init__( self.model = model self.client = cohere.Client(**kwargs) - def _embed_query(self, text: str, **kwargs: Any) -> list[float]: + @rate_limit_handler + def embed_query(self, text: str, **kwargs: Any) -> list[float]: try: response = self.client.embed( texts=[text], diff --git a/src/neo4j_graphrag/embeddings/mistral.py b/src/neo4j_graphrag/embeddings/mistral.py index 2b1c3d284..6dc486f41 100644 --- a/src/neo4j_graphrag/embeddings/mistral.py +++ b/src/neo4j_graphrag/embeddings/mistral.py @@ -20,7 +20,7 @@ from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError -from neo4j_graphrag.utils.rate_limit import RateLimitHandler +from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler try: from mistralai import Mistral @@ -55,7 +55,8 @@ def __init__( self.model = model self.mistral_client = Mistral(api_key=api_key, **kwargs) - def _embed_query(self, text: str, **kwargs: Any) -> list[float]: + @rate_limit_handler + def embed_query(self, text: str, **kwargs: Any) -> list[float]: """ Generate embeddings for a given query using a Mistral AI text embedding model. diff --git a/src/neo4j_graphrag/embeddings/ollama.py b/src/neo4j_graphrag/embeddings/ollama.py index e70fe96be..88f850963 100644 --- a/src/neo4j_graphrag/embeddings/ollama.py +++ b/src/neo4j_graphrag/embeddings/ollama.py @@ -19,7 +19,7 @@ from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError -from neo4j_graphrag.utils.rate_limit import RateLimitHandler +from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler class OllamaEmbeddings(Embedder): @@ -48,7 +48,8 @@ def __init__( self.model = model self.client = ollama.Client(**kwargs) - def _embed_query(self, text: str, **kwargs: Any) -> list[float]: + @rate_limit_handler + def embed_query(self, text: str, **kwargs: Any) -> list[float]: """ Generate embeddings for a given query using an Ollama text embedding model. diff --git a/src/neo4j_graphrag/embeddings/openai.py b/src/neo4j_graphrag/embeddings/openai.py index 9bcf5df70..a987ec3ef 100644 --- a/src/neo4j_graphrag/embeddings/openai.py +++ b/src/neo4j_graphrag/embeddings/openai.py @@ -20,7 +20,7 @@ from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError -from neo4j_graphrag.utils.rate_limit import RateLimitHandler +from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler if TYPE_CHECKING: import openai @@ -59,7 +59,8 @@ def _initialize_client(self, **kwargs: Any) -> Any: """ pass - def _embed_query(self, text: str, **kwargs: Any) -> list[float]: + @rate_limit_handler + def embed_query(self, text: str, **kwargs: Any) -> list[float]: """ Generate embeddings for a given query using an OpenAI text embedding model. diff --git a/src/neo4j_graphrag/embeddings/sentence_transformers.py b/src/neo4j_graphrag/embeddings/sentence_transformers.py index 8dca9c4f6..f49619ab6 100644 --- a/src/neo4j_graphrag/embeddings/sentence_transformers.py +++ b/src/neo4j_graphrag/embeddings/sentence_transformers.py @@ -42,7 +42,7 @@ def __init__( self.np = np self.model = sentence_transformers.SentenceTransformer(model, *args, **kwargs) - def _embed_query(self, text: str) -> Any: + def embed_query(self, text: str) -> Any: try: result = self.model.encode([text]) diff --git a/src/neo4j_graphrag/embeddings/vertexai.py b/src/neo4j_graphrag/embeddings/vertexai.py index e1792816f..ab01e9ac6 100644 --- a/src/neo4j_graphrag/embeddings/vertexai.py +++ b/src/neo4j_graphrag/embeddings/vertexai.py @@ -18,7 +18,7 @@ from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError -from neo4j_graphrag.utils.rate_limit import RateLimitHandler +from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler try: from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel @@ -52,7 +52,8 @@ def __init__( super().__init__(rate_limit_handler) self.model = TextEmbeddingModel.from_pretrained(model) - def _embed_query( + @rate_limit_handler + def embed_query( self, text: str, task_type: str = "RETRIEVAL_QUERY", **kwargs: Any ) -> list[float]: """ diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 4ae462a12..9932f12e5 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -58,12 +58,12 @@ def embedder() -> Embedder: class RandomEmbedder(Embedder): - def _embed_query(self, text: str) -> list[float]: + def embed_query(self, text: str) -> list[float]: return [random.random() for _ in range(1536)] class BiologyEmbedder(Embedder): - def _embed_query(self, text: str) -> list[float]: + def embed_query(self, text: str) -> list[float]: if text == "biology": return EMBEDDING_BIOLOGY raise ValueError(f"Unknown embedding text: {text}") diff --git a/tests/unit/embeddings/test_sentence_transformers.py b/tests/unit/embeddings/test_sentence_transformers.py index f9c8f36bc..23cb58353 100644 --- a/tests/unit/embeddings/test_sentence_transformers.py +++ b/tests/unit/embeddings/test_sentence_transformers.py @@ -3,7 +3,6 @@ import numpy as np import pytest import torch -from tenacity import RetryError from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.embeddings.sentence_transformers import ( SentenceTransformerEmbeddings, @@ -76,51 +75,3 @@ def test_embed_query_non_retryable_error_handling(mock_import: Mock) -> None: # Verify the model was called only once (no retries for non-rate-limit errors) assert mock_model.encode.call_count == 1 - - -@patch("builtins.__import__") -def test_embed_query_rate_limit_error_retries(mock_import: Mock) -> None: - """Test that rate limit errors are retried the expected number of times.""" - MockSentenceTransformer = get_mock_sentence_transformers() - mock_import.return_value = MockSentenceTransformer - mock_model = MockSentenceTransformer.SentenceTransformer.return_value - - # Rate limit error that should trigger retries (matches "too many requests" pattern) - # Create separate exception instances for each retry attempt - mock_model.encode.side_effect = [ - Exception("too many requests - please wait"), - Exception("too many requests - please wait"), - Exception("too many requests - please wait"), - ] - - instance = SentenceTransformerEmbeddings() - - # After exhausting retries, tenacity raises RetryError (since retries should work) - with pytest.raises(RetryError): - instance.embed_query("test query") - - # Verify the model was called 3 times (default max_attempts for RetryRateLimitHandler) - assert mock_model.encode.call_count == 3 - - -@patch("builtins.__import__") -def test_embed_query_rate_limit_error_eventual_success(mock_import: Mock) -> None: - """Test that rate limit errors eventually succeed after retries.""" - MockSentenceTransformer = get_mock_sentence_transformers() - mock_import.return_value = MockSentenceTransformer - mock_model = MockSentenceTransformer.SentenceTransformer.return_value - - # First two calls fail with rate limit, third succeeds - mock_model.encode.side_effect = [ - Exception("too many requests - please wait"), - Exception("too many requests - please wait"), - np.array([[0.1, 0.2, 0.3]]), - ] - - instance = SentenceTransformerEmbeddings() - result = instance.embed_query("test query") - - # Verify successful result - assert result == [0.1, 0.2, 0.3] - # Verify the model was called 3 times before succeeding - assert mock_model.encode.call_count == 3