@@ -91,6 +91,23 @@ def create_vector_index(self, name: str, column: str):
9191 )
9292 )
9393
94+ def create_multivec_index (self , name : str , column : str ):
95+ config = "build.internal.lists = []"
96+ with self .transaction ():
97+ cursor = self .get_cursor ()
98+ cursor .execute (
99+ sql .SQL (
100+ "CREATE INDEX IF NOT EXISTS {index} ON "
101+ "{table} USING vchordrq ({column} vector_maxsim_ops) WITH "
102+ "(options = $${config}$$);"
103+ ).format (
104+ table = sql .Identifier (f"{ self .ns } _{ name } " ),
105+ index = sql .Identifier (f"{ self .ns } _{ name } _{ column } _multivec_idx" ),
106+ column = sql .Identifier (column ),
107+ config = sql .SQL (config ),
108+ )
109+ )
110+
94111 def _keyword_index_name (self , name : str , column : str ):
95112 return f"{ self .ns } _{ name } _{ column } _bm25_idx"
96113
@@ -114,6 +131,7 @@ def select(
114131 raw_columns : Sequence [str ],
115132 kvs : Optional [dict [str , Any ]] = None ,
116133 from_buffer : bool = False ,
134+ limit : Optional [int ] = None ,
117135 ):
118136 """Select from db table with optional key-value condition or from un-committed
119137 transaction buffer.
@@ -129,12 +147,18 @@ def select(
129147 )
130148 if kvs :
131149 condition = sql .SQL (" AND " ).join (
132- sql .SQL ("{} = {}" ).format (sql .Identifier (col ), sql .Placeholder (col ))
133- for col in kvs
150+ sql .SQL ("{} IS NULL" ).format (sql .Identifier (col ))
151+ if val is None
152+ else sql .SQL ("{} = {}" ).format (
153+ sql .Identifier (col ), sql .Placeholder (col )
154+ )
155+ for col , val in kvs .items ()
134156 )
135157 query += sql .SQL (" WHERE {condition}" ).format (condition = condition )
136158 elif from_buffer :
137159 query += sql .SQL (" WHERE xmin = pg_current_xact_id()::xid;" )
160+ if limit :
161+ query += sql .SQL (" LIMIT {}" ).format (sql .Literal (limit ))
138162 cursor .execute (query , kvs )
139163 return [row for row in cursor .fetchall ()]
140164
@@ -199,6 +223,36 @@ def query_vec(
199223 )
200224 return [row for row in cursor .fetchall ()]
201225
226+ def query_multivec ( # noqa: PLR0913
227+ self ,
228+ name : str ,
229+ multivec_col : str ,
230+ vec : np .ndarray ,
231+ max_maxsim_tuples : int ,
232+ return_fields : list [str ],
233+ topk : int = 10 ,
234+ ):
235+ columns = sql .SQL (", " ).join (map (sql .Identifier , return_fields ))
236+ with self .transaction ():
237+ cursor = self .get_cursor ()
238+ cursor .execute ("SET vchordrq.probes = '';" )
239+ cursor .execute (
240+ sql .SQL ("SET vchordrq.max_maxsim_tuples = {};" ).format (
241+ sql .Literal (max_maxsim_tuples )
242+ )
243+ )
244+ cursor .execute (
245+ sql .SQL (
246+ "SELECT {columns} FROM {table} ORDER BY {multivec_col} @# %s LIMIT %s;"
247+ ).format (
248+ table = sql .Identifier (f"{ self .ns } _{ name } " ),
249+ columns = columns ,
250+ multivec_col = sql .Identifier (multivec_col ),
251+ ),
252+ (vec , topk ),
253+ )
254+ return [row for row in cursor .fetchall ()]
255+
202256 def query_keyword ( # noqa: PLR0913
203257 self ,
204258 name : str ,
0 commit comments