diff --git a/vector-stores/spring-ai-s3-vector-store/src/main/java/org/springframework/ai/vectorstore/s3/S3VectorStore.java b/vector-stores/spring-ai-s3-vector-store/src/main/java/org/springframework/ai/vectorstore/s3/S3VectorStore.java index f475923b1f3..95d20db20ee 100644 --- a/vector-stores/spring-ai-s3-vector-store/src/main/java/org/springframework/ai/vectorstore/s3/S3VectorStore.java +++ b/vector-stores/spring-ai-s3-vector-store/src/main/java/org/springframework/ai/vectorstore/s3/S3VectorStore.java @@ -59,6 +59,8 @@ public class S3VectorStore extends AbstractObservationVectorStore implements Ini private final S3VectorFilterExpressionConverter filterExpressionConverter; + private final String contentMetadataKeyName; + /** * Creates a new S3VectorStore instance with the specified builder settings. * Initializes observation-related components and the embedding model. @@ -70,11 +72,13 @@ protected S3VectorStore(Builder builder) { Assert.notNull(builder.vectorBucketName, "vectorBucketName must not be null"); Assert.notNull(builder.indexName, "indexName must not be null"); Assert.notNull(builder.s3VectorsClient, "S3VectorsClient must not be null"); + Assert.notNull(builder.contentMetadataKeyName, "contentMetadataKeyName must not be null"); this.s3VectorsClient = builder.s3VectorsClient; this.indexName = builder.indexName; this.filterExpressionConverter = builder.filterExpressionConverter; this.vectorBucketName = builder.vectorBucketName; + this.contentMetadataKeyName = builder.contentMetadataKeyName; } @Override @@ -89,13 +93,16 @@ public void doAdd(List documents) { for (Document document : documents) { float[] embs = embedding.get(documents.indexOf(document)); VectorData vectorData = constructVectorData(embs); + Map metadataWithText = new HashMap<>(document.getMetadata()); + metadataWithText.put(this.contentMetadataKeyName, document.getText()); vectors.add(PutInputVector.builder() .data(vectorData) .key(document.getId()) - .metadata(constructMetadata(document.getMetadata())) + .metadata(constructMetadata(metadataWithText)) .build()); } requestBuilder.vectors(vectors); + this.s3VectorsClient.putVectors(requestBuilder.build()); } @@ -163,10 +170,14 @@ private Document toDocument(QueryOutputVector vector) { if (metadata == null) { metadata = new HashMap<>(); } + String text = (String) metadata.remove(this.contentMetadataKeyName); + if (text == null) { + text = ""; + } if (vector.distance() != null) { metadata.put("SPRING_AI_S3_DISTANCE", vector.distance()); } - return Document.builder().metadata(metadata).text(vector.key()).build(); + return Document.builder().id(vector.key()).metadata(metadata).text(text).build(); } private static software.amazon.awssdk.core.document.Document constructMetadata( @@ -207,6 +218,8 @@ public Optional getNativeClient() { public static class Builder extends AbstractVectorStoreBuilder { + private String contentMetadataKeyName = "SPRING_AI_VECTOR_CONTENT_KEY"; + private final S3VectorsClient s3VectorsClient; private @Nullable String vectorBucketName; @@ -227,6 +240,13 @@ public Builder vectorBucketName(String vectorBucketName) { return this; } + public Builder contentMetadataKeyName(String contentMetadataKeyName) { + Assert.notNull(contentMetadataKeyName, "contentMetadataKeyName must not be null"); + this.contentMetadataKeyName = contentMetadataKeyName; + return this; + + } + public Builder indexName(String indexName) { Assert.notNull(indexName, "indexName must not be null"); this.indexName = indexName; diff --git a/vector-stores/spring-ai-s3-vector-store/src/test/java/org/springframework/ai/vectorstore/s3/S3VectorStoreIT.java b/vector-stores/spring-ai-s3-vector-store/src/test/java/org/springframework/ai/vectorstore/s3/S3VectorStoreIT.java new file mode 100644 index 00000000000..c89ec566449 --- /dev/null +++ b/vector-stores/spring-ai-s3-vector-store/src/test/java/org/springframework/ai/vectorstore/s3/S3VectorStoreIT.java @@ -0,0 +1,197 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore.s3; + +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.services.s3vectors.S3VectorsClient; +import software.amazon.awssdk.services.s3vectors.model.PutInputVector; +import software.amazon.awssdk.services.s3vectors.model.PutVectorsRequest; +import software.amazon.awssdk.services.s3vectors.model.QueryOutputVector; +import software.amazon.awssdk.services.s3vectors.model.QueryVectorsRequest; +import software.amazon.awssdk.services.s3vectors.model.QueryVectorsResponse; +import software.amazon.awssdk.services.s3vectors.model.VectorData; + +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.vectorstore.SearchRequest; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Integration tests for S3VectorStore. + * + * @author Matej Nedic + */ +class S3VectorStoreIT { + + @Test + void testAddDocumentStoresTextInMetadata() { + S3VectorsClient mockClient = mock(S3VectorsClient.class); + EmbeddingModel mockEmbedding = mock(EmbeddingModel.class); + + when(mockEmbedding.embed(any(List.class), any(), any())).thenReturn(List.of(new float[] { 0.1f, 0.2f, 0.3f })); + when(mockEmbedding.dimensions()).thenReturn(3); + + S3VectorStore vectorStore = new S3VectorStore.Builder(mockClient, mockEmbedding).vectorBucketName("test-bucket") + .indexName("test-index") + .build(); + + Document doc = new Document("test-id", "Test document content", Map.of("key1", "value1")); + + vectorStore.add(List.of(doc)); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(PutVectorsRequest.class); + verify(mockClient).putVectors(requestCaptor.capture()); + + PutVectorsRequest request = requestCaptor.getValue(); + assertThat(request.vectors()).hasSize(1); + + PutInputVector vector = request.vectors().get(0); + assertThat(vector.key()).isEqualTo("test-id"); + + software.amazon.awssdk.core.document.Document metadata = vector.metadata(); + assertThat(metadata.asMap()).containsEntry("SPRING_AI_VECTOR_CONTENT_KEY", + software.amazon.awssdk.core.document.Document.fromString("Test document content")); + assertThat(metadata.asMap()).containsEntry("key1", + software.amazon.awssdk.core.document.Document.fromString("value1")); + } + + @Test + void testSearchReturnsDocumentText() { + S3VectorsClient mockClient = mock(S3VectorsClient.class); + EmbeddingModel mockEmbedding = mock(EmbeddingModel.class); + + when(mockEmbedding.embed(any(String.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f }); + when(mockEmbedding.dimensions()).thenReturn(3); + + software.amazon.awssdk.core.document.Document metadata = software.amazon.awssdk.core.document.Document + .fromMap(Map.of("SPRING_AI_VECTOR_CONTENT_KEY", + software.amazon.awssdk.core.document.Document.fromString("Retrieved content"), "key1", + software.amazon.awssdk.core.document.Document.fromString("value1"))); + + QueryOutputVector outputVector = QueryOutputVector.builder() + .key("doc-id") + .metadata(metadata) + .distance(0.95f) + .data(VectorData.builder().float32(List.of(0.1f, 0.2f, 0.3f)).build()) + .build(); + + QueryVectorsResponse response = QueryVectorsResponse.builder().vectors(List.of(outputVector)).build(); + + when(mockClient.queryVectors(any(QueryVectorsRequest.class))).thenReturn(response); + + S3VectorStore vectorStore = new S3VectorStore.Builder(mockClient, mockEmbedding).vectorBucketName("test-bucket") + .indexName("test-index") + .build(); + + List results = vectorStore + .similaritySearch(SearchRequest.builder().query("test query").topK(1).build()); + + assertThat(results).hasSize(1); + Document result = results.get(0); + assertThat(result.getId()).isEqualTo("doc-id"); + assertThat(result.getText()).isEqualTo("Retrieved content"); + assertThat(result.getMetadata()).containsEntry("key1", "value1"); + assertThat(result.getMetadata()).containsEntry("SPRING_AI_S3_DISTANCE", 0.95f); + assertThat(result.getMetadata()).doesNotContainKey("SPRING_AI_VECTOR_CONTENT_KEY"); + } + + @Test + void testSearchWithNullContentReturnsEmptyString() { + S3VectorsClient mockClient = mock(S3VectorsClient.class); + EmbeddingModel mockEmbedding = mock(EmbeddingModel.class); + + when(mockEmbedding.embed(any(String.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f }); + when(mockEmbedding.dimensions()).thenReturn(3); + + software.amazon.awssdk.core.document.Document metadata = software.amazon.awssdk.core.document.Document + .fromMap(Map.of("key1", software.amazon.awssdk.core.document.Document.fromString("value1"))); + + QueryOutputVector outputVector = QueryOutputVector.builder() + .key("doc-id") + .metadata(metadata) + .data(VectorData.builder().float32(List.of(0.1f, 0.2f, 0.3f)).build()) + .build(); + + QueryVectorsResponse response = QueryVectorsResponse.builder().vectors(List.of(outputVector)).build(); + + when(mockClient.queryVectors(any(QueryVectorsRequest.class))).thenReturn(response); + + S3VectorStore vectorStore = new S3VectorStore.Builder(mockClient, mockEmbedding).vectorBucketName("test-bucket") + .indexName("test-index") + .build(); + + List results = vectorStore + .similaritySearch(SearchRequest.builder().query("test query").topK(1).build()); + + assertThat(results).hasSize(1); + assertThat(results.get(0).getText()).isEqualTo(""); + } + + @Test + void testCustomContentMetadataKeyName() { + S3VectorsClient mockClient = mock(S3VectorsClient.class); + EmbeddingModel mockEmbedding = mock(EmbeddingModel.class); + + when(mockEmbedding.embed(any(List.class), any(), any())).thenReturn(List.of(new float[] { 0.1f, 0.2f, 0.3f })); + when(mockEmbedding.embed(any(String.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f }); + when(mockEmbedding.dimensions()).thenReturn(3); + + S3VectorStore vectorStore = new S3VectorStore.Builder(mockClient, mockEmbedding).vectorBucketName("test-bucket") + .indexName("test-index") + .contentMetadataKeyName("content") + .build(); + + Document doc = new Document("test-id", "Custom key text", Map.of("key1", "value1")); + vectorStore.add(List.of(doc)); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(PutVectorsRequest.class); + verify(mockClient).putVectors(requestCaptor.capture()); + + software.amazon.awssdk.core.document.Document metadata = requestCaptor.getValue().vectors().get(0).metadata(); + assertThat(metadata.asMap()).containsEntry("content", + software.amazon.awssdk.core.document.Document.fromString("Custom key text")); + + software.amazon.awssdk.core.document.Document responseMetadata = software.amazon.awssdk.core.document.Document + .fromMap(Map.of("content", software.amazon.awssdk.core.document.Document.fromString("Retrieved text"), + "key1", software.amazon.awssdk.core.document.Document.fromString("value1"))); + + QueryOutputVector outputVector = QueryOutputVector.builder() + .key("doc-id") + .metadata(responseMetadata) + .data(VectorData.builder().float32(List.of(0.1f, 0.2f, 0.3f)).build()) + .build(); + + when(mockClient.queryVectors(any(QueryVectorsRequest.class))) + .thenReturn(QueryVectorsResponse.builder().vectors(List.of(outputVector)).build()); + + List results = vectorStore + .similaritySearch(SearchRequest.builder().query("test query").topK(1).build()); + + assertThat(results.get(0).getText()).isEqualTo("Retrieved text"); + assertThat(results.get(0).getMetadata()).doesNotContainKey("content"); + } + +}