Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.ElasticsearchVectorStore;
import org.springframework.ai.vectorstore.ElasticsearchVectorStoreOptions;
import org.springframework.ai.elasticsearch.vectorstore.ElasticsearchVectorStore;
import org.springframework.ai.elasticsearch.vectorstore.ElasticsearchVectorStoreOptions;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
Expand Down Expand Up @@ -73,9 +73,15 @@ ElasticsearchVectorStore vectorStore(ElasticsearchVectorStoreProperties properti
elasticsearchVectorStoreOptions.setSimilarity(properties.getSimilarity());
}

return new ElasticsearchVectorStore(elasticsearchVectorStoreOptions, restClient, embeddingModel,
properties.isInitializeSchema(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
customObservationConvention.getIfAvailable(() -> null), batchingStrategy);
return ElasticsearchVectorStore.builder()
.restClient(restClient)
.options(elasticsearchVectorStoreOptions)
.embeddingModel(embeddingModel)
.initializeSchema(properties.isInitializeSchema())
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.customObservationConvention(customObservationConvention.getIfAvailable(() -> null))
.batchingStrategy(batchingStrategy)
.build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.springframework.ai.autoconfigure.vectorstore.elasticsearch;

import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties;
import org.springframework.ai.vectorstore.SimilarityFunction;
import org.springframework.ai.elasticsearch.vectorstore.SimilarityFunction;
import org.springframework.boot.context.properties.ConfigurationProperties;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.document.Document;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.vectorstore.ElasticsearchVectorStore;
import org.springframework.ai.elasticsearch.vectorstore.ElasticsearchVectorStore;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.SimilarityFunction;
import org.springframework.ai.elasticsearch.vectorstore.SimilarityFunction;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.elasticsearch.ElasticsearchRestClientAutoConfiguration;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.elasticsearch.vectorstore;

import java.text.ParseException;
import java.text.SimpleDateFormat;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.elasticsearch.vectorstore;

import java.io.IOException;
import java.util.List;
Expand Down Expand Up @@ -48,11 +48,12 @@
import org.springframework.ai.model.EmbeddingUtils;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext.Builder;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -83,8 +84,6 @@ public class ElasticsearchVectorStore extends AbstractObservationVectorStore imp
SimilarityFunction.cosine, VectorStoreSimilarityMetric.COSINE, SimilarityFunction.l2_norm,
VectorStoreSimilarityMetric.EUCLIDEAN, SimilarityFunction.dot_product, VectorStoreSimilarityMetric.DOT);

private final EmbeddingModel embeddingModel;

private final ElasticsearchClient elasticsearchClient;

private final ElasticsearchVectorStoreOptions options;
Expand All @@ -95,34 +94,47 @@ public class ElasticsearchVectorStore extends AbstractObservationVectorStore imp

private final BatchingStrategy batchingStrategy;

@Deprecated(since = "1.0.0-M5", forRemoval = true)
public ElasticsearchVectorStore(RestClient restClient, EmbeddingModel embeddingModel, boolean initializeSchema) {
this(new ElasticsearchVectorStoreOptions(), restClient, embeddingModel, initializeSchema);
}

@Deprecated(since = "1.0.0-M5", forRemoval = true)
public ElasticsearchVectorStore(ElasticsearchVectorStoreOptions options, RestClient restClient,
EmbeddingModel embeddingModel, boolean initializeSchema) {
this(options, restClient, embeddingModel, initializeSchema, ObservationRegistry.NOOP, null,
new TokenCountBatchingStrategy());
}

@Deprecated(since = "1.0.0-M5", forRemoval = true)
public ElasticsearchVectorStore(ElasticsearchVectorStoreOptions options, RestClient restClient,
EmbeddingModel embeddingModel, boolean initializeSchema, ObservationRegistry observationRegistry,
VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) {

super(observationRegistry, customObservationConvention);
this(builder().restClient(restClient)
.options(options)
.embeddingModel(embeddingModel)
.initializeSchema(initializeSchema)
.observationRegistry(observationRegistry)
.customObservationConvention(customObservationConvention)
.batchingStrategy(batchingStrategy));
}

