Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions plugins/postgres/plugin_test/test_pgvector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import numpy
import pytest

from superduper.components.table import Table
from superduper import superduper, CFG
from superduper_postgres.vector_search import PGVectorSearcher
from superduper.backends.base.vector_search import measures


@pytest.fixture
def db():
_db = superduper(cluster_engine='local', vector_search_engine='postgres', initialize_cluster=False)
yield _db
_db.drop(True, True)


def test_pgvector(db):

vector_table = Table(
'vector_table',
fields={
'id': 'str',
'vector': 'vector[float:3]',
},
primary_id='id',
)

db.apply(vector_table, force=True)

db['vector_table'].insert(
[
{'vector': numpy.random.randn(3)}
for _ in range(10)
]
)
retrieved_vectors = db['vector_table'].execute()
assert isinstance(retrieved_vectors[0]['vector'], numpy.ndarray)

vector_searcher = PGVectorSearcher(
table='vector_table',
vector_column='vector',
primary_id='id',
dimensions=3,
measure='cosine',
uri=CFG.data_backend
)

vector_searcher.initialize()

h = numpy.random.randn(3)

result = vector_searcher.find_nearest_from_array(
h=h,
n=10
)

import time
time.sleep(2)

scores_manual = {}
for r in retrieved_vectors:
scores_manual[r['id']] = measures['cosine'](r['vector'][None, :], h[None, :]).item()

best_id = max(scores_manual, key=scores_manual.get)

vector_searcher.drop()

assert best_id == result[0][0]
88 changes: 88 additions & 0 deletions plugins/postgres/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[project]
name = "superduper_postgres"
readme = "README.md"
description = """
Superduper snowflake is a plugin for snowflake-framework that allows you to use Superduper as a backend for your snowflake queries.

This plugin cannot be used independently; it must be used together with `superduper_snowflake`.

Superduper supports SQL databases via the snowflake project. With superduper, queries may be built which conform to the snowflake API, with additional support for complex data-types and vector-searches.
"""
license = {file = "LICENSE"}
maintainers = [{name = "superduper.io, Inc.", email = "[email protected]"}]
keywords = [
"databases",
"mongodb",
"data-science",
"machine-learning",
"mlops",
"vector-database",
"ai",
]
requires-python = ">=3.10"
dynamic = ["version"]
dependencies = [
"psycopg2",
"pgvector",
]

[project.optional-dependencies]
test = [
# Annotation plugin dependencies will be installed in CI
]

[project.urls]
homepage = "https://superduper.io"
documentation = "https://docs.superduper.io/docs/intro"
source = "https://github.com/superduper-io/superduper"

[tool.setuptools.packages.find]
include = ["superduper_postgres*"]

[tool.setuptools.dynamic]
version = {attr = "superduper_postgres.__version__"}

[tool.black]
skip-string-normalization = true
target-version = ["py38"]

[tool.mypy]
ignore_missing_imports = true
no_implicit_optional = true
warn_unused_ignores = true
disable_error_code = ["has-type", "attr-defined", "assignment", "misc", "override", "call-arg"]

[tool.pytest.ini_options]
addopts = "-W ignore"

[tool.ruff.lint]
extend-select = [
"I", # Missing required import (auto-fixable)
"F", # PyFlakes
#"W", # PyCode Warning
"E", # PyCode Error
#"N", # pep8-naming
"D", # pydocstyle
]
ignore = [
"D100", # Missing docstring in public module
"D104", # Missing docstring in public package
"D107", # Missing docstring in __init__
"D105", # Missing docstring in magic method
"D203", # 1 blank line required before class docstring
"D212", # Multi-line docstring summary should start at the first line
"D213", # Multi-line docstring summary should start at the second line
"D401",
"E402",
]

[tool.ruff.lint.isort]
combine-as-imports = true

[tool.ruff.lint.per-file-ignores]
"test/**" = ["D"]
"plugin_test/**" = ["D"]
6 changes: 6 additions & 0 deletions plugins/postgres/superduper_postgres/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .vector_search import PGVectorSearcher as VectorSearcher
from .data_backend import PostgresDataBackend as DataBackend

__version__ = "0.8.0"

__all__ = ["VectorSearcher", "DataBackend"]
Loading
Loading