From 106c9b0caa0613060eca35c032fe9aad83c357c6 Mon Sep 17 00:00:00 2001 From: Raghotham Murthy Date: Sat, 4 Oct 2025 13:05:51 -0700 Subject: [PATCH 1/8] feat: Allow :memory: for kvstore --- .../providers/utils/kvstore/sqlite/sqlite.py | 113 ++++++++++------- .../unit/utils/kvstore/test_sqlite_memory.py | 116 ++++++++++++++++++ 2 files changed, 182 insertions(+), 47 deletions(-) create mode 100644 tests/unit/utils/kvstore/test_sqlite_memory.py diff --git a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py index 5b782902e7..e1c0332fee 100644 --- a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py +++ b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py @@ -21,67 +21,86 @@ class SqliteKVStoreImpl(KVStore): def __init__(self, config: SqliteKVStoreConfig): self.db_path = config.db_path self.table_name = "kvstore" + self._conn: aiosqlite.Connection | None = None def __str__(self): return f"SqliteKVStoreImpl(db_path={self.db_path}, table_name={self.table_name})" + def _is_memory_db(self) -> bool: + """Check if this is an in-memory database.""" + return self.db_path == ":memory:" or "mode=memory" in self.db_path + async def initialize(self): - os.makedirs(os.path.dirname(self.db_path), exist_ok=True) - async with aiosqlite.connect(self.db_path) as db: - await db.execute( - f""" - CREATE TABLE IF NOT EXISTS {self.table_name} ( - key TEXT PRIMARY KEY, - value TEXT, - expiration TIMESTAMP - ) - """ + # Skip directory creation for in-memory databases and file: URIs + if not self._is_memory_db() and not self.db_path.startswith("file:"): + db_dir = os.path.dirname(self.db_path) + if db_dir: # Only create if there's a directory component + os.makedirs(db_dir, exist_ok=True) + + # Create persistent connection for all databases + self._conn = await aiosqlite.connect(self.db_path) + await self._conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + key TEXT PRIMARY KEY, + value TEXT, + expiration TIMESTAMP ) - await db.commit() + """ + ) + await self._conn.commit() + + async def close(self): + """Close the persistent connection.""" + if self._conn: + await self._conn.close() + self._conn = None async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: - async with aiosqlite.connect(self.db_path) as db: - await db.execute( - f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)", - (key, value, expiration), - ) - await db.commit() + assert self._conn is not None, "Connection not initialized. Call initialize() first." + await self._conn.execute( + f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)", + (key, value, expiration), + ) + await self._conn.commit() async def get(self, key: str) -> str | None: - async with aiosqlite.connect(self.db_path) as db: - async with db.execute(f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)) as cursor: - row = await cursor.fetchone() - if row is None: - return None - value, expiration = row - if not isinstance(value, str): - logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None") - return None - return value + assert self._conn is not None, "Connection not initialized. Call initialize() first." + async with self._conn.execute( + f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,) + ) as cursor: + row = await cursor.fetchone() + if row is None: + return None + value, expiration = row + if not isinstance(value, str): + logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None") + return None + return value async def delete(self, key: str) -> None: - async with aiosqlite.connect(self.db_path) as db: - await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) - await db.commit() + assert self._conn is not None, "Connection not initialized. Call initialize() first." + await self._conn.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) + await self._conn.commit() async def values_in_range(self, start_key: str, end_key: str) -> list[str]: - async with aiosqlite.connect(self.db_path) as db: - async with db.execute( - f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", - (start_key, end_key), - ) as cursor: - result = [] - async for row in cursor: - _, value, _ = row - result.append(value) - return result + assert self._conn is not None, "Connection not initialized. Call initialize() first." + async with self._conn.execute( + f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", + (start_key, end_key), + ) as cursor: + result = [] + async for row in cursor: + _, value, _ = row + result.append(value) + return result async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: """Get all keys in the given range.""" - async with aiosqlite.connect(self.db_path) as db: - cursor = await db.execute( - f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?", - (start_key, end_key), - ) - rows = await cursor.fetchall() - return [row[0] for row in rows] + assert self._conn is not None, "Connection not initialized. Call initialize() first." + cursor = await self._conn.execute( + f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?", + (start_key, end_key), + ) + rows = await cursor.fetchall() + return [row[0] for row in rows] diff --git a/tests/unit/utils/kvstore/test_sqlite_memory.py b/tests/unit/utils/kvstore/test_sqlite_memory.py new file mode 100644 index 0000000000..5c9d4c4ffb --- /dev/null +++ b/tests/unit/utils/kvstore/test_sqlite_memory.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from llama_stack.providers.utils.kvstore.sqlite.config import SqliteKVStoreConfig +from llama_stack.providers.utils.kvstore.sqlite.sqlite import SqliteKVStoreImpl + + +async def test_memory_kvstore_basic_operations(): + """Test basic CRUD operations with :memory: database.""" + config = SqliteKVStoreConfig(db_path=":memory:") + store = SqliteKVStoreImpl(config) + await store.initialize() + + # Test set and get + await store.set("key1", "value1") + result = await store.get("key1") + assert result == "value1" + + # Test get non-existent key + result = await store.get("nonexistent") + assert result is None + + # Test update + await store.set("key1", "updated_value") + result = await store.get("key1") + assert result == "updated_value" + + # Test delete + await store.delete("key1") + result = await store.get("key1") + assert result is None + + await store.close() + + +async def test_memory_kvstore_range_operations(): + """Test range query operations with :memory: database.""" + config = SqliteKVStoreConfig(db_path=":memory:") + store = SqliteKVStoreImpl(config) + await store.initialize() + + # Set up test data + await store.set("key_a", "value_a") + await store.set("key_b", "value_b") + await store.set("key_c", "value_c") + await store.set("key_d", "value_d") + + # Test values_in_range + values = await store.values_in_range("key_b", "key_c") + assert len(values) == 2 + assert "value_b" in values + assert "value_c" in values + + # Test keys_in_range + keys = await store.keys_in_range("key_a", "key_c") + assert len(keys) == 3 + assert "key_a" in keys + assert "key_b" in keys + assert "key_c" in keys + + await store.close() + + +async def test_memory_kvstore_multiple_instances(): + """Test that multiple :memory: instances are independent.""" + config1 = SqliteKVStoreConfig(db_path=":memory:") + config2 = SqliteKVStoreConfig(db_path=":memory:") + + store1 = SqliteKVStoreImpl(config1) + store2 = SqliteKVStoreImpl(config2) + + await store1.initialize() + await store2.initialize() + + # Set data in store1 + await store1.set("shared_key", "value_from_store1") + + # Verify store2 doesn't see store1's data + result = await store2.get("shared_key") + assert result is None + + # Set different value in store2 + await store2.set("shared_key", "value_from_store2") + + # Verify both stores have independent data + assert await store1.get("shared_key") == "value_from_store1" + assert await store2.get("shared_key") == "value_from_store2" + + await store1.close() + await store2.close() + + +async def test_memory_kvstore_persistence_behavior(): + """Test that :memory: database doesn't persist across instances.""" + config = SqliteKVStoreConfig(db_path=":memory:") + + # First instance + store1 = SqliteKVStoreImpl(config) + await store1.initialize() + await store1.set("persist_test", "should_not_persist") + await store1.close() + + # Second instance with same config + store2 = SqliteKVStoreImpl(config) + await store2.initialize() + + # Data should not be present + result = await store2.get("persist_test") + assert result is None + + await store2.close() From 5227448d6ffe8fdda9b9e561c7b5db25b89d3c68 Mon Sep 17 00:00:00 2001 From: Raghotham Murthy Date: Sat, 4 Oct 2025 18:20:18 -0700 Subject: [PATCH 2/8] fix pre-commit error --- tests/unit/utils/kvstore/test_sqlite_memory.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/utils/kvstore/test_sqlite_memory.py b/tests/unit/utils/kvstore/test_sqlite_memory.py index 5c9d4c4ffb..38326437ee 100644 --- a/tests/unit/utils/kvstore/test_sqlite_memory.py +++ b/tests/unit/utils/kvstore/test_sqlite_memory.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import pytest from llama_stack.providers.utils.kvstore.sqlite.config import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.sqlite.sqlite import SqliteKVStoreImpl From 22077e7f3212852d16b1428ef915f451e6dc9036 Mon Sep 17 00:00:00 2001 From: Raghotham Murthy Date: Mon, 6 Oct 2025 14:29:27 -0700 Subject: [PATCH 3/8] fix import --- tests/unit/utils/kvstore/test_sqlite_memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/utils/kvstore/test_sqlite_memory.py b/tests/unit/utils/kvstore/test_sqlite_memory.py index 38326437ee..16d9cc1012 100644 --- a/tests/unit/utils/kvstore/test_sqlite_memory.py +++ b/tests/unit/utils/kvstore/test_sqlite_memory.py @@ -5,7 +5,7 @@ # the root directory of this source tree. -from llama_stack.providers.utils.kvstore.sqlite.config import SqliteKVStoreConfig +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.sqlite.sqlite import SqliteKVStoreImpl From fa100c77fd815c8a5a887a342c05a19bc59969ab Mon Sep 17 00:00:00 2001 From: Raghotham Murthy Date: Tue, 7 Oct 2025 08:29:06 -0700 Subject: [PATCH 4/8] close connections --- llama_stack/providers/inline/agents/meta_reference/agents.py | 2 +- llama_stack/providers/inline/batches/reference/batches.py | 1 + llama_stack/providers/inline/eval/meta_reference/eval.py | 3 ++- llama_stack/providers/utils/kvstore/api.py | 4 ++++ llama_stack/providers/utils/kvstore/kvstore.py | 4 ++++ tests/unit/fixtures.py | 1 + .../vector_io/test_vector_io_openai_vector_stores.py | 3 --- 7 files changed, 13 insertions(+), 5 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 27d3a94cc9..0c37b05bcc 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -314,7 +314,7 @@ async def list_agent_sessions( return paginate_records(session_dicts, start_index, limit) async def shutdown(self) -> None: - pass + await self.persistence_store.close() # OpenAI responses async def get_openai_response( diff --git a/llama_stack/providers/inline/batches/reference/batches.py b/llama_stack/providers/inline/batches/reference/batches.py index 39f45d7d1a..501b17dc69 100644 --- a/llama_stack/providers/inline/batches/reference/batches.py +++ b/llama_stack/providers/inline/batches/reference/batches.py @@ -129,6 +129,7 @@ async def shutdown(self) -> None: # don't cancel tasks - just let them stop naturally on shutdown # cancelling would mark batches as "cancelled" in the database logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks") + await self.kvstore.close() # TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions async def create_batch( diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index 0dfe23dca4..ff91a8da6d 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -64,7 +64,8 @@ async def initialize(self) -> None: benchmark = Benchmark.model_validate_json(benchmark) self.benchmarks[benchmark.identifier] = benchmark - async def shutdown(self) -> None: ... + async def shutdown(self) -> None: + await self.kvstore.close() async def register_benchmark(self, task_def: Benchmark) -> None: # Store in kvstore diff --git a/llama_stack/providers/utils/kvstore/api.py b/llama_stack/providers/utils/kvstore/api.py index d17dc66e1d..06bce28914 100644 --- a/llama_stack/providers/utils/kvstore/api.py +++ b/llama_stack/providers/utils/kvstore/api.py @@ -19,3 +19,7 @@ async def delete(self, key: str) -> None: ... async def values_in_range(self, start_key: str, end_key: str) -> list[str]: ... async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: ... + + async def close(self) -> None: + """Close any persistent connections. Optional method for cleanup.""" + ... diff --git a/llama_stack/providers/utils/kvstore/kvstore.py b/llama_stack/providers/utils/kvstore/kvstore.py index 426523d8e0..475d5df96b 100644 --- a/llama_stack/providers/utils/kvstore/kvstore.py +++ b/llama_stack/providers/utils/kvstore/kvstore.py @@ -43,6 +43,10 @@ async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: async def delete(self, key: str) -> None: del self._store[key] + async def close(self) -> None: + """No-op for in-memory store.""" + pass + async def kvstore_impl(config: KVStoreConfig) -> KVStore: if config.type == KVStoreType.redis.value: diff --git a/tests/unit/fixtures.py b/tests/unit/fixtures.py index 443a1d371d..ff86f564fe 100644 --- a/tests/unit/fixtures.py +++ b/tests/unit/fixtures.py @@ -18,6 +18,7 @@ async def sqlite_kvstore(tmp_path): kvstore = SqliteKVStoreImpl(kvstore_config) await kvstore.initialize() yield kvstore + await kvstore.close() @pytest.fixture(scope="function") diff --git a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py index ed0934224a..436f32e27d 100644 --- a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py +++ b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py @@ -46,12 +46,9 @@ async def test_initialize_index(vector_index): async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings): - vector_index.delete() - vector_index.initialize() await vector_index.add_chunks(sample_chunks, sample_embeddings) resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1) assert resp.chunks[0].content == sample_chunks[0].content - vector_index.delete() async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension): From a424815804a803711bd696fab7aa1bf04828f3cc Mon Sep 17 00:00:00 2001 From: Raghotham Murthy Date: Tue, 7 Oct 2025 12:28:31 -0700 Subject: [PATCH 5/8] One more attempt with Claude's help to close connections --- llama_stack/core/server/quota.py | 5 +++ .../inline/agents/meta_reference/agents.py | 1 + .../providers/inline/files/localfs/files.py | 3 +- .../providers/remote/files/s3/files.py | 3 +- .../utils/inference/inference_store.py | 28 ++++++++-------- .../utils/responses/responses_store.py | 28 ++++++++-------- llama_stack/providers/utils/sqlstore/api.py | 6 ++++ .../utils/sqlstore/authorized_sqlstore.py | 4 +++ .../utils/sqlstore/sqlalchemy_sqlstore.py | 8 +++++ tests/unit/files/test_files.py | 1 + tests/unit/prompts/prompts/conftest.py | 2 ++ tests/unit/server/test_quota.py | 32 +++++++++++++------ .../utils/inference/test_inference_store.py | 10 ++++++ .../utils/responses/test_responses_store.py | 16 ++++++++++ tests/unit/utils/sqlstore/test_sqlstore.py | 17 ++++++++++ tests/unit/utils/test_authorized_sqlstore.py | 6 ++++ 16 files changed, 132 insertions(+), 38 deletions(-) diff --git a/llama_stack/core/server/quota.py b/llama_stack/core/server/quota.py index 693f224c32..17832246df 100644 --- a/llama_stack/core/server/quota.py +++ b/llama_stack/core/server/quota.py @@ -108,3 +108,8 @@ async def _send_error(self, send: Send, status: int, message: str): ) body = json.dumps({"error": {"message": message}}).encode() await send({"type": "http.response.body", "body": body}) + + async def close(self): + """Close the KV store connection.""" + if self.kv is not None: + await self.kv.close() diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 0c37b05bcc..a67d8ade9b 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -315,6 +315,7 @@ async def list_agent_sessions( async def shutdown(self) -> None: await self.persistence_store.close() + await self.responses_store.shutdown() # OpenAI responses async def get_openai_response( diff --git a/llama_stack/providers/inline/files/localfs/files.py b/llama_stack/providers/inline/files/localfs/files.py index a76b982cec..b48975702a 100644 --- a/llama_stack/providers/inline/files/localfs/files.py +++ b/llama_stack/providers/inline/files/localfs/files.py @@ -62,7 +62,8 @@ async def initialize(self) -> None: ) async def shutdown(self) -> None: - pass + if self.sql_store: + await self.sql_store.close() def _generate_file_id(self) -> str: """Generate a unique file ID for OpenAI API.""" diff --git a/llama_stack/providers/remote/files/s3/files.py b/llama_stack/providers/remote/files/s3/files.py index c0e9f81d6a..938f6142dc 100644 --- a/llama_stack/providers/remote/files/s3/files.py +++ b/llama_stack/providers/remote/files/s3/files.py @@ -181,7 +181,8 @@ async def initialize(self) -> None: ) async def shutdown(self) -> None: - pass + if self._sql_store: + await self._sql_store.close() @property def client(self) -> boto3.client: diff --git a/llama_stack/providers/utils/inference/inference_store.py b/llama_stack/providers/utils/inference/inference_store.py index 901f77c679..44ab8c0ce6 100644 --- a/llama_stack/providers/utils/inference/inference_store.py +++ b/llama_stack/providers/utils/inference/inference_store.py @@ -74,19 +74,21 @@ async def initialize(self): logger.info("Write queue disabled for SQLite to avoid concurrency issues") async def shutdown(self) -> None: - if not self._worker_tasks: - return - if self._queue is not None: - await self._queue.join() - for t in self._worker_tasks: - if not t.done(): - t.cancel() - for t in self._worker_tasks: - try: - await t - except asyncio.CancelledError: - pass - self._worker_tasks.clear() + if self._worker_tasks: + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() + + if self.sql_store: + await self.sql_store.close() async def flush(self) -> None: """Wait for all queued writes to complete. Useful for testing.""" diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index e610a1ba26..80b17d116e 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -96,19 +96,21 @@ async def initialize(self): logger.info("Write queue disabled for SQLite to avoid concurrency issues") async def shutdown(self) -> None: - if not self._worker_tasks: - return - if self._queue is not None: - await self._queue.join() - for t in self._worker_tasks: - if not t.done(): - t.cancel() - for t in self._worker_tasks: - try: - await t - except asyncio.CancelledError: - pass - self._worker_tasks.clear() + if self._worker_tasks: + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() + + if self.sql_store: + await self.sql_store.close() async def flush(self) -> None: """Wait for all queued writes to complete. Useful for testing.""" diff --git a/llama_stack/providers/utils/sqlstore/api.py b/llama_stack/providers/utils/sqlstore/api.py index a61fd1090e..9061a2eadd 100644 --- a/llama_stack/providers/utils/sqlstore/api.py +++ b/llama_stack/providers/utils/sqlstore/api.py @@ -126,3 +126,9 @@ async def add_column_if_not_exists( :param nullable: Whether the column should be nullable (default: True) """ pass + + async def close(self) -> None: + """ + Close any persistent database connections. + """ + pass diff --git a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index e1da4db6e0..373deeb234 100644 --- a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -197,6 +197,10 @@ async def delete(self, table: str, where: Mapping[str, Any]) -> None: """Delete rows with automatic access control filtering.""" await self.sql_store.delete(table, where) + async def close(self) -> None: + """Close the underlying SQL store connection.""" + await self.sql_store.close() + def _build_access_control_where_clause(self, policy: list[AccessRule]) -> str: """Build SQL WHERE clause for access control filtering. diff --git a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index 23cd6444ec..088b1c5545 100644 --- a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -311,3 +311,11 @@ def check_column_exists(sync_conn): # The table creation will handle adding the column logger.error(f"Error adding column {column_name} to table {table}: {e}") pass + + async def close(self) -> None: + """Close the database engine and all connections.""" + if hasattr(self, "async_session"): + # Get the engine from the session maker + engine = self.async_session.kw.get("bind") + if engine: + await engine.dispose() diff --git a/tests/unit/files/test_files.py b/tests/unit/files/test_files.py index e14e033b95..b227d69dea 100644 --- a/tests/unit/files/test_files.py +++ b/tests/unit/files/test_files.py @@ -43,6 +43,7 @@ async def files_provider(tmp_path): provider = LocalfsFilesImpl(config, default_policy()) await provider.initialize() yield provider + await provider.shutdown() @pytest.fixture diff --git a/tests/unit/prompts/prompts/conftest.py b/tests/unit/prompts/prompts/conftest.py index b2c619e493..94c10f6bc4 100644 --- a/tests/unit/prompts/prompts/conftest.py +++ b/tests/unit/prompts/prompts/conftest.py @@ -28,3 +28,5 @@ async def temp_prompt_store(tmp_path_factory): store.kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)) yield store + + await store.kvstore.close() diff --git a/tests/unit/server/test_quota.py b/tests/unit/server/test_quota.py index 85acbc66aa..d3c569049b 100644 --- a/tests/unit/server/test_quota.py +++ b/tests/unit/server/test_quota.py @@ -52,17 +52,21 @@ async def test_endpoint(): db_path = tmp_path / f"quota_{request.node.name}.db" quota = build_quota_config(db_path) - app = InjectClientIDMiddleware( - QuotaMiddleware( - inner_app, - kv_config=quota.kvstore, - anonymous_max_requests=quota.anonymous_max_requests, - authenticated_max_requests=quota.authenticated_max_requests, - window_seconds=86400, - ), - client_id=f"client_{request.node.name}", + quota_middleware = QuotaMiddleware( + inner_app, + kv_config=quota.kvstore, + anonymous_max_requests=quota.anonymous_max_requests, + authenticated_max_requests=quota.authenticated_max_requests, + window_seconds=86400, ) - return app + app = InjectClientIDMiddleware(quota_middleware, client_id=f"client_{request.node.name}") + + yield app + + # Cleanup + import asyncio + + asyncio.run(quota_middleware.close()) def test_authenticated_quota_allows_up_to_limit(auth_app): @@ -81,6 +85,8 @@ def test_authenticated_quota_blocks_after_limit(auth_app): def test_anonymous_quota_allows_up_to_limit(tmp_path, request): + import asyncio + inner_app = FastAPI() @inner_app.get("/test") @@ -101,8 +107,12 @@ async def test_endpoint(): client = TestClient(app) assert client.get("/test").status_code == 200 + asyncio.run(app.close()) + def test_anonymous_quota_blocks_after_limit(tmp_path, request): + import asyncio + inner_app = FastAPI() @inner_app.get("/test") @@ -125,3 +135,5 @@ async def test_endpoint(): resp = client.get("/test") assert resp.status_code == 429 assert resp.json()["error"]["message"] == "Quota exceeded" + + asyncio.run(app.close()) diff --git a/tests/unit/utils/inference/test_inference_store.py b/tests/unit/utils/inference/test_inference_store.py index f6d63490ab..4bea03b880 100644 --- a/tests/unit/utils/inference/test_inference_store.py +++ b/tests/unit/utils/inference/test_inference_store.py @@ -89,6 +89,8 @@ async def test_inference_store_pagination_basic(): assert result3.data[0].id == "zebra-task" assert result3.has_more is False + await store.sql_store.close() + async def test_inference_store_pagination_ascending(): """Test pagination with ascending order.""" @@ -126,6 +128,8 @@ async def test_inference_store_pagination_ascending(): assert result2.data[0].id == "charlie-task" assert result2.has_more is True + await store.sql_store.close() + async def test_inference_store_pagination_with_model_filter(): """Test pagination combined with model filtering.""" @@ -166,6 +170,8 @@ async def test_inference_store_pagination_with_model_filter(): assert result2.data[0].model == "model-a" assert result2.has_more is False + await store.sql_store.close() + async def test_inference_store_pagination_invalid_after(): """Test error handling for invalid 'after' parameter.""" @@ -178,6 +184,8 @@ async def test_inference_store_pagination_invalid_after(): with pytest.raises(ValueError, match="Record with id='non-existent' not found in table 'chat_completions'"): await store.list_chat_completions(after="non-existent", limit=2) + await store.sql_store.close() + async def test_inference_store_pagination_no_limit(): """Test pagination behavior when no limit is specified.""" @@ -208,3 +216,5 @@ async def test_inference_store_pagination_no_limit(): assert result.data[0].id == "beta-second" # Most recent first assert result.data[1].id == "omega-first" assert result.has_more is False + + await store.sql_store.close() diff --git a/tests/unit/utils/responses/test_responses_store.py b/tests/unit/utils/responses/test_responses_store.py index c27b5a8e5f..aa5c1a7e84 100644 --- a/tests/unit/utils/responses/test_responses_store.py +++ b/tests/unit/utils/responses/test_responses_store.py @@ -98,6 +98,8 @@ async def test_responses_store_pagination_basic(): assert result3.data[0].id == "zebra-resp" assert result3.has_more is False + await store.sql_store.close() + async def test_responses_store_pagination_ascending(): """Test pagination with ascending order.""" @@ -136,6 +138,8 @@ async def test_responses_store_pagination_ascending(): assert result2.data[0].id == "charlie-resp" assert result2.has_more is True + await store.sql_store.close() + async def test_responses_store_pagination_with_model_filter(): """Test pagination combined with model filtering.""" @@ -177,6 +181,8 @@ async def test_responses_store_pagination_with_model_filter(): assert result2.data[0].model == "model-a" assert result2.has_more is False + await store.sql_store.close() + async def test_responses_store_pagination_invalid_after(): """Test error handling for invalid 'after' parameter.""" @@ -189,6 +195,8 @@ async def test_responses_store_pagination_invalid_after(): with pytest.raises(ValueError, match="Record with id.*'non-existent' not found in table 'openai_responses'"): await store.list_responses(after="non-existent", limit=2) + await store.sql_store.close() + async def test_responses_store_pagination_no_limit(): """Test pagination behavior when no limit is specified.""" @@ -221,6 +229,8 @@ async def test_responses_store_pagination_no_limit(): assert result.data[1].id == "omega-resp" assert result.has_more is False + await store.sql_store.close() + async def test_responses_store_get_response_object(): """Test retrieving a single response object.""" @@ -249,6 +259,8 @@ async def test_responses_store_get_response_object(): with pytest.raises(ValueError, match="Response with id non-existent not found"): await store.get_response_object("non-existent") + await store.sql_store.close() + async def test_responses_store_input_items_pagination(): """Test pagination functionality for input items.""" @@ -330,6 +342,8 @@ async def test_responses_store_input_items_pagination(): with pytest.raises(ValueError, match="Cannot specify both 'before' and 'after' parameters"): await store.list_response_input_items("test-resp", before="some-id", after="other-id") + await store.sql_store.close() + async def test_responses_store_input_items_before_pagination(): """Test before pagination functionality for input items.""" @@ -390,3 +404,5 @@ async def test_responses_store_input_items_before_pagination(): ValueError, match="Input item with id 'non-existent' not found for response 'test-resp-before'" ): await store.list_response_input_items("test-resp-before", before="non-existent") + + await store.sql_store.close() diff --git a/tests/unit/utils/sqlstore/test_sqlstore.py b/tests/unit/utils/sqlstore/test_sqlstore.py index 00669b698c..a68a5b6819 100644 --- a/tests/unit/utils/sqlstore/test_sqlstore.py +++ b/tests/unit/utils/sqlstore/test_sqlstore.py @@ -64,6 +64,9 @@ async def test_sqlite_sqlstore(): assert result.data == [{"id": 12, "name": "test12"}] assert result.has_more is False + # cleanup + await sqlstore.close() + async def test_sqlstore_pagination_basic(): """Test basic pagination functionality at the SQL store level.""" @@ -128,6 +131,8 @@ async def test_sqlstore_pagination_basic(): assert result3.data[0]["id"] == "zebra" assert result3.has_more is False + await store.close() + async def test_sqlstore_pagination_with_filter(): """Test pagination with WHERE conditions.""" @@ -180,6 +185,8 @@ async def test_sqlstore_pagination_with_filter(): assert result2.data[0]["id"] == "xyz" assert result2.has_more is False + await store.close() + async def test_sqlstore_pagination_ascending_order(): """Test pagination with ascending order.""" @@ -228,6 +235,8 @@ async def test_sqlstore_pagination_ascending_order(): assert result2.data[0]["id"] == "alpha" assert result2.has_more is True + await store.close() + async def test_sqlstore_pagination_multi_column_ordering_error(): """Test that multi-column ordering raises an error when using cursor pagination.""" @@ -265,6 +274,8 @@ async def test_sqlstore_pagination_multi_column_ordering_error(): assert len(result.data) == 1 assert result.data[0]["id"] == "task1" + await store.close() + async def test_sqlstore_pagination_cursor_requires_order_by(): """Test that cursor pagination requires order_by parameter.""" @@ -282,6 +293,8 @@ async def test_sqlstore_pagination_cursor_requires_order_by(): cursor=("id", "task1"), ) + await store.close() + async def test_sqlstore_pagination_error_handling(): """Test error handling for invalid columns and cursor IDs.""" @@ -414,6 +427,8 @@ async def test_where_operator_edge_cases(): with pytest.raises(ValueError, match="Unsupported operator"): await store.fetch_all("events", where={"ts": {"!=": base}}) + await store.close() + async def test_sqlstore_pagination_custom_key_column(): """Test pagination with custom primary key column (not 'id').""" @@ -463,3 +478,5 @@ async def test_sqlstore_pagination_custom_key_column(): assert len(result2.data) == 1 assert result2.data[0]["uuid"] == "uuid-alpha" assert result2.has_more is False + + await store.close() diff --git a/tests/unit/utils/test_authorized_sqlstore.py b/tests/unit/utils/test_authorized_sqlstore.py index d85e784a9b..8f395e6164 100644 --- a/tests/unit/utils/test_authorized_sqlstore.py +++ b/tests/unit/utils/test_authorized_sqlstore.py @@ -77,6 +77,8 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic assert row is not None assert row["title"] == "User Document" + await base_sqlstore.close() + @patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") async def test_sql_policy_consistency(mock_get_authenticated_user): @@ -163,6 +165,8 @@ async def test_sql_policy_consistency(mock_get_authenticated_user): f"Difference: SQL only: {sql_ids - policy_ids}, Policy only: {policy_ids - sql_ids}" ) + await base_sqlstore.close() + @patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user): @@ -211,3 +215,5 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us # Third item should have null attributes (no authenticated user) assert result.data[2]["id"] == "item3" assert result.data[2]["access_attributes"] is None + + await base_sqlstore.close() From ed78090b8e6841dceb8d24b89924105a2fab3bdd Mon Sep 17 00:00:00 2001 From: Raghotham Murthy Date: Tue, 7 Oct 2025 14:19:29 -0700 Subject: [PATCH 6/8] address comments --- .../providers/utils/kvstore/sqlite/sqlite.py | 26 +++--- .../unit/utils/kvstore/test_sqlite_memory.py | 85 ------------------- 2 files changed, 14 insertions(+), 97 deletions(-) diff --git a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py index e1c0332fee..4759bfa881 100644 --- a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py +++ b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py @@ -56,17 +56,22 @@ async def close(self): await self._conn.close() self._conn = None + @property + def conn(self) -> aiosqlite.Connection: + """Get the connection, raising an error if not initialized.""" + if self._conn is None: + raise RuntimeError("Connection not initialized. Call initialize() first.") + return self._conn + async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: - assert self._conn is not None, "Connection not initialized. Call initialize() first." - await self._conn.execute( + await self.conn.execute( f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)", (key, value, expiration), ) - await self._conn.commit() + await self.conn.commit() async def get(self, key: str) -> str | None: - assert self._conn is not None, "Connection not initialized. Call initialize() first." - async with self._conn.execute( + async with self.conn.execute( f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,) ) as cursor: row = await cursor.fetchone() @@ -79,13 +84,11 @@ async def get(self, key: str) -> str | None: return value async def delete(self, key: str) -> None: - assert self._conn is not None, "Connection not initialized. Call initialize() first." - await self._conn.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) - await self._conn.commit() + await self.conn.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) + await self.conn.commit() async def values_in_range(self, start_key: str, end_key: str) -> list[str]: - assert self._conn is not None, "Connection not initialized. Call initialize() first." - async with self._conn.execute( + async with self.conn.execute( f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", (start_key, end_key), ) as cursor: @@ -97,8 +100,7 @@ async def values_in_range(self, start_key: str, end_key: str) -> list[str]: async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: """Get all keys in the given range.""" - assert self._conn is not None, "Connection not initialized. Call initialize() first." - cursor = await self._conn.execute( + cursor = await self.conn.execute( f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?", (start_key, end_key), ) diff --git a/tests/unit/utils/kvstore/test_sqlite_memory.py b/tests/unit/utils/kvstore/test_sqlite_memory.py index 16d9cc1012..942ad10875 100644 --- a/tests/unit/utils/kvstore/test_sqlite_memory.py +++ b/tests/unit/utils/kvstore/test_sqlite_memory.py @@ -9,91 +9,6 @@ from llama_stack.providers.utils.kvstore.sqlite.sqlite import SqliteKVStoreImpl -async def test_memory_kvstore_basic_operations(): - """Test basic CRUD operations with :memory: database.""" - config = SqliteKVStoreConfig(db_path=":memory:") - store = SqliteKVStoreImpl(config) - await store.initialize() - - # Test set and get - await store.set("key1", "value1") - result = await store.get("key1") - assert result == "value1" - - # Test get non-existent key - result = await store.get("nonexistent") - assert result is None - - # Test update - await store.set("key1", "updated_value") - result = await store.get("key1") - assert result == "updated_value" - - # Test delete - await store.delete("key1") - result = await store.get("key1") - assert result is None - - await store.close() - - -async def test_memory_kvstore_range_operations(): - """Test range query operations with :memory: database.""" - config = SqliteKVStoreConfig(db_path=":memory:") - store = SqliteKVStoreImpl(config) - await store.initialize() - - # Set up test data - await store.set("key_a", "value_a") - await store.set("key_b", "value_b") - await store.set("key_c", "value_c") - await store.set("key_d", "value_d") - - # Test values_in_range - values = await store.values_in_range("key_b", "key_c") - assert len(values) == 2 - assert "value_b" in values - assert "value_c" in values - - # Test keys_in_range - keys = await store.keys_in_range("key_a", "key_c") - assert len(keys) == 3 - assert "key_a" in keys - assert "key_b" in keys - assert "key_c" in keys - - await store.close() - - -async def test_memory_kvstore_multiple_instances(): - """Test that multiple :memory: instances are independent.""" - config1 = SqliteKVStoreConfig(db_path=":memory:") - config2 = SqliteKVStoreConfig(db_path=":memory:") - - store1 = SqliteKVStoreImpl(config1) - store2 = SqliteKVStoreImpl(config2) - - await store1.initialize() - await store2.initialize() - - # Set data in store1 - await store1.set("shared_key", "value_from_store1") - - # Verify store2 doesn't see store1's data - result = await store2.get("shared_key") - assert result is None - - # Set different value in store2 - await store2.set("shared_key", "value_from_store2") - - # Verify both stores have independent data - assert await store1.get("shared_key") == "value_from_store1" - assert await store2.get("shared_key") == "value_from_store2" - - await store1.close() - await store2.close() - - async def test_memory_kvstore_persistence_behavior(): """Test that :memory: database doesn't persist across instances.""" config = SqliteKVStoreConfig(db_path=":memory:") From 06d02bf3de43844e56efe019b94923ddfee342b2 Mon Sep 17 00:00:00 2001 From: Raghotham Murthy Date: Sat, 11 Oct 2025 01:32:38 -0700 Subject: [PATCH 7/8] minimize change --- llama_stack/core/server/quota.py | 5 - .../inline/agents/meta_reference/agents.py | 3 +- .../inline/batches/reference/batches.py | 1 - .../inline/eval/meta_reference/eval.py | 3 +- .../providers/inline/files/localfs/files.py | 3 +- .../providers/remote/files/s3/files.py | 3 +- .../utils/inference/inference_store.py | 28 ++- llama_stack/providers/utils/kvstore/api.py | 4 - .../providers/utils/kvstore/kvstore.py | 4 - .../providers/utils/kvstore/sqlite/sqlite.py | 170 ++++++++++++------ .../utils/responses/responses_store.py | 28 ++- llama_stack/providers/utils/sqlstore/api.py | 6 - .../utils/sqlstore/authorized_sqlstore.py | 4 - .../utils/sqlstore/sqlalchemy_sqlstore.py | 8 - tests/unit/files/test_files.py | 1 - tests/unit/fixtures.py | 1 - tests/unit/prompts/prompts/conftest.py | 2 - .../test_vector_io_openai_vector_stores.py | 3 + tests/unit/server/test_quota.py | 32 ++-- .../utils/inference/test_inference_store.py | 10 -- .../utils/responses/test_responses_store.py | 16 -- tests/unit/utils/sqlstore/test_sqlstore.py | 17 -- tests/unit/utils/test_authorized_sqlstore.py | 6 - 23 files changed, 161 insertions(+), 197 deletions(-) diff --git a/llama_stack/core/server/quota.py b/llama_stack/core/server/quota.py index 17832246df..693f224c32 100644 --- a/llama_stack/core/server/quota.py +++ b/llama_stack/core/server/quota.py @@ -108,8 +108,3 @@ async def _send_error(self, send: Send, status: int, message: str): ) body = json.dumps({"error": {"message": message}}).encode() await send({"type": "http.response.body", "body": body}) - - async def close(self): - """Close the KV store connection.""" - if self.kv is not None: - await self.kv.close() diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index a67d8ade9b..27d3a94cc9 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -314,8 +314,7 @@ async def list_agent_sessions( return paginate_records(session_dicts, start_index, limit) async def shutdown(self) -> None: - await self.persistence_store.close() - await self.responses_store.shutdown() + pass # OpenAI responses async def get_openai_response( diff --git a/llama_stack/providers/inline/batches/reference/batches.py b/llama_stack/providers/inline/batches/reference/batches.py index 501b17dc69..39f45d7d1a 100644 --- a/llama_stack/providers/inline/batches/reference/batches.py +++ b/llama_stack/providers/inline/batches/reference/batches.py @@ -129,7 +129,6 @@ async def shutdown(self) -> None: # don't cancel tasks - just let them stop naturally on shutdown # cancelling would mark batches as "cancelled" in the database logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks") - await self.kvstore.close() # TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions async def create_batch( diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index ff91a8da6d..0dfe23dca4 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -64,8 +64,7 @@ async def initialize(self) -> None: benchmark = Benchmark.model_validate_json(benchmark) self.benchmarks[benchmark.identifier] = benchmark - async def shutdown(self) -> None: - await self.kvstore.close() + async def shutdown(self) -> None: ... async def register_benchmark(self, task_def: Benchmark) -> None: # Store in kvstore diff --git a/llama_stack/providers/inline/files/localfs/files.py b/llama_stack/providers/inline/files/localfs/files.py index b48975702a..a76b982cec 100644 --- a/llama_stack/providers/inline/files/localfs/files.py +++ b/llama_stack/providers/inline/files/localfs/files.py @@ -62,8 +62,7 @@ async def initialize(self) -> None: ) async def shutdown(self) -> None: - if self.sql_store: - await self.sql_store.close() + pass def _generate_file_id(self) -> str: """Generate a unique file ID for OpenAI API.""" diff --git a/llama_stack/providers/remote/files/s3/files.py b/llama_stack/providers/remote/files/s3/files.py index 938f6142dc..c0e9f81d6a 100644 --- a/llama_stack/providers/remote/files/s3/files.py +++ b/llama_stack/providers/remote/files/s3/files.py @@ -181,8 +181,7 @@ async def initialize(self) -> None: ) async def shutdown(self) -> None: - if self._sql_store: - await self._sql_store.close() + pass @property def client(self) -> boto3.client: diff --git a/llama_stack/providers/utils/inference/inference_store.py b/llama_stack/providers/utils/inference/inference_store.py index 44ab8c0ce6..901f77c679 100644 --- a/llama_stack/providers/utils/inference/inference_store.py +++ b/llama_stack/providers/utils/inference/inference_store.py @@ -74,21 +74,19 @@ async def initialize(self): logger.info("Write queue disabled for SQLite to avoid concurrency issues") async def shutdown(self) -> None: - if self._worker_tasks: - if self._queue is not None: - await self._queue.join() - for t in self._worker_tasks: - if not t.done(): - t.cancel() - for t in self._worker_tasks: - try: - await t - except asyncio.CancelledError: - pass - self._worker_tasks.clear() - - if self.sql_store: - await self.sql_store.close() + if not self._worker_tasks: + return + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() async def flush(self) -> None: """Wait for all queued writes to complete. Useful for testing.""" diff --git a/llama_stack/providers/utils/kvstore/api.py b/llama_stack/providers/utils/kvstore/api.py index 06bce28914..d17dc66e1d 100644 --- a/llama_stack/providers/utils/kvstore/api.py +++ b/llama_stack/providers/utils/kvstore/api.py @@ -19,7 +19,3 @@ async def delete(self, key: str) -> None: ... async def values_in_range(self, start_key: str, end_key: str) -> list[str]: ... async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: ... - - async def close(self) -> None: - """Close any persistent connections. Optional method for cleanup.""" - ... diff --git a/llama_stack/providers/utils/kvstore/kvstore.py b/llama_stack/providers/utils/kvstore/kvstore.py index 475d5df96b..426523d8e0 100644 --- a/llama_stack/providers/utils/kvstore/kvstore.py +++ b/llama_stack/providers/utils/kvstore/kvstore.py @@ -43,10 +43,6 @@ async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: async def delete(self, key: str) -> None: del self._store[key] - async def close(self) -> None: - """No-op for in-memory store.""" - pass - async def kvstore_impl(config: KVStoreConfig) -> KVStore: if config.type == KVStoreType.redis.value: diff --git a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py index 4759bfa881..7c816dc48b 100644 --- a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py +++ b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py @@ -37,72 +37,138 @@ async def initialize(self): if db_dir: # Only create if there's a directory component os.makedirs(db_dir, exist_ok=True) - # Create persistent connection for all databases - self._conn = await aiosqlite.connect(self.db_path) - await self._conn.execute( - f""" - CREATE TABLE IF NOT EXISTS {self.table_name} ( - key TEXT PRIMARY KEY, - value TEXT, - expiration TIMESTAMP + # Only use persistent connection for in-memory databases + # File-based databases use connection-per-operation to avoid hangs + if self._is_memory_db(): + self._conn = await aiosqlite.connect(self.db_path) + await self._conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + key TEXT PRIMARY KEY, + value TEXT, + expiration TIMESTAMP + ) + """ ) - """ - ) - await self._conn.commit() + await self._conn.commit() + else: + # For file-based databases, just create the table + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + key TEXT PRIMARY KEY, + value TEXT, + expiration TIMESTAMP + ) + """ + ) + await db.commit() async def close(self): - """Close the persistent connection.""" + """Close the persistent connection (only for in-memory databases).""" if self._conn: await self._conn.close() self._conn = None - @property - def conn(self) -> aiosqlite.Connection: - """Get the connection, raising an error if not initialized.""" - if self._conn is None: - raise RuntimeError("Connection not initialized. Call initialize() first.") - return self._conn - async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: - await self.conn.execute( - f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)", - (key, value, expiration), - ) - await self.conn.commit() + if self._conn: + # In-memory database with persistent connection + await self._conn.execute( + f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)", + (key, value, expiration), + ) + await self._conn.commit() + else: + # File-based database with connection per operation + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)", + (key, value, expiration), + ) + await db.commit() async def get(self, key: str) -> str | None: - async with self.conn.execute( - f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,) - ) as cursor: - row = await cursor.fetchone() - if row is None: - return None - value, expiration = row - if not isinstance(value, str): - logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None") - return None - return value + if self._conn: + # In-memory database with persistent connection + async with self._conn.execute( + f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,) + ) as cursor: + row = await cursor.fetchone() + if row is None: + return None + value, expiration = row + if not isinstance(value, str): + logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None") + return None + return value + else: + # File-based database with connection per operation + async with aiosqlite.connect(self.db_path) as db: + async with db.execute( + f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,) + ) as cursor: + row = await cursor.fetchone() + if row is None: + return None + value, expiration = row + if not isinstance(value, str): + logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None") + return None + return value async def delete(self, key: str) -> None: - await self.conn.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) - await self.conn.commit() + if self._conn: + # In-memory database with persistent connection + await self._conn.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) + await self._conn.commit() + else: + # File-based database with connection per operation + async with aiosqlite.connect(self.db_path) as db: + await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) + await db.commit() async def values_in_range(self, start_key: str, end_key: str) -> list[str]: - async with self.conn.execute( - f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", - (start_key, end_key), - ) as cursor: - result = [] - async for row in cursor: - _, value, _ = row - result.append(value) - return result + if self._conn: + # In-memory database with persistent connection + async with self._conn.execute( + f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", + (start_key, end_key), + ) as cursor: + result = [] + async for row in cursor: + _, value, _ = row + result.append(value) + return result + else: + # File-based database with connection per operation + async with aiosqlite.connect(self.db_path) as db: + async with db.execute( + f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", + (start_key, end_key), + ) as cursor: + result = [] + async for row in cursor: + _, value, _ = row + result.append(value) + return result async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: """Get all keys in the given range.""" - cursor = await self.conn.execute( - f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?", - (start_key, end_key), - ) - rows = await cursor.fetchall() - return [row[0] for row in rows] + if self._conn: + # In-memory database with persistent connection + cursor = await self._conn.execute( + f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?", + (start_key, end_key), + ) + rows = await cursor.fetchall() + return [row[0] for row in rows] + else: + # File-based database with connection per operation + async with aiosqlite.connect(self.db_path) as db: + cursor = await db.execute( + f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?", + (start_key, end_key), + ) + rows = await cursor.fetchall() + return [row[0] for row in rows] diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 80b17d116e..e610a1ba26 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -96,21 +96,19 @@ async def initialize(self): logger.info("Write queue disabled for SQLite to avoid concurrency issues") async def shutdown(self) -> None: - if self._worker_tasks: - if self._queue is not None: - await self._queue.join() - for t in self._worker_tasks: - if not t.done(): - t.cancel() - for t in self._worker_tasks: - try: - await t - except asyncio.CancelledError: - pass - self._worker_tasks.clear() - - if self.sql_store: - await self.sql_store.close() + if not self._worker_tasks: + return + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() async def flush(self) -> None: """Wait for all queued writes to complete. Useful for testing.""" diff --git a/llama_stack/providers/utils/sqlstore/api.py b/llama_stack/providers/utils/sqlstore/api.py index 9061a2eadd..a61fd1090e 100644 --- a/llama_stack/providers/utils/sqlstore/api.py +++ b/llama_stack/providers/utils/sqlstore/api.py @@ -126,9 +126,3 @@ async def add_column_if_not_exists( :param nullable: Whether the column should be nullable (default: True) """ pass - - async def close(self) -> None: - """ - Close any persistent database connections. - """ - pass diff --git a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index 373deeb234..e1da4db6e0 100644 --- a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -197,10 +197,6 @@ async def delete(self, table: str, where: Mapping[str, Any]) -> None: """Delete rows with automatic access control filtering.""" await self.sql_store.delete(table, where) - async def close(self) -> None: - """Close the underlying SQL store connection.""" - await self.sql_store.close() - def _build_access_control_where_clause(self, policy: list[AccessRule]) -> str: """Build SQL WHERE clause for access control filtering. diff --git a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index 088b1c5545..23cd6444ec 100644 --- a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -311,11 +311,3 @@ def check_column_exists(sync_conn): # The table creation will handle adding the column logger.error(f"Error adding column {column_name} to table {table}: {e}") pass - - async def close(self) -> None: - """Close the database engine and all connections.""" - if hasattr(self, "async_session"): - # Get the engine from the session maker - engine = self.async_session.kw.get("bind") - if engine: - await engine.dispose() diff --git a/tests/unit/files/test_files.py b/tests/unit/files/test_files.py index b227d69dea..e14e033b95 100644 --- a/tests/unit/files/test_files.py +++ b/tests/unit/files/test_files.py @@ -43,7 +43,6 @@ async def files_provider(tmp_path): provider = LocalfsFilesImpl(config, default_policy()) await provider.initialize() yield provider - await provider.shutdown() @pytest.fixture diff --git a/tests/unit/fixtures.py b/tests/unit/fixtures.py index ff86f564fe..443a1d371d 100644 --- a/tests/unit/fixtures.py +++ b/tests/unit/fixtures.py @@ -18,7 +18,6 @@ async def sqlite_kvstore(tmp_path): kvstore = SqliteKVStoreImpl(kvstore_config) await kvstore.initialize() yield kvstore - await kvstore.close() @pytest.fixture(scope="function") diff --git a/tests/unit/prompts/prompts/conftest.py b/tests/unit/prompts/prompts/conftest.py index 94c10f6bc4..b2c619e493 100644 --- a/tests/unit/prompts/prompts/conftest.py +++ b/tests/unit/prompts/prompts/conftest.py @@ -28,5 +28,3 @@ async def temp_prompt_store(tmp_path_factory): store.kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)) yield store - - await store.kvstore.close() diff --git a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py index 436f32e27d..ed0934224a 100644 --- a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py +++ b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py @@ -46,9 +46,12 @@ async def test_initialize_index(vector_index): async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings): + vector_index.delete() + vector_index.initialize() await vector_index.add_chunks(sample_chunks, sample_embeddings) resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1) assert resp.chunks[0].content == sample_chunks[0].content + vector_index.delete() async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension): diff --git a/tests/unit/server/test_quota.py b/tests/unit/server/test_quota.py index d3c569049b..85acbc66aa 100644 --- a/tests/unit/server/test_quota.py +++ b/tests/unit/server/test_quota.py @@ -52,21 +52,17 @@ async def test_endpoint(): db_path = tmp_path / f"quota_{request.node.name}.db" quota = build_quota_config(db_path) - quota_middleware = QuotaMiddleware( - inner_app, - kv_config=quota.kvstore, - anonymous_max_requests=quota.anonymous_max_requests, - authenticated_max_requests=quota.authenticated_max_requests, - window_seconds=86400, + app = InjectClientIDMiddleware( + QuotaMiddleware( + inner_app, + kv_config=quota.kvstore, + anonymous_max_requests=quota.anonymous_max_requests, + authenticated_max_requests=quota.authenticated_max_requests, + window_seconds=86400, + ), + client_id=f"client_{request.node.name}", ) - app = InjectClientIDMiddleware(quota_middleware, client_id=f"client_{request.node.name}") - - yield app - - # Cleanup - import asyncio - - asyncio.run(quota_middleware.close()) + return app def test_authenticated_quota_allows_up_to_limit(auth_app): @@ -85,8 +81,6 @@ def test_authenticated_quota_blocks_after_limit(auth_app): def test_anonymous_quota_allows_up_to_limit(tmp_path, request): - import asyncio - inner_app = FastAPI() @inner_app.get("/test") @@ -107,12 +101,8 @@ async def test_endpoint(): client = TestClient(app) assert client.get("/test").status_code == 200 - asyncio.run(app.close()) - def test_anonymous_quota_blocks_after_limit(tmp_path, request): - import asyncio - inner_app = FastAPI() @inner_app.get("/test") @@ -135,5 +125,3 @@ async def test_endpoint(): resp = client.get("/test") assert resp.status_code == 429 assert resp.json()["error"]["message"] == "Quota exceeded" - - asyncio.run(app.close()) diff --git a/tests/unit/utils/inference/test_inference_store.py b/tests/unit/utils/inference/test_inference_store.py index 4bea03b880..f6d63490ab 100644 --- a/tests/unit/utils/inference/test_inference_store.py +++ b/tests/unit/utils/inference/test_inference_store.py @@ -89,8 +89,6 @@ async def test_inference_store_pagination_basic(): assert result3.data[0].id == "zebra-task" assert result3.has_more is False - await store.sql_store.close() - async def test_inference_store_pagination_ascending(): """Test pagination with ascending order.""" @@ -128,8 +126,6 @@ async def test_inference_store_pagination_ascending(): assert result2.data[0].id == "charlie-task" assert result2.has_more is True - await store.sql_store.close() - async def test_inference_store_pagination_with_model_filter(): """Test pagination combined with model filtering.""" @@ -170,8 +166,6 @@ async def test_inference_store_pagination_with_model_filter(): assert result2.data[0].model == "model-a" assert result2.has_more is False - await store.sql_store.close() - async def test_inference_store_pagination_invalid_after(): """Test error handling for invalid 'after' parameter.""" @@ -184,8 +178,6 @@ async def test_inference_store_pagination_invalid_after(): with pytest.raises(ValueError, match="Record with id='non-existent' not found in table 'chat_completions'"): await store.list_chat_completions(after="non-existent", limit=2) - await store.sql_store.close() - async def test_inference_store_pagination_no_limit(): """Test pagination behavior when no limit is specified.""" @@ -216,5 +208,3 @@ async def test_inference_store_pagination_no_limit(): assert result.data[0].id == "beta-second" # Most recent first assert result.data[1].id == "omega-first" assert result.has_more is False - - await store.sql_store.close() diff --git a/tests/unit/utils/responses/test_responses_store.py b/tests/unit/utils/responses/test_responses_store.py index aa5c1a7e84..c27b5a8e5f 100644 --- a/tests/unit/utils/responses/test_responses_store.py +++ b/tests/unit/utils/responses/test_responses_store.py @@ -98,8 +98,6 @@ async def test_responses_store_pagination_basic(): assert result3.data[0].id == "zebra-resp" assert result3.has_more is False - await store.sql_store.close() - async def test_responses_store_pagination_ascending(): """Test pagination with ascending order.""" @@ -138,8 +136,6 @@ async def test_responses_store_pagination_ascending(): assert result2.data[0].id == "charlie-resp" assert result2.has_more is True - await store.sql_store.close() - async def test_responses_store_pagination_with_model_filter(): """Test pagination combined with model filtering.""" @@ -181,8 +177,6 @@ async def test_responses_store_pagination_with_model_filter(): assert result2.data[0].model == "model-a" assert result2.has_more is False - await store.sql_store.close() - async def test_responses_store_pagination_invalid_after(): """Test error handling for invalid 'after' parameter.""" @@ -195,8 +189,6 @@ async def test_responses_store_pagination_invalid_after(): with pytest.raises(ValueError, match="Record with id.*'non-existent' not found in table 'openai_responses'"): await store.list_responses(after="non-existent", limit=2) - await store.sql_store.close() - async def test_responses_store_pagination_no_limit(): """Test pagination behavior when no limit is specified.""" @@ -229,8 +221,6 @@ async def test_responses_store_pagination_no_limit(): assert result.data[1].id == "omega-resp" assert result.has_more is False - await store.sql_store.close() - async def test_responses_store_get_response_object(): """Test retrieving a single response object.""" @@ -259,8 +249,6 @@ async def test_responses_store_get_response_object(): with pytest.raises(ValueError, match="Response with id non-existent not found"): await store.get_response_object("non-existent") - await store.sql_store.close() - async def test_responses_store_input_items_pagination(): """Test pagination functionality for input items.""" @@ -342,8 +330,6 @@ async def test_responses_store_input_items_pagination(): with pytest.raises(ValueError, match="Cannot specify both 'before' and 'after' parameters"): await store.list_response_input_items("test-resp", before="some-id", after="other-id") - await store.sql_store.close() - async def test_responses_store_input_items_before_pagination(): """Test before pagination functionality for input items.""" @@ -404,5 +390,3 @@ async def test_responses_store_input_items_before_pagination(): ValueError, match="Input item with id 'non-existent' not found for response 'test-resp-before'" ): await store.list_response_input_items("test-resp-before", before="non-existent") - - await store.sql_store.close() diff --git a/tests/unit/utils/sqlstore/test_sqlstore.py b/tests/unit/utils/sqlstore/test_sqlstore.py index a68a5b6819..00669b698c 100644 --- a/tests/unit/utils/sqlstore/test_sqlstore.py +++ b/tests/unit/utils/sqlstore/test_sqlstore.py @@ -64,9 +64,6 @@ async def test_sqlite_sqlstore(): assert result.data == [{"id": 12, "name": "test12"}] assert result.has_more is False - # cleanup - await sqlstore.close() - async def test_sqlstore_pagination_basic(): """Test basic pagination functionality at the SQL store level.""" @@ -131,8 +128,6 @@ async def test_sqlstore_pagination_basic(): assert result3.data[0]["id"] == "zebra" assert result3.has_more is False - await store.close() - async def test_sqlstore_pagination_with_filter(): """Test pagination with WHERE conditions.""" @@ -185,8 +180,6 @@ async def test_sqlstore_pagination_with_filter(): assert result2.data[0]["id"] == "xyz" assert result2.has_more is False - await store.close() - async def test_sqlstore_pagination_ascending_order(): """Test pagination with ascending order.""" @@ -235,8 +228,6 @@ async def test_sqlstore_pagination_ascending_order(): assert result2.data[0]["id"] == "alpha" assert result2.has_more is True - await store.close() - async def test_sqlstore_pagination_multi_column_ordering_error(): """Test that multi-column ordering raises an error when using cursor pagination.""" @@ -274,8 +265,6 @@ async def test_sqlstore_pagination_multi_column_ordering_error(): assert len(result.data) == 1 assert result.data[0]["id"] == "task1" - await store.close() - async def test_sqlstore_pagination_cursor_requires_order_by(): """Test that cursor pagination requires order_by parameter.""" @@ -293,8 +282,6 @@ async def test_sqlstore_pagination_cursor_requires_order_by(): cursor=("id", "task1"), ) - await store.close() - async def test_sqlstore_pagination_error_handling(): """Test error handling for invalid columns and cursor IDs.""" @@ -427,8 +414,6 @@ async def test_where_operator_edge_cases(): with pytest.raises(ValueError, match="Unsupported operator"): await store.fetch_all("events", where={"ts": {"!=": base}}) - await store.close() - async def test_sqlstore_pagination_custom_key_column(): """Test pagination with custom primary key column (not 'id').""" @@ -478,5 +463,3 @@ async def test_sqlstore_pagination_custom_key_column(): assert len(result2.data) == 1 assert result2.data[0]["uuid"] == "uuid-alpha" assert result2.has_more is False - - await store.close() diff --git a/tests/unit/utils/test_authorized_sqlstore.py b/tests/unit/utils/test_authorized_sqlstore.py index 8f395e6164..d85e784a9b 100644 --- a/tests/unit/utils/test_authorized_sqlstore.py +++ b/tests/unit/utils/test_authorized_sqlstore.py @@ -77,8 +77,6 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic assert row is not None assert row["title"] == "User Document" - await base_sqlstore.close() - @patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") async def test_sql_policy_consistency(mock_get_authenticated_user): @@ -165,8 +163,6 @@ async def test_sql_policy_consistency(mock_get_authenticated_user): f"Difference: SQL only: {sql_ids - policy_ids}, Policy only: {policy_ids - sql_ids}" ) - await base_sqlstore.close() - @patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user): @@ -215,5 +211,3 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us # Third item should have null attributes (no authenticated user) assert result.data[2]["id"] == "item3" assert result.data[2]["access_attributes"] is None - - await base_sqlstore.close() From 97b86226c79123b7718bff1a6022dd395f6c3ff8 Mon Sep 17 00:00:00 2001 From: Raghotham Murthy Date: Sat, 11 Oct 2025 17:58:08 -0700 Subject: [PATCH 8/8] change close to shutdown --- llama_stack/providers/utils/kvstore/sqlite/sqlite.py | 2 +- tests/unit/utils/kvstore/test_sqlite_memory.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py index 7c816dc48b..a9a7a13048 100644 --- a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py +++ b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py @@ -65,7 +65,7 @@ async def initialize(self): ) await db.commit() - async def close(self): + async def shutdown(self): """Close the persistent connection (only for in-memory databases).""" if self._conn: await self._conn.close() diff --git a/tests/unit/utils/kvstore/test_sqlite_memory.py b/tests/unit/utils/kvstore/test_sqlite_memory.py index 942ad10875..a31377306d 100644 --- a/tests/unit/utils/kvstore/test_sqlite_memory.py +++ b/tests/unit/utils/kvstore/test_sqlite_memory.py @@ -17,7 +17,7 @@ async def test_memory_kvstore_persistence_behavior(): store1 = SqliteKVStoreImpl(config) await store1.initialize() await store1.set("persist_test", "should_not_persist") - await store1.close() + await store1.shutdown() # Second instance with same config store2 = SqliteKVStoreImpl(config) @@ -27,4 +27,4 @@ async def test_memory_kvstore_persistence_behavior(): result = await store2.get("persist_test") assert result is None - await store2.close() + await store2.shutdown()