Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.reader.markdown.MarkdownDocumentReader;
import org.springframework.ai.reader.markdown.config.MarkdownDocumentReaderConfig;
import org.springframework.ai.vectorstore.PgVectorStore;
import org.springframework.ai.pg.vectorstore.PgVectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.test.context.SpringBootTest;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.vectorstore.PgVectorStore;
import org.springframework.ai.pg.vectorstore.PgVectorStore;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.PgVectorStore;
import org.springframework.ai.pg.vectorstore.PgVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
Expand Down Expand Up @@ -62,18 +62,21 @@ public PgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embed

var initializeSchema = properties.isInitializeSchema();

return new PgVectorStore.Builder(jdbcTemplate, embeddingModel).withSchemaName(properties.getSchemaName())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General thought: Given every vector store requires a client (jdbctemplate in this case) and an embedding model, can we make both of these as the default argument to the vector store builder?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the design we are using is to have zero arg builder consistent throughout the project. Reason is imagine a class that need 5 strings as required, if the builder takes 5 args, we haven't achieved any simplification by adding a fluent api.

.withVectorTableName(properties.getTableName())
.withVectorTableValidationsEnabled(properties.isSchemaValidation())
.withDimensions(properties.getDimensions())
.withDistanceType(properties.getDistanceType())
.withRemoveExistingVectorStoreTable(properties.isRemoveExistingVectorStoreTable())
.withIndexType(properties.getIndexType())
.withInitializeSchema(initializeSchema)
.withObservationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.withSearchObservationConvention(customObservationConvention.getIfAvailable(() -> null))
.withBatchingStrategy(batchingStrategy)
.withMaxDocumentBatchSize(properties.getMaxDocumentBatchSize())
return PgVectorStore.builder()
.jdbcTemplate(jdbcTemplate)
.embeddingModel(embeddingModel)
.schemaName(properties.getSchemaName())
.vectorTableName(properties.getTableName())
.vectorTableValidationsEnabled(properties.isSchemaValidation())
.dimensions(properties.getDimensions())
.distanceType(properties.getDistanceType())
.removeExistingVectorStoreTable(properties.isRemoveExistingVectorStoreTable())
.indexType(properties.getIndexType())
.initializeSchema(initializeSchema)
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.customObservationConvention(customObservationConvention.getIfAvailable(() -> null))
.batchingStrategy(batchingStrategy)
.maxDocumentBatchSize(properties.getMaxDocumentBatchSize())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
package org.springframework.ai.autoconfigure.vectorstore.pgvector;

import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties;
import org.springframework.ai.vectorstore.PgVectorStore;
import org.springframework.ai.vectorstore.PgVectorStore.PgDistanceType;
import org.springframework.ai.vectorstore.PgVectorStore.PgIndexType;
import org.springframework.ai.pg.vectorstore.PgVectorStore;
import org.springframework.ai.pg.vectorstore.PgVectorStore.PgDistanceType;
import org.springframework.ai.pg.vectorstore.PgVectorStore.PgIndexType;
import org.springframework.boot.context.properties.ConfigurationProperties;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.ai.vectorstore.PgVectorStore;
import org.springframework.ai.pg.vectorstore.PgVectorStore;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.boot.autoconfigure.AutoConfigurations;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

import org.junit.jupiter.api.Test;

import org.springframework.ai.vectorstore.PgVectorStore;
import org.springframework.ai.vectorstore.PgVectorStore.PgDistanceType;
import org.springframework.ai.vectorstore.PgVectorStore.PgIndexType;
import org.springframework.ai.pg.vectorstore.PgVectorStore;
import org.springframework.ai.pg.vectorstore.PgVectorStore.PgDistanceType;
import org.springframework.ai.pg.vectorstore.PgVectorStore.PgIndexType;

import static org.assertj.core.api.Assertions.assertThat;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.pg.vectorstore;

import java.util.List;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.pg.vectorstore;

import java.util.ArrayList;
import java.util.List;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.springframework.ai.vectorstore;
package org.springframework.ai.pg.vectorstore;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
Expand Down Expand Up @@ -43,6 +43,9 @@
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
import org.springframework.ai.util.JacksonUtils;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
Expand All @@ -54,6 +57,7 @@
import org.springframework.jdbc.core.SqlTypeValue;
import org.springframework.jdbc.core.StatementCreatorUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
Expand Down Expand Up @@ -99,8 +103,6 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini

private final JdbcTemplate jdbcTemplate;

private final EmbeddingModel embeddingModel;

private final String schemaName;

private final boolean schemaValidation;
Expand All @@ -123,15 +125,18 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini

private final int maxDocumentBatchSize;

