Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ public class S3VectorStore extends AbstractObservationVectorStore implements Ini

private final S3VectorFilterExpressionConverter filterExpressionConverter;

private final String contentMetadataKeyName;

/**
* Creates a new S3VectorStore instance with the specified builder settings.
* Initializes observation-related components and the embedding model.
Expand All @@ -70,11 +72,13 @@ protected S3VectorStore(Builder builder) {
Assert.notNull(builder.vectorBucketName, "vectorBucketName must not be null");
Assert.notNull(builder.indexName, "indexName must not be null");
Assert.notNull(builder.s3VectorsClient, "S3VectorsClient must not be null");
Assert.notNull(builder.contentMetadataKeyName, "contentMetadataKeyName must not be null");

this.s3VectorsClient = builder.s3VectorsClient;
this.indexName = builder.indexName;
this.filterExpressionConverter = builder.filterExpressionConverter;
this.vectorBucketName = builder.vectorBucketName;
this.contentMetadataKeyName = builder.contentMetadataKeyName;
}

@Override
Expand All @@ -89,13 +93,16 @@ public void doAdd(List<Document> documents) {
for (Document document : documents) {
float[] embs = embedding.get(documents.indexOf(document));
VectorData vectorData = constructVectorData(embs);
Map<String, Object> metadataWithText = new HashMap<>(document.getMetadata());
metadataWithText.put(this.contentMetadataKeyName, document.getText());
vectors.add(PutInputVector.builder()
.data(vectorData)
.key(document.getId())
.metadata(constructMetadata(document.getMetadata()))
.metadata(constructMetadata(metadataWithText))
.build());
}
requestBuilder.vectors(vectors);

this.s3VectorsClient.putVectors(requestBuilder.build());
}

Expand Down Expand Up @@ -163,10 +170,14 @@ private Document toDocument(QueryOutputVector vector) {
if (metadata == null) {
metadata = new HashMap<>();
}
String text = (String) metadata.remove(this.contentMetadataKeyName);
if (text == null) {
text = "";
}
if (vector.distance() != null) {
metadata.put("SPRING_AI_S3_DISTANCE", vector.distance());
}
return Document.builder().metadata(metadata).text(vector.key()).build();
return Document.builder().id(vector.key()).metadata(metadata).text(text).build();
}

private static software.amazon.awssdk.core.document.Document constructMetadata(
Expand Down Expand Up @@ -207,6 +218,8 @@ public <T> Optional<T> getNativeClient() {

public static class Builder extends AbstractVectorStoreBuilder<Builder> {

private String contentMetadataKeyName = "SPRING_AI_VECTOR_CONTENT_KEY";

private final S3VectorsClient s3VectorsClient;

private @Nullable String vectorBucketName;
Expand All @@ -227,6 +240,13 @@ public Builder vectorBucketName(String vectorBucketName) {
return this;
}

public Builder contentMetadataKeyName(String contentMetadataKeyName) {
Assert.notNull(contentMetadataKeyName, "contentMetadataKeyName must not be null");
this.contentMetadataKeyName = contentMetadataKeyName;
return this;

}

public Builder indexName(String indexName) {
Assert.notNull(indexName, "indexName must not be null");
this.indexName = indexName;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.vectorstore.s3;

import java.util.List;
import java.util.Map;

import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import software.amazon.awssdk.services.s3vectors.S3VectorsClient;
import software.amazon.awssdk.services.s3vectors.model.PutInputVector;
import software.amazon.awssdk.services.s3vectors.model.PutVectorsRequest;
import software.amazon.awssdk.services.s3vectors.model.QueryOutputVector;
import software.amazon.awssdk.services.s3vectors.model.QueryVectorsRequest;
import software.amazon.awssdk.services.s3vectors.model.QueryVectorsResponse;
import software.amazon.awssdk.services.s3vectors.model.VectorData;

import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.SearchRequest;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

/**
* Integration tests for S3VectorStore.
*
* @author Matej Nedic
*/
class S3VectorStoreIT {

@Test
void testAddDocumentStoresTextInMetadata() {
S3VectorsClient mockClient = mock(S3VectorsClient.class);
EmbeddingModel mockEmbedding = mock(EmbeddingModel.class);

when(mockEmbedding.embed(any(List.class), any(), any())).thenReturn(List.of(new float[] { 0.1f, 0.2f, 0.3f }));
when(mockEmbedding.dimensions()).thenReturn(3);

S3VectorStore vectorStore = new S3VectorStore.Builder(mockClient, mockEmbedding).vectorBucketName("test-bucket")
.indexName("test-index")
.build();

Document doc = new Document("test-id", "Test document content", Map.of("key1", "value1"));

vectorStore.add(List.of(doc));

ArgumentCaptor<PutVectorsRequest> requestCaptor = ArgumentCaptor.forClass(PutVectorsRequest.class);
verify(mockClient).putVectors(requestCaptor.capture());

PutVectorsRequest request = requestCaptor.getValue();
assertThat(request.vectors()).hasSize(1);

PutInputVector vector = request.vectors().get(0);
assertThat(vector.key()).isEqualTo("test-id");

software.amazon.awssdk.core.document.Document metadata = vector.metadata();
assertThat(metadata.asMap()).containsEntry("SPRING_AI_VECTOR_CONTENT_KEY",
software.amazon.awssdk.core.document.Document.fromString("Test document content"));
assertThat(metadata.asMap()).containsEntry("key1",
software.amazon.awssdk.core.document.Document.fromString("value1"));
}

@Test
void testSearchReturnsDocumentText() {
S3VectorsClient mockClient = mock(S3VectorsClient.class);
EmbeddingModel mockEmbedding = mock(EmbeddingModel.class);

when(mockEmbedding.embed(any(String.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f });
when(mockEmbedding.dimensions()).thenReturn(3);

software.amazon.awssdk.core.document.Document metadata = software.amazon.awssdk.core.document.Document
.fromMap(Map.of("SPRING_AI_VECTOR_CONTENT_KEY",
software.amazon.awssdk.core.document.Document.fromString("Retrieved content"), "key1",
software.amazon.awssdk.core.document.Document.fromString("value1")));

QueryOutputVector outputVector = QueryOutputVector.builder()
.key("doc-id")
.metadata(metadata)
.distance(0.95f)
.data(VectorData.builder().float32(List.of(0.1f, 0.2f, 0.3f)).build())
.build();

QueryVectorsResponse response = QueryVectorsResponse.builder().vectors(List.of(outputVector)).build();

when(mockClient.queryVectors(any(QueryVectorsRequest.class))).thenReturn(response);

S3VectorStore vectorStore = new S3VectorStore.Builder(mockClient, mockEmbedding).vectorBucketName("test-bucket")
.indexName("test-index")
.build();

List<Document> results = vectorStore
.similaritySearch(SearchRequest.builder().query("test query").topK(1).build());

assertThat(results).hasSize(1);
Document result = results.get(0);
assertThat(result.getId()).isEqualTo("doc-id");
assertThat(result.getText()).isEqualTo("Retrieved content");
assertThat(result.getMetadata()).containsEntry("key1", "value1");
assertThat(result.getMetadata()).containsEntry("SPRING_AI_S3_DISTANCE", 0.95f);
assertThat(result.getMetadata()).doesNotContainKey("SPRING_AI_VECTOR_CONTENT_KEY");
}

@Test
void testSearchWithNullContentReturnsEmptyString() {
S3VectorsClient mockClient = mock(S3VectorsClient.class);
EmbeddingModel mockEmbedding = mock(EmbeddingModel.class);

when(mockEmbedding.embed(any(String.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f });
when(mockEmbedding.dimensions()).thenReturn(3);

software.amazon.awssdk.core.document.Document metadata = software.amazon.awssdk.core.document.Document
.fromMap(Map.of("key1", software.amazon.awssdk.core.document.Document.fromString("value1")));

QueryOutputVector outputVector = QueryOutputVector.builder()
.key("doc-id")
.metadata(metadata)
.data(VectorData.builder().float32(List.of(0.1f, 0.2f, 0.3f)).build())
.build();

QueryVectorsResponse response = QueryVectorsResponse.builder().vectors(List.of(outputVector)).build();

when(mockClient.queryVectors(any(QueryVectorsRequest.class))).thenReturn(response);

S3VectorStore vectorStore = new S3VectorStore.Builder(mockClient, mockEmbedding).vectorBucketName("test-bucket")
.indexName("test-index")
.build();

List<Document> results = vectorStore
.similaritySearch(SearchRequest.builder().query("test query").topK(1).build());

assertThat(results).hasSize(1);
assertThat(results.get(0).getText()).isEqualTo("");
}

@Test
void testCustomContentMetadataKeyName() {
S3VectorsClient mockClient = mock(S3VectorsClient.class);
EmbeddingModel mockEmbedding = mock(EmbeddingModel.class);

when(mockEmbedding.embed(any(List.class), any(), any())).thenReturn(List.of(new float[] { 0.1f, 0.2f, 0.3f }));
when(mockEmbedding.embed(any(String.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f });
when(mockEmbedding.dimensions()).thenReturn(3);

S3VectorStore vectorStore = new S3VectorStore.Builder(mockClient, mockEmbedding).vectorBucketName("test-bucket")
.indexName("test-index")
.contentMetadataKeyName("content")
.build();

Document doc = new Document("test-id", "Custom key text", Map.of("key1", "value1"));
vectorStore.add(List.of(doc));

ArgumentCaptor<PutVectorsRequest> requestCaptor = ArgumentCaptor.forClass(PutVectorsRequest.class);
verify(mockClient).putVectors(requestCaptor.capture());

software.amazon.awssdk.core.document.Document metadata = requestCaptor.getValue().vectors().get(0).metadata();
assertThat(metadata.asMap()).containsEntry("content",
software.amazon.awssdk.core.document.Document.fromString("Custom key text"));

software.amazon.awssdk.core.document.Document responseMetadata = software.amazon.awssdk.core.document.Document
.fromMap(Map.of("content", software.amazon.awssdk.core.document.Document.fromString("Retrieved text"),
"key1", software.amazon.awssdk.core.document.Document.fromString("value1")));

QueryOutputVector outputVector = QueryOutputVector.builder()
.key("doc-id")
.metadata(responseMetadata)
.data(VectorData.builder().float32(List.of(0.1f, 0.2f, 0.3f)).build())
.build();

when(mockClient.queryVectors(any(QueryVectorsRequest.class)))
.thenReturn(QueryVectorsResponse.builder().vectors(List.of(outputVector)).build());

List<Document> results = vectorStore
.similaritySearch(SearchRequest.builder().query("test query").topK(1).build());

assertThat(results.get(0).getText()).isEqualTo("Retrieved text");
assertThat(results.get(0).getMetadata()).doesNotContainKey("content");
}

}