Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backend/audio_transcription/makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
test_all: integration unit

format:
uv run ruff format

unit:
uv run pytest test/unit/

Expand Down
5 changes: 5 additions & 0 deletions backend/audio_transcription/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,8 @@ unannotated-parameter = "error"
unannotated-return = "error"
unannotated-attribute = "error"
implicit-any = "error"

[tool.coverage.run]
omit = [
"src/gen/*",
]
4 changes: 2 additions & 2 deletions backend/audio_transcription/src/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _collect_process_metrics() -> None:

def handle_shutdown(
server: grpc.Server, # pyrefly: ignore
_sig: int,
_sig: int,
_frame: types.FrameType | None,
) -> None:
server.stop(grace=5)
Expand Down Expand Up @@ -64,7 +64,7 @@ def serve() -> None:
server = grpc.server( # pyrefly: ignore
ThreadPoolExecutor(max_workers=settings.max_workers)
)
db_pool = create_connection_pool() #
db_pool = create_connection_pool() #
transcription_handler = TranscriptionHandler()
cache = Valkey(host=settings.VALKEY_HOST, port=settings.VALKEY_PORT)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Generator
from psycopg_pool import ConnectionPool
from pgvector.psycopg import register_vector
from testcontainers.postgres import PostgresContainer
import pytest
import psycopg
import numpy as np


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -74,26 +75,56 @@ def db_connection(


@pytest.fixture
def db_pool(
postgres_container: PostgresContainer,
) -> Generator[ConnectionPool, None, None]:
"""
Create a connection pool for testing pool-based operations. clean up tables after each test
"""
# Convert SQLAlchemy URL to PostgreSQL URI for psycopg
connection_url = postgres_container.get_connection_url().replace("+psycopg2", "")
pool = ConnectionPool(conninfo=connection_url, min_size=1, max_size=5, open=True)
with pool.connection() as conn:
# Rollback any failed transaction first
if conn.info.transaction_status != psycopg.pq.TransactionStatus.IDLE:
conn.rollback()
def seed_profile_table():
def _seed(db_conn: psycopg.Connection) -> int:
"""Create a new profile and return the id for downstream use"""
with db_conn.cursor() as cursor:
cursor.execute("INSERT INTO profiles DEFAULT VALUES RETURNING id")
result = cursor.fetchone()
return result[0]

return _seed


@pytest.fixture
def seed_visitor_face_embeddings_table():
def _seed(db_conn: psycopg.Connection, profile_id: int, visitor_name: str) -> int:
"""Create a new visitor face record and return the id for downstream use"""
register_vector(db_conn)

with db_conn.cursor() as cursor:
query = """
INSERT INTO visitor_face_embeddings
(profile_id, visitor_name, face_embedding)
VALUES (%(profile_id)s, %(visitor_name)s, %(face_embedding)s)
RETURNING id
"""

face_embedding = np.zeros(128, dtype=np.float32)

with conn.cursor() as cursor:
cursor.execute(
"TRUNCATE TABLE product_sentiment, raw_comments RESTART IDENTITY CASCADE;"
query,
{
"profile_id": profile_id,
"visitor_name": visitor_name,
"face_embedding": face_embedding,
},
)
conn.commit()
result = cursor.fetchone()
return result[0]

return _seed

yield pool

pool.close()
@pytest.fixture
def seed_session_table():
def _seed(db_conn: psycopg.Connection, profile_id: int, session_token: str) -> None:
"""Insert a session row linking profile_id to session_token"""
with db_conn.cursor() as cursor:
cursor.execute(
"INSERT INTO sessions (profile_id, session_token) VALUES (%s, %s)",
(profile_id, session_token),
)
db_conn.commit()

return _seed
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,12 @@ def grpc_client(grpc_server):
channel = grpc.insecure_channel(grpc_server)
yield audio_transcription_pb2_grpc.AudioTranscriptionServiceStub(channel)
channel.close()


@pytest.fixture
def servicer(tmp_path) -> AudioTranscriptionServicer:
db_pool = MagicMock()
handler = MagicMock()
cache = MagicMock()
with patch("src.grpc.servicer._LOGS_DIR", tmp_path):
return AudioTranscriptionServicer(db_pool, handler, cache)
71 changes: 0 additions & 71 deletions backend/audio_transcription/test/fixtures/queries.py

This file was deleted.

5 changes: 2 additions & 3 deletions backend/audio_transcription/test/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# references to fixture files
pytest_plugins = [
"test.fixtures.setup",
"test.fixtures.queries",
"test.fixtures.db",
"test.fixtures.transcription",
"test.fixtures.grpc_server",
"test.fixtures.grpc",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from unittest.mock import patch
from urllib.parse import urlparse
from testcontainers.postgres import PostgresContainer
from psycopg_pool import ConnectionPool
from src.db.connection_pool import create_connection_pool
import pytest


def test_create_connection_pool_succeeds_and_returns_working_pool(
postgres_container: PostgresContainer,
) -> None:
connection_url = postgres_container.get_connection_url().replace("+psycopg2", "")

parsed = urlparse(connection_url)

mock_settings = patch("src.db.connection_pool.settings")
with mock_settings as s:
s.DB_HOST = parsed.hostname
s.DB_PORT = parsed.port
s.DB_NAME = parsed.path.lstrip("/")
s.DB_USER = parsed.username
s.DB_PASS = parsed.password
s.POOL_MIN_SIZE = 1
s.POOL_MAX_SIZE = 5
s.KEEP_ALIVES_COUNT = 1
s.KEEP_ALIVES_IDLE_S = 30
s.KEEP_ALIVES_RETRY_INTERVAL_S = 10
s.KEEP_ALIVES_RETRY_COUNT = 5

pool = create_connection_pool()

try:
with pool.connection() as conn:
with conn.cursor() as cursor:
cursor.execute("SELECT 1")
result = cursor.fetchone()
assert result == (1,)
finally:
pool.close()


def test_create_connection_pool_raises_on_unreachable_db() -> None:
def _fast_pool(*args, **kwargs):
kwargs["reconnect_timeout"] = 1
kwargs["timeout"] = 1
return ConnectionPool(*args, **kwargs)

mock_settings = patch("src.db.connection_pool.settings")
with (
mock_settings as s,
patch("src.db.connection_pool.ConnectionPool", side_effect=_fast_pool),
):
s.DB_HOST = "127.0.0.1"
s.DB_PORT = 1
s.DB_NAME = "nonexistent"
s.DB_USER = "nobody"
s.DB_PASS = "wrong"
s.POOL_MIN_SIZE = 1
s.POOL_MAX_SIZE = 5
s.KEEP_ALIVES_COUNT = 1
s.KEEP_ALIVES_IDLE_S = 30
s.KEEP_ALIVES_RETRY_INTERVAL_S = 10
s.KEEP_ALIVES_RETRY_COUNT = 5

with pytest.raises(Exception):
create_connection_pool()
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from src.db.conversation_queries import save_conversation
import pytest
import psycopg


Expand Down Expand Up @@ -62,3 +62,28 @@ def test_insert_nonexistant_patient_or_visitor_id(
"""non existant profile id or visitor id should raise error"""
with pytest.raises(Exception):
save_conversation(db_connection, 23, "placeholdText", 23)


def test_save_conversation_rolls_back_on_failed_insert(
db_connection: psycopg.Connection,
seed_profile_table,
seed_visitor_face_embeddings_table,
) -> None:
"""A FK violation mid-loop should roll back all inserts leaving no records."""
profileID = seed_profile_table(db_connection)
visitorID = seed_visitor_face_embeddings_table(db_connection, profileID, "visitor")
bad_visitor_id = 99999

with pytest.raises(Exception):
save_conversation(
db_connection, profileID, "hello", [visitorID, bad_visitor_id]
)

with db_connection.cursor() as cursor:
cursor.execute(
"SELECT COUNT(*) FROM conversation_records WHERE profile_id = %s",
(profileID,),
)
count = cursor.fetchone()[0]

assert count == 0
Loading
Loading