Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
name: Release to PyPI

on:
release:
types: [published]
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class Document(Table, kw_only=True):

class Chunk(Table, kw_only=True)
uid: Optional[PrimaryKeyAutoIncrease] = None
doc_id: Annotated[int, ForeignKey[Document.uid]] # reference to `Document.uid`
doc_id: Annotated[int, ForeignKey[Document.uid]] # reference to `Document.uid` on DELETE CASCADE
vector: DenseVector # this comes with a default vector index
text: str
```
Expand Down
23 changes: 22 additions & 1 deletion vechord/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import psycopg
from pgvector.psycopg import register_vector
from psycopg import sql
from psycopg.pq import TransactionStatus

from vechord.spec import IndexColumn, Keyword

Expand Down Expand Up @@ -52,7 +53,11 @@ def __init__(self, namespace: str, url: str):

@contextlib.contextmanager
def transaction(self):
"""Create a transaction context manager."""
"""Create a transaction context manager (when there is no transaction)."""
if self.conn.info.transaction_status != TransactionStatus.IDLE:
yield None
return

with self.conn.transaction():
cursor = self.conn.cursor()
token = active_cursor.set(cursor)
Expand Down Expand Up @@ -164,6 +169,7 @@ def select(

@staticmethod
def _to_placeholder(kv: tuple[str, Any]):
"""Process the `Keyword` type"""
key, value = kv
if isinstance(value, Keyword):
return sql.SQL("tokenize({}, {})").format(
Expand All @@ -183,6 +189,21 @@ def insert(self, name: str, values: dict):
values,
)

def copy_bulk(self, name: str, values: Sequence[dict]):
columns = sql.SQL(", ").join(map(sql.Identifier, values[0]))
with self.transaction():
cursor = self.get_cursor()
with cursor.copy(
sql.SQL(
"COPY {table} ({columns}) FROM STDIN WITH (FORMAT BINARY)"
).format(
table=sql.Identifier(f"{self.ns}_{name}"),
columns=columns,
)
) as copy:
for value in values:
copy.write_row(tuple(value.values()))

def delete(self, name: str, kvs: dict):
if kvs:
condition = sql.SQL(" AND ").join(
Expand Down
42 changes: 37 additions & 5 deletions vechord/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@ def select_by(
"""Retrieve the requested fields for the given object stored in the DB.

Args:
obj: the object to be retrieved, this should be a `Table.partial_init()`
instance, which means given values will be used for filtering.
obj: the object to be retrieved, this should be generated from
:meth:`Table.partial_init`, while the given values will be used
for filtering (``=`` or ``is``).
fields: the fields to be retrieved, if not set, all the fields will be
retrieved.
limit: the maximum number of results to be returned, if not set, all
Expand Down Expand Up @@ -287,6 +288,27 @@ def insert(self, obj: Table):
raise ValueError(f"unsupported class {type(obj)}")
self.client.insert(obj.name(), obj.todict())

def copy_bulk(self, objs: list[Table]):
"""Insert the given list of objects to the DB.

This is more efficient than calling `insert` for each object.

Args:
objs: the list of objects to be inserted, needs to be of the same
class and filled with the same fields. The class should be a
subclass of `Table`.
"""
if not objs:
return
cls = objs[0].__class__
if not issubclass(cls, Table):
raise ValueError(f"unsupported class {cls}")
if not all(isinstance(obj, cls) for obj in objs):
raise ValueError(f"not all the objects are {cls}")

name = objs[0].name()
self.client.copy_bulk(name, [obj.todict() for obj in objs])

def inject(
self, input: Optional[type[Table]] = None, output: Optional[type[Table]] = None
):
Expand Down Expand Up @@ -334,11 +356,21 @@ def wrapper(*args, **kwargs):
return [func(*arg, **kwargs) for arg in arguments]

count = 0
use_copy = output.keyword_column() is None
if is_list_of_type(returns):
for arg in arguments:
for ret in func(*arg, **kwargs):
self.insert(ret)
count += 1
if use_copy:
rets = list(func(*arg, **kwargs))
self.copy_bulk(rets)
count += len(rets)
else:
for ret in func(*arg, **kwargs):
self.insert(ret)
count += 1
elif use_copy:
rets = list(func(*args, **kwargs) for args in arguments)
self.copy_bulk(rets)
count += len(rets)
else:
for arg in arguments:
ret = func(*arg, **kwargs)
Expand Down