Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/minor-20250721095433798586.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Add ElasticSearch vector store support with full compatibility with existing vector store interface"
}
221 changes: 221 additions & 0 deletions graphrag/vector_stores/elasticsearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""The ElasticSearch vector storage implementation package."""

import json
from typing import Any

import numpy as np
from elasticsearch import Elasticsearch, NotFoundError
from elasticsearch.helpers import bulk

from graphrag.data_model.types import TextEmbedder
from graphrag.vector_stores.base import (
BaseVectorStore,
VectorStoreDocument,
VectorStoreSearchResult,
)


def _create_index_settings(vector_dim: int) -> dict:
"""Create ElasticSearch index settings with dynamic vector dimensions."""
return {
"settings": {
"number_of_shards": 1, # Single shard for development/local setup
"number_of_replicas": 0, # No replicas for development/local setup
},
"mappings": {
"properties": {
"id": {"type": "keyword"},
"text": {"type": "text"},
"vector": {
"type": "dense_vector",
"dims": vector_dim,
"index": True,
"similarity": "cosine",
},
"attributes": {"type": "text"},
}
},
}


class ElasticSearchVectorStore(BaseVectorStore):
"""ElasticSearch vector storage implementation."""

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

def connect(self, **kwargs: Any) -> None:
"""Connect to the vector storage."""
self.db_connection = Elasticsearch(
hosts=[kwargs.get("url", "http://localhost:9200")]
)
if self.collection_name and self.db_connection.indices.exists(
index=self.collection_name
):
pass

def load_documents(
self, documents: list[VectorStoreDocument], overwrite: bool = True
) -> None:
"""Load documents into vector storage."""
if self.db_connection is None:
msg = "Must connect to ElasticSearch before loading documents"
raise RuntimeError(msg)

data = [
{
"id": document.id,
"text": document.text,
"vector": document.vector,
"attributes": json.dumps(document.attributes),
}
for document in documents
if document.vector is not None # Skip documents without embeddings
]

if len(data) == 0:
data = None

if overwrite:
if self.db_connection.indices.exists(index=self.collection_name):
self.db_connection.indices.delete(index=self.collection_name)

if data:
vector_dim = len(data[0]["vector"])
index_settings = _create_index_settings(vector_dim)

self.db_connection.indices.create(
index=self.collection_name,
body=index_settings,
)
actions = [
{
"_index": self.collection_name,
"_id": str(doc["id"]),
"_source": doc,
}
for doc in data
]
bulk(self.db_connection, actions)
# Force index refresh for immediate searchability (ElasticSearch is near real-time by default)
self.db_connection.indices.refresh(index=self.collection_name)
else:
# Default to OpenAI text-embedding-3-small dimensions
default_settings = _create_index_settings(1536)
self.db_connection.indices.create(
index=self.collection_name,
body=default_settings,
)
else:
if not self.db_connection.indices.exists(index=self.collection_name):
if data:
vector_dim = len(data[0]["vector"])
index_settings = _create_index_settings(vector_dim)
else:
# Default to OpenAI text-embedding-3-small dimensions
index_settings = _create_index_settings(1536)

self.db_connection.indices.create(
index=self.collection_name,
body=index_settings,
)
if data:
actions = [
{
"_index": self.collection_name,
"_id": str(doc["id"]),
"_source": doc,
}
for doc in data
]
bulk(self.db_connection, actions)
# Force index refresh for immediate searchability (ElasticSearch is near real-time by default)
self.db_connection.indices.refresh(index=self.collection_name)

def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
"""Build a query filter to filter documents by id."""
if len(include_ids) == 0:
self.query_filter = None
else:
self.query_filter = {
"terms": {"id": [str(doc_id) for doc_id in include_ids]}
}
return self.query_filter

def similarity_search_by_vector(
self, query_embedding: list[float], k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]:
"""Perform a vector-based similarity search."""
if self.db_connection is None:
msg = "Must connect to ElasticSearch before searching"
raise RuntimeError(msg)

query = {
"knn": {
"field": "vector",
"query_vector": query_embedding,
"k": k,
# Search more candidates for better recall in approximate KNN
# 10x multiplier balances accuracy vs performance, capped at 10k for memory limits
"num_candidates": min(k * 10, 10000),
},
"_source": ["id", "text", "vector", "attributes"],
}

if self.query_filter:
query["query"] = self.query_filter

response = self.db_connection.search(
index=self.collection_name,
body=query,
)

return [
VectorStoreSearchResult(
document=VectorStoreDocument(
id=hit["_source"]["id"],
text=hit["_source"]["text"],
vector=hit["_source"]["vector"],
attributes=json.loads(hit["_source"]["attributes"]),
),
score=hit["_score"],
)
for hit in response["hits"]["hits"]
]

def similarity_search_by_text(
self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]:
"""Perform a similarity search using a given input text."""
query_embedding = text_embedder(text)
if query_embedding is not None:
if isinstance(query_embedding, np.ndarray):
# Convert NumPy array to list for ElasticSearch JSON API compatibility
# Unlike LanceDB which supports NumPy natively, ElasticSearch requires JSON serialization
query_embedding = query_embedding.tolist()
return self.similarity_search_by_vector(query_embedding, k)
return []

def search_by_id(self, id: str) -> VectorStoreDocument:
"""Search for a document by id."""
if self.db_connection is None:
msg = "Must connect to ElasticSearch before searching"
raise RuntimeError(msg)

try:
response = self.db_connection.get(
index=self.collection_name,
id=str(id),
)
source = response["_source"]
return VectorStoreDocument(
id=source["id"],
text=source["text"],
vector=source["vector"],
attributes=json.loads(source["attributes"]),
)
except NotFoundError:
return VectorStoreDocument(id=id, text=None, vector=None) # Return null object for consistency
4 changes: 4 additions & 0 deletions graphrag/vector_stores/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
from graphrag.vector_stores.base import BaseVectorStore
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
from graphrag.vector_stores.elasticsearch import ElasticSearchVectorStore
from graphrag.vector_stores.lancedb import LanceDBVectorStore


Expand All @@ -18,6 +19,7 @@ class VectorStoreType(str, Enum):
LanceDB = "lancedb"
AzureAISearch = "azure_ai_search"
CosmosDB = "cosmosdb"
ElasticSearch = "elasticsearch"


class VectorStoreFactory:
Expand Down Expand Up @@ -45,6 +47,8 @@ def create_vector_store(
return AzureAISearchVectorStore(**kwargs)
case VectorStoreType.CosmosDB:
return CosmosDBVectorStore(**kwargs)
case VectorStoreType.ElasticSearch:
return ElasticSearchVectorStore(**kwargs)
case _:
if vector_store_type in cls.vector_store_types:
return cls.vector_store_types[vector_store_type](**kwargs)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ environs = "^11.0.0"

# Vector Stores
azure-search-documents = "^11.5.2"
elasticsearch = "^8.0.0"
lancedb = "^0.17.0"

# Async IO
Expand Down
Loading