Skip to content

Commit c015fd8

Browse files
committed
Fix sqlite3 Connection initialization check
Add proper __init__ validation for sqlite3.Connection to ensure base class __init__ is called before using connection methods. This fixes the test_connection_constructor_call_check test case. Changes: - Modified Connection.py_new to detect subclassing - For base Connection class, initialization happens immediately in py_new - For subclassed Connection, db is initialized as None - Added __init__ method that performs actual database initialization - Updated _db_lock error message to match CPython: 'Base Connection.__init__ not called.' This ensures CPython compatibility where attempting to use a Connection subclass instance without calling the base __init__ raises ProgrammingError.
1 parent 3b48dcc commit c015fd8

File tree

2 files changed

+43
-15
lines changed

2 files changed

+43
-15
lines changed

Lib/test/test_sqlite3/test_regression.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,6 @@ def test_str_subclass(self):
221221
class MyStr(str): pass
222222
self.con.execute("select ?", (MyStr("abc"),))
223223

224-
# TODO: RUSTPYTHON
225-
@unittest.expectedFailure
226224
def test_connection_constructor_call_check(self):
227225
"""
228226
Verifies that connection methods check whether base class __init__ was

stdlib/src/sqlite.rs

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,38 @@ mod _sqlite {
851851
type Args = ConnectArgs;
852852

853853
fn py_new(cls: PyTypeRef, args: Self::Args, vm: &VirtualMachine) -> PyResult {
854-
Ok(Self::new(args, vm)?.into_ref_with_type(vm, cls)?.into())
854+
let text_factory = PyStr::class(&vm.ctx).to_owned().into_object();
855+
856+
// For non-subclassed Connection, initialize in __new__
857+
// For subclassed Connection, leave db as None and require __init__ to be called
858+
let is_base_class = cls.is(Connection::class(&vm.ctx).as_object());
859+
860+
let db = if is_base_class {
861+
// Initialize immediately for base class
862+
let path = args.database.to_cstring(vm)?;
863+
let db = Sqlite::from(SqliteRaw::open(path.as_ptr(), args.uri, vm)?);
864+
let timeout = (args.timeout * 1000.0) as c_int;
865+
db.busy_timeout(timeout);
866+
if let Some(isolation_level) = &args.isolation_level {
867+
begin_statement_ptr_from_isolation_level(isolation_level, vm)?;
868+
}
869+
Some(db)
870+
} else {
871+
// For subclasses, require __init__ to be called
872+
None
873+
};
874+
875+
let conn = Self {
876+
db: PyMutex::new(db),
877+
detect_types: args.detect_types,
878+
isolation_level: PyAtomicRef::from(args.isolation_level),
879+
check_same_thread: args.check_same_thread,
880+
thread_ident: std::thread::current().id(),
881+
row_factory: PyAtomicRef::from(None),
882+
text_factory: PyAtomicRef::from(text_factory),
883+
};
884+
885+
Ok(conn.into_ref_with_type(vm, cls)?.into())
855886
}
856887
}
857888

@@ -873,25 +904,24 @@ mod _sqlite {
873904

874905
#[pyclass(with(Constructor, Callable), flags(BASETYPE))]
875906
impl Connection {
876-
fn new(args: ConnectArgs, vm: &VirtualMachine) -> PyResult<Self> {
907+
#[pymethod]
908+
fn __init__(&self, args: ConnectArgs, vm: &VirtualMachine) -> PyResult<()> {
909+
let mut guard = self.db.lock();
910+
if guard.is_some() {
911+
// Already initialized
912+
return Ok(());
913+
}
914+
877915
let path = args.database.to_cstring(vm)?;
878916
let db = Sqlite::from(SqliteRaw::open(path.as_ptr(), args.uri, vm)?);
879917
let timeout = (args.timeout * 1000.0) as c_int;
880918
db.busy_timeout(timeout);
881919
if let Some(isolation_level) = &args.isolation_level {
882920
begin_statement_ptr_from_isolation_level(isolation_level, vm)?;
883921
}
884-
let text_factory = PyStr::class(&vm.ctx).to_owned().into_object();
885922

886-
Ok(Self {
887-
db: PyMutex::new(Some(db)),
888-
detect_types: args.detect_types,
889-
isolation_level: PyAtomicRef::from(args.isolation_level),
890-
check_same_thread: args.check_same_thread,
891-
thread_ident: std::thread::current().id(),
892-
row_factory: PyAtomicRef::from(None),
893-
text_factory: PyAtomicRef::from(text_factory),
894-
})
923+
*guard = Some(db);
924+
Ok(())
895925
}
896926

897927
fn db_lock(&self, vm: &VirtualMachine) -> PyResult<PyMappedMutexGuard<'_, Sqlite>> {
@@ -908,7 +938,7 @@ mod _sqlite {
908938
} else {
909939
Err(new_programming_error(
910940
vm,
911-
"Cannot operate on a closed database.".to_owned(),
941+
"Base Connection.__init__ not called.".to_owned(),
912942
))
913943
}
914944
}

0 commit comments

Comments
 (0)