Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/customize/embeddings/custom_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

class CustomEmbeddings(Embedder):
def __init__(self, dimension: int = 10, **kwargs: Any):
super().__init__(**kwargs)
self.dimension = dimension

def _embed_query(self, input: str) -> list[float]:
def embed_query(self, input: str) -> list[float]:
return [random.random() for _ in range(self.dimension)]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

# Create Embedder object
class CustomEmbedder(Embedder):
def _embed_query(self, text: str) -> list[float]:
def embed_query(self, text: str) -> list[float]:
return [random() for _ in range(DIMENSION)]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

# Create Embedder object
class CustomEmbedder(Embedder):
def _embed_query(self, text: str) -> list[float]:
def embed_query(self, text: str) -> list[float]:
return [random() for _ in range(DIMENSION)]


Expand Down
15 changes: 1 addition & 14 deletions src/neo4j_graphrag/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from neo4j_graphrag.utils.rate_limit import (
DEFAULT_RATE_LIMIT_HANDLER,
RateLimitHandler,
rate_limit_handler,
)


Expand All @@ -39,20 +38,8 @@ def __init__(self, rate_limit_handler: Optional[RateLimitHandler] = None):
else:
self._rate_limit_handler = DEFAULT_RATE_LIMIT_HANDLER

@rate_limit_handler
def embed_query(self, text: str) -> list[float]:
"""Embed query text.

Args:
text (str): Text to convert to vector embedding

Returns:
list[float]: A vector embedding.
"""
return self._embed_query(text)

@abstractmethod
def _embed_query(self, text: str) -> list[float]:
def embed_query(self, text: str) -> list[float]:
"""Embed query text.

Args:
Expand Down
5 changes: 3 additions & 2 deletions src/neo4j_graphrag/embeddings/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
from neo4j_graphrag.utils.rate_limit import RateLimitHandler
from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler

try:
import cohere
Expand All @@ -42,7 +42,8 @@ def __init__(
self.model = model
self.client = cohere.Client(**kwargs)

def _embed_query(self, text: str, **kwargs: Any) -> list[float]:
@rate_limit_handler
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
try:
response = self.client.embed(
texts=[text],
Expand Down
5 changes: 3 additions & 2 deletions src/neo4j_graphrag/embeddings/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
from neo4j_graphrag.utils.rate_limit import RateLimitHandler
from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler

try:
from mistralai import Mistral
Expand Down Expand Up @@ -55,7 +55,8 @@ def __init__(
self.model = model
self.mistral_client = Mistral(api_key=api_key, **kwargs)

def _embed_query(self, text: str, **kwargs: Any) -> list[float]:
@rate_limit_handler
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
"""
Generate embeddings for a given query using a Mistral AI text embedding model.

Expand Down
5 changes: 3 additions & 2 deletions src/neo4j_graphrag/embeddings/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
from neo4j_graphrag.utils.rate_limit import RateLimitHandler
from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler


class OllamaEmbeddings(Embedder):
Expand Down Expand Up @@ -48,7 +48,8 @@ def __init__(
self.model = model
self.client = ollama.Client(**kwargs)

def _embed_query(self, text: str, **kwargs: Any) -> list[float]:
@rate_limit_handler
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
"""
Generate embeddings for a given query using an Ollama text embedding model.

Expand Down
5 changes: 3 additions & 2 deletions src/neo4j_graphrag/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
from neo4j_graphrag.utils.rate_limit import RateLimitHandler
from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler

if TYPE_CHECKING:
import openai
Expand Down Expand Up @@ -59,7 +59,8 @@ def _initialize_client(self, **kwargs: Any) -> Any:
"""
pass

def _embed_query(self, text: str, **kwargs: Any) -> list[float]:
@rate_limit_handler
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
"""
Generate embeddings for a given query using an OpenAI text embedding model.

Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/embeddings/sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
self.np = np
self.model = sentence_transformers.SentenceTransformer(model, *args, **kwargs)

def _embed_query(self, text: str) -> Any:
def embed_query(self, text: str) -> Any:
try:
result = self.model.encode([text])

Expand Down
5 changes: 3 additions & 2 deletions src/neo4j_graphrag/embeddings/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
from neo4j_graphrag.utils.rate_limit import RateLimitHandler
from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler

try:
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
Expand Down Expand Up @@ -52,7 +52,8 @@ def __init__(
super().__init__(rate_limit_handler)
self.model = TextEmbeddingModel.from_pretrained(model)

def _embed_query(
@rate_limit_handler
def embed_query(
self, text: str, task_type: str = "RETRIEVAL_QUERY", **kwargs: Any
) -> list[float]:
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def embedder() -> Embedder:


class RandomEmbedder(Embedder):
def _embed_query(self, text: str) -> list[float]:
def embed_query(self, text: str) -> list[float]:
return [random.random() for _ in range(1536)]


class BiologyEmbedder(Embedder):
def _embed_query(self, text: str) -> list[float]:
def embed_query(self, text: str) -> list[float]:
if text == "biology":
return EMBEDDING_BIOLOGY
raise ValueError(f"Unknown embedding text: {text}")
Expand Down
49 changes: 0 additions & 49 deletions tests/unit/embeddings/test_sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
import pytest
import torch
from tenacity import RetryError
from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.embeddings.sentence_transformers import (
SentenceTransformerEmbeddings,
Expand Down Expand Up @@ -76,51 +75,3 @@ def test_embed_query_non_retryable_error_handling(mock_import: Mock) -> None:

# Verify the model was called only once (no retries for non-rate-limit errors)
assert mock_model.encode.call_count == 1


@patch("builtins.__import__")
def test_embed_query_rate_limit_error_retries(mock_import: Mock) -> None:
"""Test that rate limit errors are retried the expected number of times."""
MockSentenceTransformer = get_mock_sentence_transformers()
mock_import.return_value = MockSentenceTransformer
mock_model = MockSentenceTransformer.SentenceTransformer.return_value

# Rate limit error that should trigger retries (matches "too many requests" pattern)
# Create separate exception instances for each retry attempt
mock_model.encode.side_effect = [
Exception("too many requests - please wait"),
Exception("too many requests - please wait"),
Exception("too many requests - please wait"),
]

instance = SentenceTransformerEmbeddings()

# After exhausting retries, tenacity raises RetryError (since retries should work)
with pytest.raises(RetryError):
instance.embed_query("test query")

# Verify the model was called 3 times (default max_attempts for RetryRateLimitHandler)
assert mock_model.encode.call_count == 3


@patch("builtins.__import__")
def test_embed_query_rate_limit_error_eventual_success(mock_import: Mock) -> None:
"""Test that rate limit errors eventually succeed after retries."""
MockSentenceTransformer = get_mock_sentence_transformers()
mock_import.return_value = MockSentenceTransformer
mock_model = MockSentenceTransformer.SentenceTransformer.return_value

# First two calls fail with rate limit, third succeeds
mock_model.encode.side_effect = [
Exception("too many requests - please wait"),
Exception("too many requests - please wait"),
np.array([[0.1, 0.2, 0.3]]),
]

instance = SentenceTransformerEmbeddings()
result = instance.embed_query("test query")

# Verify successful result
assert result == [0.1, 0.2, 0.3]
# Verify the model was called 3 times before succeeding
assert mock_model.encode.call_count == 3
Loading