@Deprecated(forRemoval = true, since = "1.0.0-M5")
public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
this(jdbcTemplate, embeddingModel, INVALID_EMBEDDING_DIMENSION, PgDistanceType.COSINE_DISTANCE, false,
PgIndexType.NONE, false);
}

@Deprecated(forRemoval = true, since = "1.0.0-M5")
public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions) {
this(jdbcTemplate, embeddingModel, dimensions, PgDistanceType.COSINE_DISTANCE, false, PgIndexType.NONE, false);
}

@Deprecated(forRemoval = true, since = "1.0.0-M5")
public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions,
PgDistanceType distanceType, boolean removeExistingVectorStoreTable, PgIndexType createIndexMethod,
boolean initializeSchema) {
Expand All @@ -140,60 +145,62 @@ public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, i
createIndexMethod, initializeSchema);
}

@Deprecated(forRemoval = true, since = "1.0.0-M5")
public PgVectorStore(String vectorTableName, JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel,
int dimensions, PgDistanceType distanceType, boolean removeExistingVectorStoreTable,
PgIndexType createIndexMethod, boolean initializeSchema) {

this(DEFAULT_SCHEMA_NAME, vectorTableName, DEFAULT_SCHEMA_VALIDATION, jdbcTemplate, embeddingModel, dimensions,
distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema);
this(builder().jdbcTemplate(jdbcTemplate)
.schemaName(DEFAULT_SCHEMA_NAME)
.vectorTableName(vectorTableName)
.vectorTableValidationsEnabled(DEFAULT_SCHEMA_VALIDATION)
.dimensions(dimensions)
.distanceType(distanceType)
.removeExistingVectorStoreTable(removeExistingVectorStoreTable)
.indexType(createIndexMethod)
.initializeSchema(initializeSchema));
}

private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled,
JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgDistanceType distanceType,
boolean removeExistingVectorStoreTable, PgIndexType createIndexMethod, boolean initializeSchema) {

this(schemaName, vectorTableName, vectorTableValidationsEnabled, jdbcTemplate, embeddingModel, dimensions,
distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema,
ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy(), MAX_DOCUMENT_BATCH_SIZE);
}

private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled,
JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgDistanceType distanceType,
boolean removeExistingVectorStoreTable, PgIndexType createIndexMethod, boolean initializeSchema,
ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention,
BatchingStrategy batchingStrategy, int maxDocumentBatchSize) {
/**
* @param builder {@link VectorStore.Builder} for pg vector store
*/
protected PgVectorStore(PgVectorStoreBuilder builder) {
super(builder);

super(observationRegistry, customObservationConvention);
Assert.notNull(builder.jdbcTemplate, "JdbcTemplate must not be null");

this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build();

this.vectorTableName = (null == vectorTableName || vectorTableName.isEmpty()) ? DEFAULT_TABLE_NAME
: vectorTableName.trim();
String vectorTable = builder.vectorTableName;
this.vectorTableName = (null == vectorTable || vectorTable.isEmpty()) ? DEFAULT_TABLE_NAME : vectorTable.trim();
logger.info("Using the vector table name: {}. Is empty: {}", this.vectorTableName,
(vectorTableName == null || vectorTableName.isEmpty()));
(this.vectorTableName == null || this.vectorTableName.isEmpty()));

this.vectorIndexName = this.vectorTableName.equals(DEFAULT_TABLE_NAME) ? DEFAULT_VECTOR_INDEX_NAME
: this.vectorTableName + "_index";

this.schemaName = schemaName;
this.schemaValidation = vectorTableValidationsEnabled;

this.jdbcTemplate = jdbcTemplate;
this.embeddingModel = embeddingModel;
this.dimensions = dimensions;
this.distanceType = distanceType;
this.removeExistingVectorStoreTable = removeExistingVectorStoreTable;
this.createIndexMethod = createIndexMethod;
this.initializeSchema = initializeSchema;
this.schemaValidator = new PgVectorSchemaValidator(jdbcTemplate);
this.batchingStrategy = batchingStrategy;
this.maxDocumentBatchSize = maxDocumentBatchSize;
this.schemaName = builder.schemaName;
this.schemaValidation = builder.vectorTableValidationsEnabled;

this.jdbcTemplate = builder.jdbcTemplate;
this.dimensions = builder.dimensions;
this.distanceType = builder.distanceType;
this.removeExistingVectorStoreTable = builder.removeExistingVectorStoreTable;
this.createIndexMethod = builder.indexType;
this.initializeSchema = builder.initializeSchema;
this.schemaValidator = new PgVectorSchemaValidator(this.jdbcTemplate);
this.batchingStrategy = builder.batchingStrategy;
this.maxDocumentBatchSize = builder.maxDocumentBatchSize;
}

public PgDistanceType getDistanceType() {
return this.distanceType;
}

