Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/database/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ model Collection {
name String @unique
description String?
embedder_config Json
quantization_config Json?
// Collection can have multiple data sources
associated_data_sources Json @default("{}")

Expand Down
8 changes: 5 additions & 3 deletions backend/modules/dataloaders/web_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ async def extract_urls_from_sitemap(url: str) -> List[Tuple[str, str]]:
urls = [
(
loc.text,
loc.find_next_sibling("lastmod").text
if loc.find_next_sibling("lastmod")
else None,
(
loc.find_next_sibling("lastmod").text
if loc.find_next_sibling("lastmod")
else None
),
)
for loc in soup.find_all("loc")
]
Expand Down
69 changes: 39 additions & 30 deletions backend/modules/metadata_store/prisma_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,19 @@ async def aget_collection_by_name(

async def acreate_collection(self, collection: CreateCollection) -> Collection:
logger.info(f"Creating collection: {collection.model_dump()}")
collection_data = collection.model_dump()
collection_data = collection.model_dump(exclude_unset=True)
collection_data["embedder_config"] = json.dumps(
collection_data["embedder_config"]
)
# Handle quantization_config - serialize to JSON if present, remove if None
if collection_data.get("quantization_config") is not None:
collection_data["quantization_config"] = json.dumps(
collection_data["quantization_config"]
)
elif "quantization_config" in collection_data:
# Remove None values to let Prisma handle the optional field
collection_data.pop("quantization_config")

collection: "PrismaCollection" = await self.db.collection.create(
data=collection_data
)
Expand Down Expand Up @@ -183,9 +192,9 @@ async def aassociate_data_sources_with_collection(
parser_config=assoc.parser_config,
data_source=DataSource.model_validate(data_source.model_dump()),
)
existing_associated_data_sources[
assoc.data_source_fqn
] = data_src_to_associate
existing_associated_data_sources[assoc.data_source_fqn] = (
data_src_to_associate
)

# Convert the existing associated data sources to a dictionary
associated_data_sources = {
Expand Down Expand Up @@ -274,9 +283,9 @@ async def adelete_data_source(self, data_source_fqn: str) -> None:
)

# Delete the data source
deleted_datasource: Optional[
PrismaDataSource
] = await self.db.datasource.delete(where={"fqn": data_source_fqn})
deleted_datasource: Optional[PrismaDataSource] = (
await self.db.datasource.delete(where={"fqn": data_source_fqn})
)

if not deleted_datasource:
raise HTTPException(
Expand Down Expand Up @@ -319,10 +328,10 @@ async def acreate_data_ingestion_run(
async def aget_data_ingestion_run(
self, data_ingestion_run_name: str, no_cache: bool = False
) -> Optional[DataIngestionRun]:
data_ingestion_run: Optional[
"PrismaDataIngestionRun"
] = await self.db.ingestionruns.find_first(
where={"name": data_ingestion_run_name}
data_ingestion_run: Optional["PrismaDataIngestionRun"] = (
await self.db.ingestionruns.find_first(
where={"name": data_ingestion_run_name}
)
)
logger.info(f"Data ingestion run: {data_ingestion_run}")
if data_ingestion_run:
Expand All @@ -333,10 +342,10 @@ async def aget_data_ingestion_runs(
self, collection_name: str, data_source_fqn: str = None
) -> List[DataIngestionRun]:
"""Get all data ingestion runs for a collection"""
data_ingestion_runs: List[
"PrismaDataIngestionRun"
] = await self.db.ingestionruns.find_many(
where={"collection_name": collection_name}, order={"id": "desc"}
data_ingestion_runs: List["PrismaDataIngestionRun"] = (
await self.db.ingestionruns.find_many(
where={"collection_name": collection_name}, order={"id": "desc"}
)
)
return [
DataIngestionRun.model_validate(data_ir.model_dump())
Expand All @@ -347,10 +356,10 @@ async def aupdate_data_ingestion_run_status(
self, data_ingestion_run_name: str, status: DataIngestionRunStatus
) -> DataIngestionRun:
"""Update the status of a data ingestion run"""
updated_data_ingestion_run: Optional[
"PrismaDataIngestionRun"
] = await self.db.ingestionruns.update(
where={"name": data_ingestion_run_name}, data={"status": status}
updated_data_ingestion_run: Optional["PrismaDataIngestionRun"] = (
await self.db.ingestionruns.update(
where={"name": data_ingestion_run_name}, data={"status": status}
)
)
if not updated_data_ingestion_run:
raise HTTPException(
Expand All @@ -364,11 +373,11 @@ async def alog_errors_for_data_ingestion_run(
self, data_ingestion_run_name: str, errors: Dict[str, Any]
) -> None:
"""Log errors for the given data ingestion run"""
updated_data_ingestion_run: Optional[
"PrismaDataIngestionRun"
] = await self.db.ingestionruns.update(
where={"name": data_ingestion_run_name},
data={"errors": json.dumps(errors)},
updated_data_ingestion_run: Optional["PrismaDataIngestionRun"] = (
await self.db.ingestionruns.update(
where={"name": data_ingestion_run_name},
data={"errors": json.dumps(errors)},
)
)
if not updated_data_ingestion_run:
raise HTTPException(
Expand All @@ -381,9 +390,9 @@ async def alog_errors_for_data_ingestion_run(
######
async def aget_rag_app(self, app_name: str) -> Optional[RagApplication]:
"""Get a RAG application from the metadata store"""
rag_app: Optional[
"PrismaRagApplication"
] = await self.db.ragapps.find_first_or_raise(where={"name": app_name})
rag_app: Optional["PrismaRagApplication"] = (
await self.db.ragapps.find_first_or_raise(where={"name": app_name})
)

return RagApplication.model_validate(rag_app.model_dump())

Expand All @@ -403,9 +412,9 @@ async def alist_rag_apps(self) -> List[str]:

async def adelete_rag_app(self, app_name: str):
"""Delete a RAG application from the metadata store"""
deleted_rag_app: Optional[
"PrismaRagApplication"
] = await self.db.ragapps.delete(where={"name": app_name})
deleted_rag_app: Optional["PrismaRagApplication"] = (
await self.db.ragapps.delete(where={"name": app_name})
)
if not deleted_rag_app:
raise HTTPException(
status_code=404,
Expand Down
46 changes: 44 additions & 2 deletions backend/modules/vector_db/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from abc import ABC, abstractmethod
from typing import List
from typing import List, Optional

from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.schema.vectorstore import VectorStore

from backend.constants import DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE
from backend.logger import logger
from backend.types import DataPointVector
from backend.types import DataPointVector, QuantizationConfig


class BaseVectorDB(ABC):
Expand All @@ -18,6 +18,48 @@ def create_collection(self, collection_name: str, embeddings: Embeddings):
"""
raise NotImplementedError()

def create_collection_with_quantization(
self,
collection_name: str,
embeddings: Embeddings,
quantization_config: Optional[QuantizationConfig] = None,
):
"""
Create a collection with quantization support in the vector database.
Falls back to regular collection creation if quantization is not supported.
"""
if quantization_config and self.supports_quantization():
return self._create_quantized_collection(
collection_name, embeddings, quantization_config
)
else:
if quantization_config:
logger.warning(
f"Quantization not supported by {self.__class__.__name__}, creating regular collection"
)
return self.create_collection(collection_name, embeddings)

def supports_quantization(self) -> bool:
"""
Check if vector DB supports quantization.
Default implementation returns False, override in subclasses that support it.
"""
return False

def _create_quantized_collection(
self,
collection_name: str,
embeddings: Embeddings,
quantization_config: QuantizationConfig,
):
"""
Internal method to create quantized collection.
Should be implemented by subclasses that support quantization.
"""
raise NotImplementedError(
f"Quantization not implemented for {self.__class__.__name__}"
)

@abstractmethod
def upsert_documents(
self,
Expand Down
8 changes: 5 additions & 3 deletions backend/modules/vector_db/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,11 @@ def list_documents_in_collection(
while stop is not True:
batch_cursor = (
collection.find(
{f"metadata.{DATA_POINT_FQN_METADATA_KEY}": base_document_id}
if base_document_id
else {},
(
{f"metadata.{DATA_POINT_FQN_METADATA_KEY}": base_document_id}
if base_document_id
else {}
),
{f"metadata.{DATA_POINT_FQN_METADATA_KEY}": 1},
)
.skip(offset if offset else 0)
Expand Down
81 changes: 76 additions & 5 deletions backend/modules/vector_db/qdrant.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional
from urllib.parse import urlparse

from langchain.embeddings.base import Embeddings
Expand All @@ -9,7 +9,13 @@
from backend.constants import DATA_POINT_FQN_METADATA_KEY, DATA_POINT_HASH_METADATA_KEY
from backend.logger import logger
from backend.modules.vector_db.base import BaseVectorDB
from backend.types import DataPointVector, QdrantClientConfig, VectorDBConfig
from backend.types import (
DataPointVector,
QdrantClientConfig,
VectorDBConfig,
QuantizationConfig,
QuantizationType,
)

MAX_SCROLL_LIMIT = int(1e6)
BATCH_SIZE = 1000
Expand Down Expand Up @@ -40,13 +46,15 @@ def __init__(self, config: VectorDBConfig):
url=url, api_key=api_key, **qdrant_kwargs.model_dump()
)

def supports_quantization(self) -> bool:
"""Qdrant supports quantization"""
return True

def create_collection(self, collection_name: str, embeddings: Embeddings):
logger.debug(f"[Qdrant] Creating new collection {collection_name}")

# Calculate embedding size
partial_embeddings = embeddings.embed_documents(["Initial document"])
vector_size = len(partial_embeddings[0])
logger.debug(f"Vector size: {vector_size}")
vector_size = self.get_embedding_dimensions(embeddings)

self.qdrant_client.create_collection(
collection_name=collection_name,
Expand All @@ -64,6 +72,69 @@ def create_collection(self, collection_name: str, embeddings: Embeddings):
)
logger.debug(f"[Qdrant] Created new collection {collection_name}")

def _create_quantized_collection(
self,
collection_name: str,
embeddings: Embeddings,
quantization_config: QuantizationConfig,
):
logger.debug(
f"[Qdrant] Creating quantized collection {collection_name} with {quantization_config.type} quantization"
)

# Calculate embedding size
vector_size = self.get_embedding_dimensions(embeddings)

# Configure quantization based on type
quantization = None
if quantization_config.type == QuantizationType.SCALAR:
quantization = models.ScalarQuantization(
scalar=models.ScalarQuantizationConfig(
type=models.ScalarType.INT8,
quantile=0.99,
always_ram=quantization_config.always_ram,
)
)
logger.debug(
f"[Qdrant] Using scalar quantization (INT8) for collection {collection_name}"
)
elif quantization_config.type == QuantizationType.BINARY:
quantization = models.BinaryQuantization(
binary=models.BinaryQuantizationConfig(
always_ram=quantization_config.always_ram,
)
)
logger.debug(
f"[Qdrant] Using binary quantization for collection {collection_name}"
)
else:
logger.warning(
f"[Qdrant] Unsupported quantization type: {quantization_config.type}, falling back to no quantization"
)

# Create collection with quantization
self.qdrant_client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=vector_size,
distance=Distance.COSINE,
on_disk=True,
quantization_config=quantization,
),
replication_factor=3,
)

# Create payload index for metadata
self.qdrant_client.create_payload_index(
collection_name=collection_name,
field_name=f"metadata.{DATA_POINT_FQN_METADATA_KEY}",
field_schema=models.PayloadSchemaType.KEYWORD,
)

logger.info(
f"[Qdrant] Successfully created quantized collection {collection_name} with {quantization_config.type} quantization"
)

def _get_records_to_be_upserted(
self, collection_name: str, data_point_fqns: List[str], incremental: bool
):
Expand Down
27 changes: 22 additions & 5 deletions backend/server/routers/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,32 @@ async def create_collection(
name=collection.name,
description=collection.description,
embedder_config=collection.embedder_config,
quantization_config=collection.quantization_config,
)
)
logger.info(f"Creating collection {collection.name} on vector db...")
VECTOR_STORE_CLIENT.create_collection(
collection_name=collection.name,
embeddings=model_gateway.get_embedder_from_model_config(
model_name=collection.embedder_config.name
),

# Get embeddings model
embeddings = model_gateway.get_embedder_from_model_config(
model_name=collection.embedder_config.name
)

# Create collection with quantization support if specified
if collection.quantization_config:
logger.info(
f"Creating collection {collection.name} with {collection.quantization_config.type} quantization..."
)
VECTOR_STORE_CLIENT.create_collection_with_quantization(
collection_name=collection.name,
embeddings=embeddings,
quantization_config=collection.quantization_config,
)
else:
VECTOR_STORE_CLIENT.create_collection(
collection_name=collection.name,
embeddings=embeddings,
)

logger.info(f"Created collection... {created_collection}")

if collection.associated_data_sources:
Expand Down
Loading