@@ -21,14 +21,27 @@ class SqliteKVStoreImpl(KVStore):
2121 def __init__ (self , config : SqliteKVStoreConfig ):
2222 self .db_path = config .db_path
2323 self .table_name = "kvstore"
24+ self ._conn : aiosqlite .Connection | None = None
2425
2526 def __str__ (self ):
2627 return f"SqliteKVStoreImpl(db_path={ self .db_path } , table_name={ self .table_name } )"
2728
29+ def _is_memory_db (self ) -> bool :
30+ """Check if this is an in-memory database."""
31+ return self .db_path == ":memory:" or "mode=memory" in self .db_path
32+
2833 async def initialize (self ):
29- os .makedirs (os .path .dirname (self .db_path ), exist_ok = True )
30- async with aiosqlite .connect (self .db_path ) as db :
31- await db .execute (
34+ # Skip directory creation for in-memory databases and file: URIs
35+ if not self ._is_memory_db () and not self .db_path .startswith ("file:" ):
36+ db_dir = os .path .dirname (self .db_path )
37+ if db_dir : # Only create if there's a directory component
38+ os .makedirs (db_dir , exist_ok = True )
39+
40+ # Only use persistent connection for in-memory databases
41+ # File-based databases use connection-per-operation to avoid hangs
42+ if self ._is_memory_db ():
43+ self ._conn = await aiosqlite .connect (self .db_path )
44+ await self ._conn .execute (
3245 f"""
3346 CREATE TABLE IF NOT EXISTS { self .table_name } (
3447 key TEXT PRIMARY KEY,
@@ -37,19 +50,50 @@ async def initialize(self):
3750 )
3851 """
3952 )
40- await db .commit ()
53+ await self ._conn .commit ()
54+ else :
55+ # For file-based databases, just create the table
56+ async with aiosqlite .connect (self .db_path ) as db :
57+ await db .execute (
58+ f"""
59+ CREATE TABLE IF NOT EXISTS { self .table_name } (
60+ key TEXT PRIMARY KEY,
61+ value TEXT,
62+ expiration TIMESTAMP
63+ )
64+ """
65+ )
66+ await db .commit ()
67+
68+ async def shutdown (self ):
69+ """Close the persistent connection (only for in-memory databases)."""
70+ if self ._conn :
71+ await self ._conn .close ()
72+ self ._conn = None
4173
4274 async def set (self , key : str , value : str , expiration : datetime | None = None ) -> None :
43- async with aiosqlite .connect (self .db_path ) as db :
44- await db .execute (
75+ if self ._conn :
76+ # In-memory database with persistent connection
77+ await self ._conn .execute (
4578 f"INSERT OR REPLACE INTO { self .table_name } (key, value, expiration) VALUES (?, ?, ?)" ,
4679 (key , value , expiration ),
4780 )
48- await db .commit ()
81+ await self ._conn .commit ()
82+ else :
83+ # File-based database with connection per operation
84+ async with aiosqlite .connect (self .db_path ) as db :
85+ await db .execute (
86+ f"INSERT OR REPLACE INTO { self .table_name } (key, value, expiration) VALUES (?, ?, ?)" ,
87+ (key , value , expiration ),
88+ )
89+ await db .commit ()
4990
5091 async def get (self , key : str ) -> str | None :
51- async with aiosqlite .connect (self .db_path ) as db :
52- async with db .execute (f"SELECT value, expiration FROM { self .table_name } WHERE key = ?" , (key ,)) as cursor :
92+ if self ._conn :
93+ # In-memory database with persistent connection
94+ async with self ._conn .execute (
95+ f"SELECT value, expiration FROM { self .table_name } WHERE key = ?" , (key ,)
96+ ) as cursor :
5397 row = await cursor .fetchone ()
5498 if row is None :
5599 return None
@@ -58,15 +102,36 @@ async def get(self, key: str) -> str | None:
58102 logger .warning (f"Expected string value for key { key } , got { type (value )} , returning None" )
59103 return None
60104 return value
105+ else :
106+ # File-based database with connection per operation
107+ async with aiosqlite .connect (self .db_path ) as db :
108+ async with db .execute (
109+ f"SELECT value, expiration FROM { self .table_name } WHERE key = ?" , (key ,)
110+ ) as cursor :
111+ row = await cursor .fetchone ()
112+ if row is None :
113+ return None
114+ value , expiration = row
115+ if not isinstance (value , str ):
116+ logger .warning (f"Expected string value for key { key } , got { type (value )} , returning None" )
117+ return None
118+ return value
61119
62120 async def delete (self , key : str ) -> None :
63- async with aiosqlite .connect (self .db_path ) as db :
64- await db .execute (f"DELETE FROM { self .table_name } WHERE key = ?" , (key ,))
65- await db .commit ()
121+ if self ._conn :
122+ # In-memory database with persistent connection
123+ await self ._conn .execute (f"DELETE FROM { self .table_name } WHERE key = ?" , (key ,))
124+ await self ._conn .commit ()
125+ else :
126+ # File-based database with connection per operation
127+ async with aiosqlite .connect (self .db_path ) as db :
128+ await db .execute (f"DELETE FROM { self .table_name } WHERE key = ?" , (key ,))
129+ await db .commit ()
66130
67131 async def values_in_range (self , start_key : str , end_key : str ) -> list [str ]:
68- async with aiosqlite .connect (self .db_path ) as db :
69- async with db .execute (
132+ if self ._conn :
133+ # In-memory database with persistent connection
134+ async with self ._conn .execute (
70135 f"SELECT key, value, expiration FROM { self .table_name } WHERE key >= ? AND key <= ?" ,
71136 (start_key , end_key ),
72137 ) as cursor :
@@ -75,13 +140,35 @@ async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
75140 _ , value , _ = row
76141 result .append (value )
77142 return result
143+ else :
144+ # File-based database with connection per operation
145+ async with aiosqlite .connect (self .db_path ) as db :
146+ async with db .execute (
147+ f"SELECT key, value, expiration FROM { self .table_name } WHERE key >= ? AND key <= ?" ,
148+ (start_key , end_key ),
149+ ) as cursor :
150+ result = []
151+ async for row in cursor :
152+ _ , value , _ = row
153+ result .append (value )
154+ return result
78155
79156 async def keys_in_range (self , start_key : str , end_key : str ) -> list [str ]:
80157 """Get all keys in the given range."""
81- async with aiosqlite .connect (self .db_path ) as db :
82- cursor = await db .execute (
158+ if self ._conn :
159+ # In-memory database with persistent connection
160+ cursor = await self ._conn .execute (
83161 f"SELECT key FROM { self .table_name } WHERE key >= ? AND key <= ?" ,
84162 (start_key , end_key ),
85163 )
86164 rows = await cursor .fetchall ()
87165 return [row [0 ] for row in rows ]
166+ else :
167+ # File-based database with connection per operation
168+ async with aiosqlite .connect (self .db_path ) as db :
169+ cursor = await db .execute (
170+ f"SELECT key FROM { self .table_name } WHERE key >= ? AND key <= ?" ,
171+ (start_key , end_key ),
172+ )
173+ rows = await cursor .fetchall ()
174+ return [row [0 ] for row in rows ]
0 commit comments