Skip to content

Commit e349fe9

Browse files
raghothamjwm4
authored andcommitted
feat: Allow :memory: for kvstore (llamastack#3696)
## Test Plan added unit tests
1 parent 9c1364e commit e349fe9

File tree

2 files changed

+133
-16
lines changed

2 files changed

+133
-16
lines changed

llama_stack/providers/utils/kvstore/sqlite/sqlite.py

Lines changed: 103 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,27 @@ class SqliteKVStoreImpl(KVStore):
2121
def __init__(self, config: SqliteKVStoreConfig):
2222
self.db_path = config.db_path
2323
self.table_name = "kvstore"
24+
self._conn: aiosqlite.Connection | None = None
2425

2526
def __str__(self):
2627
return f"SqliteKVStoreImpl(db_path={self.db_path}, table_name={self.table_name})"
2728

29+
def _is_memory_db(self) -> bool:
30+
"""Check if this is an in-memory database."""
31+
return self.db_path == ":memory:" or "mode=memory" in self.db_path
32+
2833
async def initialize(self):
29-
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
30-
async with aiosqlite.connect(self.db_path) as db:
31-
await db.execute(
34+
# Skip directory creation for in-memory databases and file: URIs
35+
if not self._is_memory_db() and not self.db_path.startswith("file:"):
36+
db_dir = os.path.dirname(self.db_path)
37+
if db_dir: # Only create if there's a directory component
38+
os.makedirs(db_dir, exist_ok=True)
39+
40+
# Only use persistent connection for in-memory databases
41+
# File-based databases use connection-per-operation to avoid hangs
42+
if self._is_memory_db():
43+
self._conn = await aiosqlite.connect(self.db_path)
44+
await self._conn.execute(
3245
f"""
3346
CREATE TABLE IF NOT EXISTS {self.table_name} (
3447
key TEXT PRIMARY KEY,
@@ -37,19 +50,50 @@ async def initialize(self):
3750
)
3851
"""
3952
)
40-
await db.commit()
53+
await self._conn.commit()
54+
else:
55+
# For file-based databases, just create the table
56+
async with aiosqlite.connect(self.db_path) as db:
57+
await db.execute(
58+
f"""
59+
CREATE TABLE IF NOT EXISTS {self.table_name} (
60+
key TEXT PRIMARY KEY,
61+
value TEXT,
62+
expiration TIMESTAMP
63+
)
64+
"""
65+
)
66+
await db.commit()
67+
68+
async def shutdown(self):
69+
"""Close the persistent connection (only for in-memory databases)."""
70+
if self._conn:
71+
await self._conn.close()
72+
self._conn = None
4173

4274
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
43-
async with aiosqlite.connect(self.db_path) as db:
44-
await db.execute(
75+
if self._conn:
76+
# In-memory database with persistent connection
77+
await self._conn.execute(
4578
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
4679
(key, value, expiration),
4780
)
48-
await db.commit()
81+
await self._conn.commit()
82+
else:
83+
# File-based database with connection per operation
84+
async with aiosqlite.connect(self.db_path) as db:
85+
await db.execute(
86+
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
87+
(key, value, expiration),
88+
)
89+
await db.commit()
4990

5091
async def get(self, key: str) -> str | None:
51-
async with aiosqlite.connect(self.db_path) as db:
52-
async with db.execute(f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)) as cursor:
92+
if self._conn:
93+
# In-memory database with persistent connection
94+
async with self._conn.execute(
95+
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
96+
) as cursor:
5397
row = await cursor.fetchone()
5498
if row is None:
5599
return None
@@ -58,15 +102,36 @@ async def get(self, key: str) -> str | None:
58102
logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None")
59103
return None
60104
return value
105+
else:
106+
# File-based database with connection per operation
107+
async with aiosqlite.connect(self.db_path) as db:
108+
async with db.execute(
109+
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
110+
) as cursor:
111+
row = await cursor.fetchone()
112+
if row is None:
113+
return None
114+
value, expiration = row
115+
if not isinstance(value, str):
116+
logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None")
117+
return None
118+
return value
61119

62120
async def delete(self, key: str) -> None:
63-
async with aiosqlite.connect(self.db_path) as db:
64-
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
65-
await db.commit()
121+
if self._conn:
122+
# In-memory database with persistent connection
123+
await self._conn.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
124+
await self._conn.commit()
125+
else:
126+
# File-based database with connection per operation
127+
async with aiosqlite.connect(self.db_path) as db:
128+
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
129+
await db.commit()
66130

67131
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
68-
async with aiosqlite.connect(self.db_path) as db:
69-
async with db.execute(
132+
if self._conn:
133+
# In-memory database with persistent connection
134+
async with self._conn.execute(
70135
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
71136
(start_key, end_key),
72137
) as cursor:
@@ -75,13 +140,35 @@ async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
75140
_, value, _ = row
76141
result.append(value)
77142
return result
143+
else:
144+
# File-based database with connection per operation
145+
async with aiosqlite.connect(self.db_path) as db:
146+
async with db.execute(
147+
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
148+
(start_key, end_key),
149+
) as cursor:
150+
result = []
151+
async for row in cursor:
152+
_, value, _ = row
153+
result.append(value)
154+
return result
78155

79156
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
80157
"""Get all keys in the given range."""
81-
async with aiosqlite.connect(self.db_path) as db:
82-
cursor = await db.execute(
158+
if self._conn:
159+
# In-memory database with persistent connection
160+
cursor = await self._conn.execute(
83161
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
84162
(start_key, end_key),
85163
)
86164
rows = await cursor.fetchall()
87165
return [row[0] for row in rows]
166+
else:
167+
# File-based database with connection per operation
168+
async with aiosqlite.connect(self.db_path) as db:
169+
cursor = await db.execute(
170+
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
171+
(start_key, end_key),
172+
)
173+
rows = await cursor.fetchall()
174+
return [row[0] for row in rows]
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
8+
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
9+
from llama_stack.providers.utils.kvstore.sqlite.sqlite import SqliteKVStoreImpl
10+
11+
12+
async def test_memory_kvstore_persistence_behavior():
13+
"""Test that :memory: database doesn't persist across instances."""
14+
config = SqliteKVStoreConfig(db_path=":memory:")
15+
16+
# First instance
17+
store1 = SqliteKVStoreImpl(config)
18+
await store1.initialize()
19+
await store1.set("persist_test", "should_not_persist")
20+
await store1.shutdown()
21+
22+
# Second instance with same config
23+
store2 = SqliteKVStoreImpl(config)
24+
await store2.initialize()
25+
26+
# Data should not be present
27+
result = await store2.get("persist_test")
28+
assert result is None
29+
30+
await store2.shutdown()

0 commit comments

Comments
 (0)