Skip to content
7 changes: 5 additions & 2 deletions pgvector/bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,12 @@ def from_binary(cls, value):

@classmethod
def _to_db(cls, value):
if value is None:
return value

if not isinstance(value, cls):
raise ValueError('expected bit')

value = cls(value)
return value.to_text()

@classmethod
Expand Down
6 changes: 4 additions & 2 deletions pgvector/sqlalchemy/bit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncpg
from sqlalchemy.dialects.postgresql.asyncpg import PGDialect_asyncpg
from sqlalchemy.dialects.postgresql.base import ischema_names
from sqlalchemy.types import UserDefinedType, Float

from .. import Bit

class BIT(UserDefinedType):
cache_ok = True
Expand All @@ -24,7 +26,7 @@ def process(value):
return value
return process
else:
return super().bind_processor(dialect)
return lambda value: Bit._to_db(value)

class comparator_factory(UserDefinedType.Comparator):
def hamming_distance(self, other):
Expand Down
8 changes: 7 additions & 1 deletion tests/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,13 @@ def test_bit(self, engine):
item = session.get(Item, 1)
assert item.binary_embedding == '101'

def test_boolean_list_bit(self, engine):
with Session(engine) as session:
session.add(Item(id=1, binary_embedding=[True, False, True]))
session.commit()
item = session.get(Item, 1)
assert item.binary_embedding == '101'

def test_bit_hamming_distance(self, engine):
create_items()
with Session(engine) as session:
Expand Down Expand Up @@ -567,7 +574,6 @@ def test_halfvec_array(self, engine):
item = session.get(Item, 1)
assert item.half_embeddings == [HalfVector([1, 2, 3]), HalfVector([4, 5, 6])]


@pytest.mark.parametrize('engine', async_engines)
class TestSqlalchemyAsync:
def setup_method(self):
Expand Down