protected ElasticsearchVectorStore(ElasticsearchBuilder builder) {
super(builder);

Assert.notNull(builder.restClient, "RestClient must not be null");

this.initializeSchema = builder.initializeSchema;
this.options = builder.options;
this.filterExpressionConverter = builder.filterExpressionConverter;
this.batchingStrategy = builder.batchingStrategy;

this.initializeSchema = initializeSchema;
Objects.requireNonNull(embeddingModel, "RestClient must not be null");
Objects.requireNonNull(embeddingModel, "EmbeddingModel must not be null");
String version = Version.VERSION == null ? "Unknown" : Version.VERSION.toString();
this.elasticsearchClient = new ElasticsearchClient(new RestClientTransport(restClient,
this.elasticsearchClient = new ElasticsearchClient(new RestClientTransport(builder.restClient,
new JacksonJsonpMapper(
new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false))))
.withTransportOptions(t -> t.addHeader("user-agent", "spring-ai elastic-java/" + version));
this.embeddingModel = embeddingModel;
this.options = options;
this.filterExpressionConverter = new ElasticsearchAiSearchFilterExpressionConverter();
this.batchingStrategy = batchingStrategy;
}

@Override
Expand Down Expand Up @@ -297,4 +309,93 @@ private String getSimilarityMetric() {
public record ElasticSearchDocument(String id, String content, Map<String, Object> metadata, float[] embedding) {
}

/**
* Creates a new builder instance for ElasticsearchVectorStore.
* @return a new ElasticsearchBuilder instance
*/
public static ElasticsearchBuilder builder() {
return new ElasticsearchBuilder();
}

public static class ElasticsearchBuilder extends AbstractVectorStoreBuilder<ElasticsearchBuilder> {

private RestClient restClient;

private ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions();

private boolean initializeSchema = false;

private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();

private FilterExpressionConverter filterExpressionConverter = new ElasticsearchAiSearchFilterExpressionConverter();

/**
* @param restClient the Elasticsearch REST client
* @throws IllegalArgumentException if restClient is null
*/
public ElasticsearchBuilder restClient(RestClient restClient) {
Assert.notNull(restClient, "RestClient must not be null");
this.restClient = restClient;
return this;
}

/**
* Sets the Elasticsearch vector store options.
* @param options the vector store options to use
* @return the builder instance
* @throws IllegalArgumentException if options is null
*/
public ElasticsearchBuilder options(ElasticsearchVectorStoreOptions options) {
Assert.notNull(options, "options must not be null");
this.options = options;
return this;
}

/**
* Sets whether to initialize the schema.
* @param initializeSchema true to initialize schema, false otherwise
* @return the builder instance
*/
public ElasticsearchBuilder initializeSchema(boolean initializeSchema) {
this.initializeSchema = initializeSchema;
return this;
}

/**
* Sets the batching strategy for vector operations.
* @param batchingStrategy the batching strategy to use
* @return the builder instance
* @throws IllegalArgumentException if batchingStrategy is null
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do we throw IllegalArgumentException? Same question for other places we specify that the exception is thrown for invalid values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the method implementation, the Assert.notNull(..) call would throw the IAE.

*/
public ElasticsearchBuilder batchingStrategy(BatchingStrategy batchingStrategy) {
Assert.notNull(batchingStrategy, "batchingStrategy must not be null");
this.batchingStrategy = batchingStrategy;
return this;
}

/**
* Sets the filter expression converter.
* @param converter the filter expression converter to use
* @return the builder instance
* @throws IllegalArgumentException if converter is null
*/
public ElasticsearchBuilder filterExpressionConverter(FilterExpressionConverter converter) {
Assert.notNull(converter, "filterExpressionConverter must not be null");
this.filterExpressionConverter = converter;
return this;
}

/**
* Builds the ElasticsearchVectorStore instance.
* @return a new ElasticsearchVectorStore instance
* @throws IllegalStateException if the builder is in an invalid state
*/
@Override
public ElasticsearchVectorStore build() {
validate();
return new ElasticsearchVectorStore(this);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.elasticsearch.vectorstore;

/**
* Provided Elasticsearch vector option configuration.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.elasticsearch.vectorstore;

/**
* https://www.elastic.co/guide/en/elasticsearch/reference/master/dense-vector.html
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.elasticsearch.vectorstore;

import java.util.Date;
import java.util.List;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.elasticsearch.vectorstore;

import org.testcontainers.utility.DockerImageName;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.elasticsearch.vectorstore;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -51,6 +51,7 @@
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
Expand Down Expand Up @@ -376,23 +377,37 @@ public static class TestApplication {

@Bean("vectorStore_cosine")
public ElasticsearchVectorStore vectorStoreDefault(EmbeddingModel embeddingModel, RestClient restClient) {
return new ElasticsearchVectorStore(restClient, embeddingModel, true);
return ElasticsearchVectorStore.builder()
.restClient(restClient)
.embeddingModel(embeddingModel)
.initializeSchema(true)
.build();
}

@Bean("vectorStore_l2_norm")
public ElasticsearchVectorStore vectorStoreL2(EmbeddingModel embeddingModel, RestClient restClient) {
ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions();
options.setIndexName("index_l2");
options.setSimilarity(SimilarityFunction.l2_norm);
return new ElasticsearchVectorStore(options, restClient, embeddingModel, true);
return ElasticsearchVectorStore.builder()
.restClient(restClient)
.embeddingModel(embeddingModel)
.initializeSchema(true)
.options(options)
.build();
}

@Bean("vectorStore_dot_product")
public ElasticsearchVectorStore vectorStoreDotProduct(EmbeddingModel embeddingModel, RestClient restClient) {
ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions();
options.setIndexName("index_dot_product");
options.setSimilarity(SimilarityFunction.dot_product);
return new ElasticsearchVectorStore(options, restClient, embeddingModel, true);
return ElasticsearchVectorStore.builder()
.restClient(restClient)
.embeddingModel(embeddingModel)
.initializeSchema(true)
.options(options)
.build();
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.elasticsearch.vectorstore;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -51,6 +51,8 @@
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.observation.DefaultVectorStoreObservationConvention;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationDocumentation.HighCardinalityKeyNames;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationDocumentation.LowCardinalityKeyNames;
Expand All @@ -67,6 +69,7 @@
/**
* @author Christian Tzolov
* @author Thomas Vitale
* @author Soby Chacko
*/
@Testcontainers
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
Expand Down Expand Up @@ -205,8 +208,15 @@ public TestObservationRegistry observationRegistry() {
@Bean
public ElasticsearchVectorStore vectorStoreDefault(EmbeddingModel embeddingModel, RestClient restClient,
ObservationRegistry observationRegistry) {
return new ElasticsearchVectorStore(new ElasticsearchVectorStoreOptions(), restClient, embeddingModel, true,
observationRegistry, null, new TokenCountBatchingStrategy());
return ElasticsearchVectorStore.builder()
.restClient(restClient)
.embeddingModel(embeddingModel)
.initializeSchema(true)
.options(new ElasticsearchVectorStoreOptions())
.observationRegistry(observationRegistry)
.customObservationConvention(null)
.batchingStrategy(new TokenCountBatchingStrategy())
.build();
}

@Bean
Expand Down
Loading