diff --git a/pgvector/bit.py b/pgvector/bit.py index 26a9d8d..cecd180 100644 --- a/pgvector/bit.py +++ b/pgvector/bit.py @@ -62,9 +62,12 @@ def from_binary(cls, value): @classmethod def _to_db(cls, value): + if value is None: + return value + if not isinstance(value, cls): - raise ValueError('expected bit') - + value = cls(value) + return value.to_text() @classmethod @@ -73,3 +76,9 @@ def _to_db_binary(cls, value): raise ValueError('expected bit') return value.to_binary() + + @classmethod + def _from_db(cls, value): + if value is None or isinstance(value, cls): + return value + return cls.from_text(value) \ No newline at end of file diff --git a/pgvector/sqlalchemy/bit.py b/pgvector/sqlalchemy/bit.py index 0f83f3c..338e1b6 100644 --- a/pgvector/sqlalchemy/bit.py +++ b/pgvector/sqlalchemy/bit.py @@ -1,6 +1,8 @@ +import asyncpg +from sqlalchemy.dialects.postgresql.asyncpg import PGDialect_asyncpg from sqlalchemy.dialects.postgresql.base import ischema_names from sqlalchemy.types import UserDefinedType, Float - +from .. import Bit class BIT(UserDefinedType): cache_ok = True @@ -14,6 +16,23 @@ def get_col_spec(self, **kw): return 'BIT' return 'BIT(%d)' % self.length + def bind_processor(self, dialect): + def process(value): + value = Bit._to_db(value) + if value and isinstance(dialect, PGDialect_asyncpg): + return asyncpg.BitString(value) + return value + return process + + def result_processor(self, dialect, coltype): + def process(value): + if value is None: return None + else: + if isinstance(dialect, PGDialect_asyncpg): + return value.as_string() + return Bit._from_db(value).to_text() + return process + class comparator_factory(UserDefinedType.Comparator): def hamming_distance(self, other): return self.op('<~>', return_type=Float)(other) diff --git a/tests/test_sqlalchemy.py b/tests/test_sqlalchemy.py index 5aec977..286bf66 100644 --- a/tests/test_sqlalchemy.py +++ b/tests/test_sqlalchemy.py @@ -591,7 +591,7 @@ async def test_bit(self, engine): async with async_session() as session: async with session.begin(): - embedding = asyncpg.BitString('101') if engine == asyncpg_engine else '101' + embedding = '101' session.add(Item(id=1, binary_embedding=embedding)) item = await session.get(Item, 1) assert item.binary_embedding == embedding