diff --git a/CHANGELOG.md b/CHANGELOG.md index d1cfd9e4b..e199da610 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ ## Next +### Added +- Added support for Amazon Bedrock embeddings via `BedrockEmbedding` class. +- Users can now leverage Bedrock-hosted Embedding models for vector generation. +- Added unit test and conducted Unit test + ## 1.9.0 ### Fixed diff --git a/src/neo4j_graphrag/embeddings/__init__.py b/src/neo4j_graphrag/embeddings/__init__.py index 9398eefe7..91f475a7e 100644 --- a/src/neo4j_graphrag/embeddings/__init__.py +++ b/src/neo4j_graphrag/embeddings/__init__.py @@ -19,6 +19,7 @@ from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from .sentence_transformers import SentenceTransformerEmbeddings from .vertexai import VertexAIEmbeddings +from .bedrockembeddings import BedrockEmbeddings __all__ = [ "Embedder", @@ -29,4 +30,5 @@ "VertexAIEmbeddings", "MistralAIEmbeddings", "CohereEmbeddings", + "BedrockEmbeddings" ] diff --git a/src/neo4j_graphrag/embeddings/bedrockembeddings.py b/src/neo4j_graphrag/embeddings/bedrockembeddings.py new file mode 100644 index 000000000..a91902640 --- /dev/null +++ b/src/neo4j_graphrag/embeddings/bedrockembeddings.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import List +import boto3 +import json +import time +from neo4j_graphrag.embeddings.base import Embedder +from neo4j_graphrag.exceptions import EmbeddingsGenerationError + + +class BedrockEmbeddings(Embedder): + """ + Embedder implementation using Amazon Bedrock's Titan Text Embedding model. + + This class integrates with AWS Bedrock via `boto3` and uses the Titan Embedding + model (`amazon.titan-embed-text-v2:0`) to generate 1536-dimensional vector + representations for input text. + + Example: + >>> embedder = BedrockEmbeddings() + >>> embedding = embedder.embed_query("Neo4j integrates well with Bedrock.") + >>> len(embedding) + 1536 + + Notes: + - Embeddings returned are 1536-dimensional vectors. + - A short sleep delay is applied to avoid throttling. + - This class uses the default AWS credentials chain supported by `boto3`. + + AWS Authentication: + The following authentication methods are supported through boto3: + + - Environment variables: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_SESSION_TOKEN` (if needed) + - AWS credentials/config files (e.g., `~/.aws/credentials`) + - IAM roles (if running on EC2, Lambda, SageMaker, etc.) + - AWS CLI named profile via `AWS_PROFILE` environment variable + """ + + def __init__( + self, + model_id: str = 'amazon.titan-embed-text-v2:0', + region: str = 'us-east-1' + ): + """ + Initialize the BedrockEmbeddings instance. + + Args: + model_id (str): Identifier for the Bedrock Titan embedding model. + Default is 'amazon.titan-embed-text-v2:0'. + region (str): AWS region where the Bedrock service is hosted. + Default is 'us-east-1'. + """ + self.model_id = model_id + self.bedrock = boto3.client('bedrock-runtime', region_name=region) + + def embed_query(self, text: str) -> List[float]: + """ + Generate a vector embedding for the input text using Amazon Bedrock. + + Args: + text (str): The input text string to be embedded. + + Returns: + List[float]: A 1536-dimensional list representing the text embedding. + + Raises: + EmbeddingsGenerationError: If an error occurs during the embedding process. + """ + try: + response = self.bedrock.invoke_model( + modelId=self.model_id, + contentType='application/json', + accept='application/json', + body=json.dumps({"inputText": text}) + ) + body = json.loads(response['body'].read()) + time.sleep(0.05) # To prevent throttling + return body['embedding'] + except Exception as e: + raise EmbeddingsGenerationError(f"Issue Generating Embeddings: {e}") diff --git a/tests/unit/embeddings/test_bedrockembedding.py b/tests/unit/embeddings/test_bedrockembedding.py new file mode 100644 index 000000000..aab9cc430 --- /dev/null +++ b/tests/unit/embeddings/test_bedrockembedding.py @@ -0,0 +1,48 @@ +from unittest.mock import patch, MagicMock +import pytest +import json + +from neo4j_graphrag.embeddings.bedrockembeddings import BedrockEmbeddings +from neo4j_graphrag.exceptions import EmbeddingsGenerationError + + +@patch("neo4j_graphrag.embeddings.bedrockembeddings.boto3.client") +def test_bedrock_embedder_happy_path(mock_boto_client): + # Mock AWS response with valid embedding + fake_embedding = [0.1] * 1024 + fake_response = { + "embedding": fake_embedding + } + + # Mock the .read() to return the fake response as JSON bytes + mock_body = MagicMock() + mock_body.read.return_value = json.dumps(fake_response).encode("utf-8") + + # Mock the bedrock client + mock_bedrock_client = MagicMock() + mock_bedrock_client.invoke_model.return_value = {"body": mock_body} + mock_boto_client.return_value = mock_bedrock_client + + # Instantiate the embedder and run embed_query + embedder = BedrockEmbeddings() + result = embedder.embed_query("Hello, Bedrock!") + + # Assertions + assert isinstance(result, list) + assert len(result) == 1024 + assert result == fake_embedding + + +@patch("neo4j_graphrag.embeddings.bedrockembeddings.boto3.client") +def test_bedrock_embedder_error_path(mock_boto_client): + # Simulate AWS client raising an exception + mock_bedrock_client = MagicMock() + mock_bedrock_client.invoke_model.side_effect = Exception("AWS error") + mock_boto_client.return_value = mock_bedrock_client + + embedder = BedrockEmbeddings() + + with pytest.raises(EmbeddingsGenerationError) as exc_info: + embedder.embed_query("This will fail.") + + assert "Issue Generating Embeddings" in str(exc_info.value)