Skip to content

Commit ed1f762

Browse files
committed
feat: Added QdrantStore
1 parent f4b2472 commit ed1f762

File tree

1 file changed

+130
-0
lines changed

1 file changed

+130
-0
lines changed

datapipe/store/qdrant.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import pandas as pd
2+
import hashlib
3+
import uuid
4+
5+
from typing import Dict, List, Optional
6+
from qdrant_client import QdrantClient
7+
from qdrant_client.conversions import common_types as types
8+
from qdrant_client.http import models as rest
9+
from qdrant_client.http.exceptions import UnexpectedResponse
10+
11+
from datapipe.types import DataSchema, MetaSchema, IndexDF, DataDF, data_to_index
12+
from datapipe.store.table_store import TableStore
13+
14+
15+
class CollectionParams(rest.CreateCollection):
16+
pass
17+
18+
19+
class QdrantStore(TableStore):
20+
def __init__(
21+
self,
22+
name: str,
23+
host: str,
24+
port: int,
25+
schema: DataSchema,
26+
pk_field: str,
27+
embedding_field: str,
28+
collection_params: CollectionParams
29+
):
30+
super().__init__()
31+
self.name = name
32+
self.host = host
33+
self.port = port
34+
self.schema = schema
35+
self.pk_field = pk_field
36+
self.embedding_field = embedding_field
37+
self.collection_params = collection_params
38+
self.inited = False
39+
self.client = QdrantClient(host=self.host, port=self.port)
40+
41+
pk_columns = [column for column in self.schema if column.primary_key]
42+
43+
if len(pk_columns) != 1 and pk_columns[0].name != pk_field:
44+
raise ValueError("Incorrect prymary key columns in schema")
45+
46+
self.paylods_filelds = [column.name for column in self.schema if column.name != self.embedding_field]
47+
48+
def __init(self):
49+
try:
50+
self.client.get_collection(self.name)
51+
except UnexpectedResponse as e:
52+
if e.status_code == 404:
53+
self.client.http.collections_api.create_collection(
54+
collection_name=self.name,
55+
create_collection=self.collection_params
56+
)
57+
58+
def __check_init(self):
59+
if not self.inited:
60+
self.__init()
61+
self.inited = True
62+
63+
def __get_ids(self, df):
64+
return df[self.pk_field].apply(
65+
lambda x: str(uuid.UUID(bytes=hashlib.md5(str(x).encode('utf-8')).digest()))
66+
).to_list()
67+
68+
def get_primary_schema(self) -> DataSchema:
69+
return [column for column in self.schema if column.primary_key]
70+
71+
def get_meta_schema(self) -> MetaSchema:
72+
return []
73+
74+
def insert_rows(self, df: DataDF) -> None:
75+
self.__check_init()
76+
77+
if len(df) == 0:
78+
return
79+
80+
self.client.upsert(
81+
self.name,
82+
rest.Batch(
83+
ids=self.__get_ids(df),
84+
vectors=df[self.embedding_field].apply(list).to_list(),
85+
payloads=df[self.paylods_filelds].to_dict(orient='records')
86+
),
87+
wait=True,
88+
)
89+
90+
def update_rows(self, df: DataDF) -> None:
91+
self.insert_rows(df)
92+
93+
def delete_rows(self, idx: IndexDF) -> None:
94+
self.__check_init()
95+
96+
if len(idx) == 0:
97+
return
98+
99+
self.client.delete(
100+
self.name,
101+
rest.PointIdsList(
102+
points=self.__get_ids(idx)
103+
),
104+
wait=True,
105+
)
106+
107+
def read_rows(self, idx: Optional[IndexDF] = None) -> DataDF:
108+
self.__check_init()
109+
110+
if not idx:
111+
raise Exception("Qrand doesn't support full store reading")
112+
113+
response = self.client.http.points_api.get_points(
114+
self.name,
115+
point_request=rest.PointRequest(
116+
ids=self.__get_ids(idx),
117+
with_payload=True,
118+
with_vector=True
119+
)
120+
)
121+
122+
records = []
123+
124+
for point in response.result:
125+
record = point.payload
126+
record[self.embedding_field] = point.vector
127+
128+
records.append(record)
129+
130+
return pd.DataFrame.from_records(records)[[column.name for column in self.schema]]

0 commit comments

Comments
 (0)