diff --git a/plugins/postgres/plugin_test/test_pgvector.py b/plugins/postgres/plugin_test/test_pgvector.py new file mode 100644 index 0000000000..6f92ddc451 --- /dev/null +++ b/plugins/postgres/plugin_test/test_pgvector.py @@ -0,0 +1,68 @@ +import numpy +import pytest + +from superduper.components.table import Table +from superduper import superduper, CFG +from superduper_postgres.vector_search import PGVectorSearcher +from superduper.backends.base.vector_search import measures + + +@pytest.fixture +def db(): + _db = superduper(cluster_engine='local', vector_search_engine='postgres', initialize_cluster=False) + yield _db + _db.drop(True, True) + + +def test_pgvector(db): + + vector_table = Table( + 'vector_table', + fields={ + 'id': 'str', + 'vector': 'vector[float:3]', + }, + primary_id='id', + ) + + db.apply(vector_table, force=True) + + db['vector_table'].insert( + [ + {'vector': numpy.random.randn(3)} + for _ in range(10) + ] + ) + retrieved_vectors = db['vector_table'].execute() + assert isinstance(retrieved_vectors[0]['vector'], numpy.ndarray) + + vector_searcher = PGVectorSearcher( + table='vector_table', + vector_column='vector', + primary_id='id', + dimensions=3, + measure='cosine', + uri=CFG.data_backend + ) + + vector_searcher.initialize() + + h = numpy.random.randn(3) + + result = vector_searcher.find_nearest_from_array( + h=h, + n=10 + ) + + import time + time.sleep(2) + + scores_manual = {} + for r in retrieved_vectors: + scores_manual[r['id']] = measures['cosine'](r['vector'][None, :], h[None, :]).item() + + best_id = max(scores_manual, key=scores_manual.get) + + vector_searcher.drop() + + assert best_id == result[0][0] diff --git a/plugins/postgres/pyproject.toml b/plugins/postgres/pyproject.toml new file mode 100644 index 0000000000..c7cf626951 --- /dev/null +++ b/plugins/postgres/pyproject.toml @@ -0,0 +1,88 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "superduper_postgres" +readme = "README.md" +description = """ +Superduper snowflake is a plugin for snowflake-framework that allows you to use Superduper as a backend for your snowflake queries. + +This plugin cannot be used independently; it must be used together with `superduper_snowflake`. + +Superduper supports SQL databases via the snowflake project. With superduper, queries may be built which conform to the snowflake API, with additional support for complex data-types and vector-searches. +""" +license = {file = "LICENSE"} +maintainers = [{name = "superduper.io, Inc.", email = "opensource@superduper.io"}] +keywords = [ + "databases", + "mongodb", + "data-science", + "machine-learning", + "mlops", + "vector-database", + "ai", +] +requires-python = ">=3.10" +dynamic = ["version"] +dependencies = [ + "psycopg2", + "pgvector", +] + +[project.optional-dependencies] +test = [ + # Annotation plugin dependencies will be installed in CI +] + +[project.urls] +homepage = "https://superduper.io" +documentation = "https://docs.superduper.io/docs/intro" +source = "https://github.com/superduper-io/superduper" + +[tool.setuptools.packages.find] +include = ["superduper_postgres*"] + +[tool.setuptools.dynamic] +version = {attr = "superduper_postgres.__version__"} + +[tool.black] +skip-string-normalization = true +target-version = ["py38"] + +[tool.mypy] +ignore_missing_imports = true +no_implicit_optional = true +warn_unused_ignores = true +disable_error_code = ["has-type", "attr-defined", "assignment", "misc", "override", "call-arg"] + +[tool.pytest.ini_options] +addopts = "-W ignore" + +[tool.ruff.lint] +extend-select = [ + "I", # Missing required import (auto-fixable) + "F", # PyFlakes + #"W", # PyCode Warning + "E", # PyCode Error + #"N", # pep8-naming + "D", # pydocstyle +] +ignore = [ + "D100", # Missing docstring in public module + "D104", # Missing docstring in public package + "D107", # Missing docstring in __init__ + "D105", # Missing docstring in magic method + "D203", # 1 blank line required before class docstring + "D212", # Multi-line docstring summary should start at the first line + "D213", # Multi-line docstring summary should start at the second line + "D401", + "E402", +] + +[tool.ruff.lint.isort] +combine-as-imports = true + +[tool.ruff.lint.per-file-ignores] +"test/**" = ["D"] +"plugin_test/**" = ["D"] diff --git a/plugins/postgres/superduper_postgres/__init__.py b/plugins/postgres/superduper_postgres/__init__.py new file mode 100644 index 0000000000..aba64e0cd6 --- /dev/null +++ b/plugins/postgres/superduper_postgres/__init__.py @@ -0,0 +1,6 @@ +from .vector_search import PGVectorSearcher as VectorSearcher +from .data_backend import PostgresDataBackend as DataBackend + +__version__ = "0.8.0" + +__all__ = ["VectorSearcher", "DataBackend"] \ No newline at end of file diff --git a/plugins/postgres/superduper_postgres/data_backend.py b/plugins/postgres/superduper_postgres/data_backend.py new file mode 100644 index 0000000000..2e9beeadc5 --- /dev/null +++ b/plugins/postgres/superduper_postgres/data_backend.py @@ -0,0 +1,447 @@ +import json +import textwrap +import time +import traceback +import typing as t +import uuid + +import click +import numpy as np +import psycopg2 +from psycopg2.extras import execute_values + +from superduper.base.datatype import Vector +from superduper_postgres.query import map_superduper_query_to_postgres_query +from superduper_postgres.schema import superduper_to_postgres_schema +from superduper import CFG, logging +from superduper.backends.base.data_backend import BaseDataBackend +from superduper.base.event import CreateTable +from superduper.base.query import Query +from superduper.base.schema import Schema + +# Hybrid tables are a feature of Snowflake which are a proxy +# for transactional tables. +CREATE_TABLE = """ +CREATE TABLE IF NOT EXISTS "{identifier}" ( + {schema} +); +""" + +INSERT = """ +INSERT INTO "{table}" ({columns}) + VALUES %s +""" + +REPLACE = """ +INSERT INTO "{table}" ({columns}) +VALUES %s +ON CONFLICT ({primary_id}) +DO UPDATE SET + {excluded_columns} +""" + +UPDATE = """ +UPDATE "{table}" SET {updates} WHERE {conditions} +""" + +class PostgresDataBackend(BaseDataBackend): + """Postgres data backend.""" + + tables_ignore = '^POSTGRES_*' + json_native = True + + def __init__(self, uri, plugin, flavour): + self.uri = uri + self.session = psycopg2.connect(uri) + self.flavour = flavour + + def get_table(self, identifier: str): + raise NotImplementedError + + def reconnect(self): + """Reconnect to the data backend.""" + self.session = psycopg2.connect(self.uri) + + def _run_query(self, query, commit: bool = False): + start = time.time() + logging.info(f"Executing query: {query}") + query = textwrap.dedent(query) + try: + with self.session.cursor() as cur: + cur.execute(query) + try: + result = cur.fetchall() + except psycopg2.ProgrammingError as e: + if 'no results to fetch' in str(e): + result = [] + else: + raise e + logging.info(f"Executing query... DONE ({time.time() - start:.2f}s)") + if commit: + self.session.commit() + except Exception as e: + self.session.rollback() + logging.error(f"Error executing query {query}: {e}") + logging.error(traceback.format_exc()) + raise e + return result + + def drop_table(self, table: str): + """Drop data from table. + + :param table: The table to drop. + """ + return self._run_query(f'DROP TABLE IF EXISTS "{table}"', commit=True) + + @property + def db(self): + """Return the datalayer.""" + return self._db + + @db.setter + def db(self, value): + """Set the datalayer. + + :param value: The datalayer. + """ + self._db = value + + def create_tables_and_schemas(self, events: t.List[CreateTable]): + """Create tables and schemas in the data-backend. + + :param events: The events to create. + """ + for ev in events: + self.create_table_and_schema( + ev.identifier, Schema.build(**ev.fields), ev.primary_id + ) + + def create_table_and_schema(self, identifier: str, schema: Schema, primary_id: str): + """Create a schema in the data-backend. + + :param identifier: The identifier of the schema. + :param schema: The schema to create. + :param primary_id: The primary id of the schema. + """ + if identifier in self.list_tables(): + return + native_schema = superduper_to_postgres_schema(schema, primary_id) + native_schema_str = ',\n '.join( + f'"{k}" {v}' for k, v in native_schema.items() + ) + q = CREATE_TABLE.format( + identifier=identifier, primary_id=primary_id, schema=native_schema_str + ) + self._run_query(q) + + def check_output_dest(self, predict_id) -> bool: + """Check if the output destination exists. + + :param predict_id: The identifier of the output destination. + """ + return CFG.output_prefix + predict_id in self.list_tables() + + def _merge_schemas(self, tables: str): + """Merge schemas. + + :param tables: The tables to merge. + """ + fields = {} + for tab in tables: + tab = self.get_table(tab) + fields.update( + { + f.name.removeprefix('"').removesuffix('"'): f.datatype + for f in tab.schema.fields + if f.name not in fields + } + ) + return fields + + def drop(self, force: bool = False): + """Drop the databackend. + + :param force: If ``True``, don't ask for confirmation. + """ + if not force and not click.confirm( + "Are you sure you want to drop the database?", default=False + ): + return + for table in self.list_tables(): + logging.info(f"Dropping table {table}") + self.drop_table(table) + logging.info(f"Dropping table {table}... DONE") + + def list_tables(self): + """List all tables or collections in the database.""" + sql = """ + SELECT table_name + FROM information_schema.tables + WHERE table_schema NOT IN ('pg_catalog', 'information_schema') + AND table_type IN ('BASE TABLE', 'VIEW') + ORDER BY table_name; + """ + with self.session.cursor() as cur: + cur.execute(sql) + results = cur.fetchall() + return [r[0] for r in results] + + ######################################################## + # Abstract methods/ optional methods to be implemented # + ######################################################## + + def random_id(self): + """Generate a random id.""" + return str(uuid.uuid4().hex)[:16] + + def _fill_primary_id(self, raw_documents, primary_id): + ids = [] + for r in raw_documents: + if primary_id not in r: + r[primary_id] = self.random_id() + ids.append(r[primary_id]) + return ids + + def _get_columns(self, table_name: str): + """Get the columns of a table. + + :param table_name: The name of the table. + """ + return [r[0] for r in self._run_query(f""" + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' AND table_name = '{table_name}'; + """)] + + def _handle_json_columns(self, raw_documents, schema): + json_fields = [ + f for f in schema.fields if getattr(schema[f], 'dtype', None) == 'json' + ] + for f in json_fields: + for r in raw_documents: + if f not in r: + continue + r[f] = json.dumps(r[f]) + + def _handle_vector_columns(self, raw_documents, schema): + vector_fields = [ + f for f in schema.fields if isinstance(schema[f], Vector) + ] + assert len(vector_fields) <= 1, "Only one vector field is supported per table." + field = vector_fields[0] if vector_fields else None + if field is None: + return + + for r in raw_documents: + if field not in r: + continue + if not isinstance(r[field], list): + assert isinstance(r[field], np.ndarray) + assert len(r[field].shape) == 1, "Vector field must be a 1D array." + r[field] = r[field].tolist() + r[field] = f'{r[field]}' + + def insert(self, table_name, raw_documents, primary_id: str | None = None): + """Insert data into the database. + + :param table: The table to insert into. + :param raw_documents: The (encoded) documents to insert. + """ + if primary_id is None: + primary_id = self.db.metadata.get_primary_id(table_name) + + if len(raw_documents) == 0: + return [] + + ids = self._fill_primary_id(raw_documents, primary_id) + cols = self._get_columns(table_name) + def get_row(row): + return [row[col] if col in row else None for col in cols] + s = self.db.metadata.get_schema(table_name) + self._handle_vector_columns(raw_documents, schema=s) + self._handle_json_columns(raw_documents, schema=s) + values = [get_row(r) for r in raw_documents] + cols = ', '.join(f'"{c}"' for c in cols) + with self.session.cursor() as cur: + sql = INSERT.format(table=table_name, columns=cols) + try: + execute_values( + cur, sql, values + ) + except Exception as e: + logging.error(f"Error executing query: {sql} {e}") + logging.error(traceback.format_exc()) + self.session.rollback() + raise e + self.session.commit() + return ids + + def replace(self, table: str, condition: t.Dict, r: t.Dict) -> t.List[str]: + """Replace data. + + :param table: The table to insert into. + :param condition: The condition to update. + :param r: The document to replace. + """ + + + sql = """ + INSERT INTO "{table}" ({primary_id}, {columns}) + VALUES %s + ON CONFLICT ({primary_id}) + DO UPDATE SET + {excluded_columns} + """ + + cols = self.db.metadata.get_columns(table) + primary_id = self.db.metadata.get_primary_id(table) + cols = [c for c in cols if c != primary_id] + excluded_columns = '\n'.join([ + f' {c} = EXCLUDED.{c}' + for c in cols + ]) + sql = sql.format( + table=table, + columns=cols, + primary_id=primary_id, + excluded_columns=excluded_columns, + ) + + if condition: + sql += " WHERE " + clause = [] + for k, v in condition.items(): + if isinstance(v, str): + v = f"'{v}'" + elif isinstance(v, (list, tuple)): + raise NotImplemented + clause.append(f'"{k}" = {v}') + sql += " AND ".join(clause) + + try: + with self.session.cursor() as cur: + cur.execute(sql, (r[primary_id], *[r[c] for c in cols])) + except Exception as e: + logging.error(f"Error executing query: {e}") + logging.error(traceback.format_exc()) + self.session.rollback() + raise e + + self.session.commit() + + def update(self, table: str, condition: t.Dict, key: str, value: t.Any): + """Update data in the database. + + :param table: The table to update. + :param condition: The condition to update. + :param key: The key to update. + :param value: The value to update. + """ + conditions = [] + for k, v in condition.items(): + if isinstance(v, str): + v = f"'{v}'" + elif isinstance(v, (list, tuple)): + raise NotImplementedError("List or tuple conditions are not supported.") + conditions.append(f'"{k}" = {v}') + + conditions = " AND ".join(conditions) + sql = f"UPDATE \"{table}\" SET {key} = %s WHERE {conditions}" + schema = self.db.metadata.get_schema(table) + with self.session.cursor() as cur: + logging.info(f"Executing query: {sql} with value: {value}") + if isinstance(value, str): + value = f"'{value}'" + if getattr(schema[key], 'dtype', None) == 'json': + value = json.dumps(value) + try: + cur.execute(sql, (value,)) + except Exception as e: + logging.error(f"Error executing query: {e}") + logging.error(traceback.format_exc()) + self.session.rollback() + raise e + self.session.commit() + + def delete(self, table: str, condition: t.Dict): + """Update data in the database. + + :param table: The table to update. + :param condition: The condition to update. + """ + terms = [] + for k, v in condition.items(): + if isinstance(v, str): + v = f"'{v}'" + terms.append(f'"{k}" = {v}') + condition = " AND ".join(terms) + q = f'DELETE FROM "{table}" WHERE {condition}' + logging.info(f"Executing query: {q}") + self._run_query(f'DELETE FROM "{table}" WHERE {condition}') + + def missing_outputs(self, query, predict_id): + """Get missing outputs. + + :param query: The query to get the missing outputs of. + :param predict_id: The identifier of the output destination. + """ + pid = self.primary_id(query.table) + + sql = f""" + SELECT {query.table}.{pid} + FROM {query.table} + LEFT JOIN "{CFG.output_prefix}{predict_id}" ON {query.table}.{pid} = {CFG.output_prefix}{predict_id}._source + WHERE {CFG.output_prefix}{predict_id}._source IS NULL + """ + + filter = query.decomposition.filter + if filter: + raise NotImplementedError + + return [r[0] for r in self._run_query(sql)] + + def _build_schema(self, query: Query): + """Build the schema of a query. + + :param query: The query to build the schema of. + """ + return self.get_table(query.table).schema + + def select(self, query: Query, primary_id: str | None = None) -> t.List[t.Dict]: + """Select data from the database. + + :param query: The query to perform. + """ + q = map_superduper_query_to_postgres_query(query) + cols = list(query.decomposition.columns) + for i, c in enumerate(cols): + if isinstance(c, Query): + cols[i] = c.execute() + try: + results = self._run_query(q) + except Exception as e: + print(traceback.format_exc()) + raise e + + results = [{col: v for col, v in zip(cols, result)} for result in results] + output_schema = query.decomposition.schema + vector_datatype = next( + (k for k in output_schema.fields.keys() if isinstance(output_schema[k], Vector)), None + ) + if vector_datatype: + for r in results: + if vector_datatype in r: + r[vector_datatype] = eval(r[vector_datatype]) + return results + + def execute_native(self, query: str): + """Execute a native query. + + :param query: The query to execute. + """ + results = self._run_query(query) + out = [] + for r in results: + out.append(r.as_dict()) + return out + diff --git a/plugins/postgres/superduper_postgres/query.py b/plugins/postgres/superduper_postgres/query.py new file mode 100644 index 0000000000..a78502b741 --- /dev/null +++ b/plugins/postgres/superduper_postgres/query.py @@ -0,0 +1,85 @@ +from superduper.base.query import Query +from superduper import logging +from superduper import CFG + + +OP_LOOKUP = { + '==': '=', + '!=': '<>', + '<': '<', + '<=': '<=', + '>': '>', + '>=': '>=', + 'in': 'IN', +} + + +def map_superduper_query_to_postgres_query(query: Query): + logging.info(f'Mapping SuperDuper query {query} to Postgres query') + + assert query.type == 'select' + + d = query.decomposition + + cols = list(d.select.args if d.select else query.decomposition.columns) + for i, c in enumerate(cols): + if isinstance(c, Query): + cols[i] = c.execute() + else: + cols[i] = f'"{c}"' + + pid = query.primary_id.execute() + + if d.outputs: + predict_ids = d.outputs.args + for i, col in enumerate(cols): + try: + next(predict_id for predict_id in predict_ids if predict_id in col) + cols[i] = f'{col}.{col}' + except StopIteration: + cols[i] = f'{query.table}.{col}' + + cols = ', '.join(cols) + + output = f'SELECT {cols} FROM "{d.table}"' + if d.outputs: + for predict_id in d.outputs.args: + output_t = f'{CFG.output_prefix}{predict_id}' + output += f' INNER JOIN {output_t} ON {d.table}.{pid} = {output_t}._source \n' + + filter_str = '' + if d.filter: + filters = [f.parts for f in d.filter.args] + filter_parts = [] + for col, f in filters: + if f.symbol in OP_LOOKUP: + value = f.args[0] + if isinstance(value, str): + value = f"'{value}'" + + if isinstance(value, Query): + value = f'"{value.execute()}"' + + if f.symbol == 'in': + if col == 'primary_id': + col = query.primary_id.execute() + assert all(isinstance(v, str) for v in value), "All values in 'in' operator must be strings" + value = '(' + ','.join([f"'{v}'" for v in value]) + ')' # Assuming value is a list for 'in' operator + + filter_parts.append( + f"\"{col}\" {OP_LOOKUP[f.symbol]} {value}" + ) + + filter_str = ' WHERE ' + ' AND '.join(filter_parts) + + output += filter_str + + if d.limit: + output += f' LIMIT {d.limit.args[0]}' + + if d.limit and 'offset' in d.limit.kwargs: + output += f' OFFSET {d.limit.kwargs["offset"]}' + + logging.info(f'Mapped SuperDuper to Postgres query: {output}') + + return output \ No newline at end of file diff --git a/plugins/postgres/superduper_postgres/schema.py b/plugins/postgres/superduper_postgres/schema.py new file mode 100644 index 0000000000..bd2cf29cd8 --- /dev/null +++ b/plugins/postgres/superduper_postgres/schema.py @@ -0,0 +1,34 @@ +from superduper.base.datatype import Vector +from superduper.base.schema import Schema + + +def superduper_to_postgres_schema(schema: Schema, primary_id: str = 'id') -> dict: + """Convert a SuperDuper schema to a Postgres schema. + + :param schema: The SuperDuper schema. + """ + out = {} + out[primary_id] = 'VARCHAR(64) PRIMARY KEY' + for f in schema.fields: + if f == primary_id: + continue + if str(schema[f]).lower() == 'str': + out[f] = 'TEXT' + elif str(schema[f]).lower() == 'int': + out[f] = 'INT' + elif str(schema[f]).lower() == 'float': + out[f] = 'FLOAT' + elif str(schema[f]).lower() == 'json': + out[f] = 'JSONB' + elif str(schema[f]).lower() == 'bool': + out[f] = 'BOOLEAN' + elif isinstance(schema[f], Vector): + out[f] = f'VECTOR({schema[f].shape})' + else: + if schema[f].dtype == 'str': + out[f] = 'TEXT' + elif schema[f].dtype == 'json': + out[f] = 'JSONB' + else: + raise ValueError(f"Unsupported field type: {schema[f].dtype} for field {f}") + return out \ No newline at end of file diff --git a/plugins/postgres/superduper_postgres/vector_search.py b/plugins/postgres/superduper_postgres/vector_search.py new file mode 100644 index 0000000000..ff886a3c4e --- /dev/null +++ b/plugins/postgres/superduper_postgres/vector_search.py @@ -0,0 +1,155 @@ +import numpy +import typing as t + +import psycopg2 +from pgvector.psycopg2 import register_vector +from pgvector import Vector # <- new import + +from superduper.backends.base.vector_search import BaseVectorSearcher, VectorItem +from superduper import VectorIndex + + +class PGVectorSearcher(BaseVectorSearcher): + + def __init__( + self, + table: str, + vector_column: str, + primary_id: str, + dimensions: int, + measure: str, + uri: str, + ): + self.conn = psycopg2.connect(uri) + register_vector(self.conn) + self.table = table + self.vector_column = vector_column + self.dimensions = dimensions + self.measure = measure + self.primary_id = primary_id + + def drop(self): + """Drop the vector index.""" + self.conn.close() + + def initialize(self): + """Initialize the vector-searcher. + + :param db: ``Datalayer`` instance. + """ + cur = self.conn.cursor() + # check that this is a vector table + try: + cur.execute(f""" + SELECT * + FROM information_schema.columns + WHERE table_name = '{self.table}' AND column_name = '{self.vector_column}' + """) + except Exception as e: + self.conn.rollback() + raise e + self.conn.commit() + if not cur.fetchone(): + raise ValueError(f"Table {self.table} is not a vector table") + + def add(self, items: t.Sequence['VectorItem']) -> None: + """ + Add items to the index. + + :param items: t.Sequence of VectorItems + """ + return + + def delete(self, ids: t.Sequence[str]) -> None: + """Remove items from the index. + + :param ids: t.Sequence of ids of vectors. + """ + return + + def find_nearest_from_array( + self, + h: numpy.typing.ArrayLike, + n: int = 100, + within_ids: t.Sequence[str] = (), + ) -> t.Tuple[t.List[str], t.List[float]]: + """ + Find the nearest vectors to the given vector. + + :param h: vector + :param n: number of nearest vectors to return + :param within_ids: list of ids to search within + """ + # use pg_vector to find nearest vectors + cur = self.conn.cursor() + operator = { + 'l2': '<->', + 'css': '<=>', + 'cosine': '<=>', + 'dot': '<#>', + }[self.measure] + + if within_ids: + query = f""" + SELECT id, {self.vector_column} {operator} %s AS score + FROM {self.table} + WHERE {self.primary_id} = ANY(%s) + ORDER BY score + LIMIT %s + """ + else: + query = f""" + SELECT id, {self.vector_column} {operator} %s AS score + FROM {self.table} + ORDER BY score + LIMIT %s + """ + with self.conn.cursor() as cur: + if isinstance(h, numpy.ndarray): + h = h.tolist() + try: + if within_ids: + cur.execute(query, (Vector(h), within_ids, n)) + else: + cur.execute(query, (Vector(h), n)) + results = cur.fetchall() + except Exception as e: + self.conn.rollback() + raise e + self.conn.commit() + return [r[0] for r in results], [1 - r[1] for r in results] + + def find_nearest_from_id( + self, + id: str, + n: int = 100, + within_ids: t.Sequence[str] = (), + ) -> t.Tuple[t.List[str], t.List[float]]: + """ + Find the nearest vectors to the given vector. + + :param id: id of the vector to search with + :param n: number of nearest vectors to return + :param within_ids: list of ids to search within + """ + + def __len__(self): + with self.conn.cursor() as cur: + result = cur.execute(f"SELECT COUNT(*) FROM {self.table}") + + return result.fetchone()[0] if result else 0 + + @classmethod + def from_component(cls, vi: VectorIndex): + """Create a PGVectorSearcher from component and vector index.""" + output_table = vi.db.load(vi.indexing_listener.outputs) + pid = output_table.primary_id + from superduper import CFG + return cls( + table=vi.indexing_listener.outputs, + vector_column=vi.indexing_listener.key, + primary_id=pid, + dimensions=vi.dimensions, + measure=vi.measure, + uri=CFG.data_backend, + ) \ No newline at end of file diff --git a/superduper/backends/base/vector_search.py b/superduper/backends/base/vector_search.py index 1ceca12e91..dbce80a659 100644 --- a/superduper/backends/base/vector_search.py +++ b/superduper/backends/base/vector_search.py @@ -310,6 +310,7 @@ def cosine(x, y): x = x / numpy.linalg.norm(x, axis=1)[:, None] # y which implies all vectors in vectordatabase # has normalized vectors. + y = y / numpy.linalg.norm(y, axis=1)[:, None] return dot(x, y) diff --git a/superduper/base/apply.py b/superduper/base/apply.py index c7f273264d..f7aca82fa7 100644 --- a/superduper/base/apply.py +++ b/superduper/base/apply.py @@ -361,15 +361,14 @@ def wrapper(child): jobs=list(job_events.values()), context=context, ) - for service in object.services: - put_events[f'{object.huuid}/{service}'] = PutComponent( - component=object.component, - identifier=object.identifier, - uuid=object.uuid, - context=context, - version=object.version, - service=service, - ) + put_events[object.huuid] = PutComponent( + component=object.component, + identifier=object.identifier, + uuid=object.uuid, + context=context, + version=object.version, + services=object.services, + ) elif apply_status == 'breaking': metadata_event = Create( @@ -386,15 +385,14 @@ def wrapper(child): jobs=list(job_events.values()), context=context, ) - for service in object.services: - put_events[f'{object.huuid}/{service}'] = PutComponent( - component=object.component, - identifier=object.identifier, - uuid=object.uuid, - context=context, - version=object.version, - service=service, - ) + put_events[object.huuid] = PutComponent( + component=object.component, + identifier=object.identifier, + uuid=object.uuid, + context=context, + version=object.version, + services=object.services, + ) d = db['Deployment'] assert deprecated_context is not None diff --git a/superduper/base/build.py b/superduper/base/build.py index 7c003d8a0f..40af08af43 100644 --- a/superduper/base/build.py +++ b/superduper/base/build.py @@ -56,7 +56,7 @@ class _DataBackendLoader(_Loader): r'^mongodb\+srv:\/\/': ('mongodb', 'atlas'), r'^mongomock:\/\/': ('mongodb', 'mongomock'), r'^sqlite://': ('sql', 'base'), - r'^postgresql://': ('sql', 'base'), + r'^postgresql://': ('postgres', 'base'), r'^snowflake:\/\/': ('snowflake', 'base'), r'^duckdb://': ('sql', 'base'), r'^mssql://': ('sql', 'base'), diff --git a/superduper/base/datalayer.py b/superduper/base/datalayer.py index db0152e7cc..2382a0cb1c 100644 --- a/superduper/base/datalayer.py +++ b/superduper/base/datalayer.py @@ -651,12 +651,14 @@ def load( else: if component_cache and (component, identifier) in self._component_cache: assert isinstance(identifier, str) - if self._component_cache[ + cache_uuid = self._component_cache[ (component, identifier) - ].uuid == self.metadata.get_latest_uuid( + ].uuid + latest_uuid = self.metadata.get_latest_uuid( component=component, identifier=identifier, - ): + ) + if cache_uuid == latest_uuid: logging.debug(f'Found {component, identifier} in cache...') return self._component_cache[(component, identifier)] else: diff --git a/superduper/base/event.py b/superduper/base/event.py index 78356e8152..6220b49dbb 100644 --- a/superduper/base/event.py +++ b/superduper/base/event.py @@ -239,14 +239,12 @@ def execute(self, db: 'Datalayer'): ) except Exception as e: - db.metadata.set_component_status( + db.metadata.set_component_failed( component=self.component, uuid=self.data['uuid'], - details_update={ - 'phase': STATUS_FAILED, - 'reason': f'Failed to create: {str(e)}', - 'message': format_exc(), - }, + reason=f'Failed to create: {str(e)}', + message=str(format_exc()), + context=self.context, ) raise e @@ -265,7 +263,7 @@ class PutComponent(Event): :param identifier: the identifier of the component to be created :param version: the version of the component to be created :param uuid: the uuid of the component to be created - :param service: the service to put the component on + :param services: the services to put the component on """ queue: t.ClassVar[str] = '_apply' @@ -276,7 +274,7 @@ class PutComponent(Event): identifier: str version: int uuid: str - service: str + services: t.List[str] @property def cls(self): @@ -285,7 +283,7 @@ def cls(self): @property def huuid(self): - return f'{self.component}:{self.identifier}:{self.uuid}/{self.service}' + return f'{self.component}:{self.identifier}:{self.uuid}' @classmethod def batch_execute( @@ -363,19 +361,20 @@ def execute(self, db: 'Datalayer'): context=self.context, ) db.metadata.create_deployment(deployment) - logging.info( - f'Putting {self.component}:' - f'{self.identifier}:{self.uuid} on {self.service}' - ) - getattr(db.cluster, self.service).put_component( - component=self.component, - uuid=self.uuid, - ) - logging.info( - f'Putting {self.component}:' - f'{self.identifier}:{self.uuid} on {self.service}' - '... DONE' - ) + for service in self.services: + logging.info( + f'Putting {self.component}:' + f'{self.identifier}:{self.uuid} on {service}' + ) + getattr(db.cluster, service).put_component( + component=self.component, + uuid=self.uuid, + ) + logging.info( + f'Putting {self.component}:' + f'{self.identifier}:{self.uuid} on {service}' + '... DONE' + ) class Teardown(Event): diff --git a/superduper/base/metadata.py b/superduper/base/metadata.py index 4cad2927f6..f102bddee3 100644 --- a/superduper/base/metadata.py +++ b/superduper/base/metadata.py @@ -418,7 +418,7 @@ class ArtifactRelations(Base): :param artifact_id: UUID of component version """ - primary_id: t.ClassVar[str] = 'relation' + primary_id: t.ClassVar[str] = 'relation_id' relation_id: str component: str identifier: str @@ -449,7 +449,7 @@ def __init__(self, db: 'Datalayer', parent_db: 'Datalayer'): self.parent_db = parent_db self.primary_ids = { "Table": "uuid", - "ParentChildAssociations": "uuid", + "ParentChildAssociations": "id", "ArtifactRelations": "relation_id", "Job": "job_id", } @@ -478,7 +478,6 @@ def init(self): path='superduper.components.table.Table', ).encode() r['version'] = 0 - self.db.databackend.do_insert('Table', [r], raw=True) r = self.get_component('Table', 'Table') @@ -528,6 +527,20 @@ def check_table_in_metadata(self, table: str): return table in self.db.databackend.list_tables() + def get_columns(self, table: str): + """Get the columns of a table. + + :param table: table name. + """ + if table in metaclasses: + out = list(metaclasses[table].class_schema.fields.keys()) + else: + r = self.get_component('Table', table) + out = list(r['fields'].keys()) + pid = self.get_primary_id(table) + out = [pid] + out + return out + def get_primary_id(self, table: str): """Get the primary id of a table. diff --git a/superduper/base/query.py b/superduper/base/query.py index 4f34bc5002..337c559c05 100644 --- a/superduper/base/query.py +++ b/superduper/base/query.py @@ -22,6 +22,7 @@ from superduper.base.constant import KEY_BLOBS, KEY_BUILDS, KEY_FILES, KEY_PATH from superduper.base.datatype import BaseDataType from superduper.base.document import Document, _unpack +from superduper.base.schema import Schema if t.TYPE_CHECKING: from superduper.base.datalayer import Datalayer @@ -86,12 +87,45 @@ class Decomposition: outputs: QueryPart | None = None op: Op | None = None + def __post_init__(self): + """Post-initialization.""" + self._columns = None + self._schema = None + @property def predict_ids(self): if self.outputs: return self.outputs.args return [] + @property + def schema(self): + """Get the schema of the query.""" + base_schema = self.db.metadata.get_schema(self.table) + if self.outputs: + for predict_id in self.outputs.args: + base_schema += self.db.metadata.get_schema(CFG.output_prefix + predict_id) + cols = self.columns + new_schema = {} + for c in cols: + if str(c) not in base_schema.fields: + continue + new_schema[c] = base_schema[c] + return Schema(new_schema) + + @property + def columns(self): + if self._columns is not None: + return self._columns + if self.select: + cols = self.select.args + else: + cols = self.db.metadata.get_columns(self.table) + if self.outputs: + cols.extend([CFG.output_prefix + x for x in self.outputs.args]) + self._columns = cols + return cols + def to_query(self): """Convert decomposition back to a ``Query``.""" if self.db is None: diff --git a/test/integration/modules/test_query.py b/test/integration/modules/test_query.py index f9b0e33385..5d52edcd56 100644 --- a/test/integration/modules/test_query.py +++ b/test/integration/modules/test_query.py @@ -81,6 +81,27 @@ def test(db): for r in results: assert r[list.outputs] == r['number'] + 1 + vector_table = Table( + 'vector_table', + fields={ + 'id': 'str', + 'vector': 'vector[float:3]', + }, + primary_id='id', + ) + + db.apply(vector_table) + + import numpy + db['vector_table'].insert( + [ + {'vector': numpy.random.rand(3)} + for _ in range(10) + ] + ) + retrieved_vectors = db['vector_table'].execute() + assert isinstance(retrieved_vectors[0]['vector'], numpy.ndarray) + db.databackend.drop_table('documents') assert 'documents' not in db.databackend.list_tables() @@ -88,3 +109,4 @@ def test(db): db.databackend.drop(force=True) assert not db.databackend.list_tables() +