public static PgVectorStoreBuilder builder() {
return new PgVectorStoreBuilder();
}

@Override
public void doAdd(List<Document> documents) {
List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
Expand Down Expand Up @@ -527,6 +534,94 @@ private Map<String, Object> toMap(PGobject pgObject) {

}

public static class PgVectorStoreBuilder extends AbstractVectorStoreBuilder<PgVectorStoreBuilder> {

private JdbcTemplate jdbcTemplate;

private String schemaName = PgVectorStore.DEFAULT_SCHEMA_NAME;

private String vectorTableName = PgVectorStore.DEFAULT_TABLE_NAME;

private boolean vectorTableValidationsEnabled = PgVectorStore.DEFAULT_SCHEMA_VALIDATION;

private int dimensions = PgVectorStore.INVALID_EMBEDDING_DIMENSION;

private PgDistanceType distanceType = PgDistanceType.COSINE_DISTANCE;

private boolean removeExistingVectorStoreTable = false;

private PgIndexType indexType = PgIndexType.HNSW;

private boolean initializeSchema;

private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();

private int maxDocumentBatchSize = MAX_DOCUMENT_BATCH_SIZE;

public PgVectorStoreBuilder jdbcTemplate(JdbcTemplate jdbcTemplate) {
Assert.notNull(jdbcTemplate, "JdbcTemplate must not be null");
this.jdbcTemplate = jdbcTemplate;
return this;
}

public PgVectorStoreBuilder schemaName(String schemaName) {
this.schemaName = schemaName;
return this;
}

public PgVectorStoreBuilder vectorTableName(String vectorTableName) {
this.vectorTableName = vectorTableName;
return this;
}

public PgVectorStoreBuilder vectorTableValidationsEnabled(boolean vectorTableValidationsEnabled) {
this.vectorTableValidationsEnabled = vectorTableValidationsEnabled;
return this;
}

public PgVectorStoreBuilder dimensions(int dimensions) {
this.dimensions = dimensions;
return this;
}

public PgVectorStoreBuilder distanceType(PgDistanceType distanceType) {
this.distanceType = distanceType;
return this;
}

public PgVectorStoreBuilder removeExistingVectorStoreTable(boolean removeExistingVectorStoreTable) {
this.removeExistingVectorStoreTable = removeExistingVectorStoreTable;
return this;
}

public PgVectorStoreBuilder indexType(PgIndexType indexType) {
this.indexType = indexType;
return this;
}

public PgVectorStoreBuilder initializeSchema(boolean initializeSchema) {
this.initializeSchema = initializeSchema;
return this;
}

public PgVectorStoreBuilder batchingStrategy(BatchingStrategy batchingStrategy) {
this.batchingStrategy = batchingStrategy;
return this;
}

public PgVectorStoreBuilder maxDocumentBatchSize(int maxDocumentBatchSize) {
this.maxDocumentBatchSize = maxDocumentBatchSize;
return this;
}

public PgVectorStore build() {
validate();
return new PgVectorStore(this);
}

}

@Deprecated(forRemoval = true, since = "1.0.0-M5")
public static class Builder {

private final JdbcTemplate jdbcTemplate;
Expand Down Expand Up @@ -558,7 +653,6 @@ public static class Builder {
@Nullable
private VectorStoreObservationConvention searchObservationConvention;

// Builder constructor with mandatory parameters
public Builder(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
if (jdbcTemplate == null || embeddingModel == null) {
throw new IllegalArgumentException("JdbcTemplate and EmbeddingModel must not be null");
Expand Down Expand Up @@ -628,11 +722,20 @@ public Builder withMaxDocumentBatchSize(int maxDocumentBatchSize) {
}

public PgVectorStore build() {
return new PgVectorStore(this.schemaName, this.vectorTableName, this.vectorTableValidationsEnabled,
this.jdbcTemplate, this.embeddingModel, this.dimensions, this.distanceType,
this.removeExistingVectorStoreTable, this.indexType, this.initializeSchema,
this.observationRegistry, this.searchObservationConvention, this.batchingStrategy,
this.maxDocumentBatchSize);
return PgVectorStore.builder()
.jdbcTemplate(this.jdbcTemplate)
.embeddingModel(this.embeddingModel)
.schemaName(this.schemaName)
.vectorTableName(this.vectorTableName)
.vectorTableValidationsEnabled(this.vectorTableValidationsEnabled)
.dimensions(this.dimensions)
.distanceType(this.distanceType)
.removeExistingVectorStoreTable(this.removeExistingVectorStoreTable)
.indexType(this.indexType)
.initializeSchema(this.initializeSchema)
.batchingStrategy(this.batchingStrategy)
.maxDocumentBatchSize(this.maxDocumentBatchSize)
.build();
}

}
Expand Down
Loading
Loading