Skip to content

Commit 4e1e783

Browse files
authored
Fix VertexAIEmbeddings (#402)
* Fix VertexAIEmbeddings * Ruff * Undo changes to poetry.lock
1 parent e91453b commit 4e1e783

File tree

4 files changed

+19
-16
lines changed

4 files changed

+19
-16
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
- Fixed documentation for PdfLoader
88
- Fixed a bug where the `format` argument for `OllamaLLM` was not propagated to the client.
9-
9+
- Fixed an import error in `VertexAIEmbeddings`.
1010

1111
## 1.9.0
1212

examples/customize/embeddings/vertexai_embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44

55
from neo4j_graphrag.embeddings import VertexAIEmbeddings
66

7-
embeder = VertexAIEmbeddings(model="text-embedding-004")
7+
embeder = VertexAIEmbeddings(model="text-embedding-005")
88
res = embeder.embed_query("my question")
99
print(res[:10])

src/neo4j_graphrag/embeddings/vertexai.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,18 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
from typing import Any
17+
from typing import Any, TYPE_CHECKING
1818

1919
from neo4j_graphrag.embeddings.base import Embedder
2020

2121
try:
22-
import vertexai
23-
except ImportError:
24-
vertexai = None # type: ignore[assignment]
22+
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
23+
except (ImportError, AttributeError):
24+
TextEmbeddingModel = TextEmbeddingInput = None # type: ignore[misc, assignment]
25+
26+
27+
if TYPE_CHECKING:
28+
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
2529

2630

2731
class VertexAIEmbeddings(Embedder):
@@ -34,14 +38,12 @@ class VertexAIEmbeddings(Embedder):
3438
"""
3539

3640
def __init__(self, model: str = "text-embedding-004") -> None:
37-
if vertexai is None:
41+
if TextEmbeddingModel is None:
3842
raise ImportError(
3943
"""Could not import Vertex AI Python client.
4044
Please install it with `pip install "neo4j-graphrag[google]"`."""
4145
)
42-
self.vertexai_model = (
43-
vertexai.language_models.TextEmbeddingModel.from_pretrained(model)
44-
)
46+
self.model = TextEmbeddingModel.from_pretrained(model)
4547

4648
def embed_query(
4749
self, text: str, task_type: str = "RETRIEVAL_QUERY", **kwargs: Any
@@ -54,6 +56,7 @@ def embed_query(
5456
task_type (str): The type of the text embedding task. Defaults to "RETRIEVAL_QUERY". See https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#tasktype for a full list.
5557
**kwargs (Any): Additional keyword arguments to pass to the Vertex AI client's get_embeddings method.
5658
"""
57-
inputs = [vertexai.language_models.TextEmbeddingInput(text, task_type)]
58-
embeddings = self.vertexai_model.get_embeddings(inputs, **kwargs)
59-
return embeddings[0].values # type: ignore
59+
# type annotation needed for mypy
60+
inputs: list[str | TextEmbeddingInput] = [TextEmbeddingInput(text, task_type)]
61+
embeddings = self.model.get_embeddings(inputs, **kwargs)
62+
return embeddings[0].values

tests/unit/embeddings/test_vertexai_embedder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@
1818
from neo4j_graphrag.embeddings.vertexai import VertexAIEmbeddings
1919

2020

21-
@patch("neo4j_graphrag.embeddings.vertexai.vertexai", None)
21+
@patch("neo4j_graphrag.embeddings.vertexai.TextEmbeddingModel", None)
2222
def test_vertexai_embedder_missing_dependency() -> None:
2323
with pytest.raises(ImportError):
2424
VertexAIEmbeddings()
2525

2626

27-
@patch("neo4j_graphrag.embeddings.vertexai.vertexai")
27+
@patch("neo4j_graphrag.embeddings.vertexai.TextEmbeddingModel")
2828
def test_vertexai_embedder_happy_path(mock_vertexai: Mock) -> None:
29-
mock_vertexai.language_models.TextEmbeddingModel.from_pretrained.return_value.get_embeddings.return_value = [
29+
mock_vertexai.from_pretrained.return_value.get_embeddings.return_value = [
3030
MagicMock(values=[1.0, 2.0])
3131
]
3232
embedder = VertexAIEmbeddings()

0 commit comments

Comments
 (0)