diff --git a/chdb/session/state.py b/chdb/session/state.py index be8552b7168..dc3e9fdfbb4 100644 --- a/chdb/session/state.py +++ b/chdb/session/state.py @@ -40,6 +40,7 @@ class Session: """ def __init__(self, path=None): + self._conn = None global g_session, g_session_path if g_session is not None: warnings.warn( diff --git a/programs/local/LocalServer.cpp b/programs/local/LocalServer.cpp index 105e98da52d..449f43f0d69 100644 --- a/programs/local/LocalServer.cpp +++ b/programs/local/LocalServer.cpp @@ -157,6 +157,7 @@ void applySettingsOverridesForLocal(ContextMutablePtr context) LocalServer::~LocalServer() { + cleanup(); resetQueryOutputVector(); } diff --git a/programs/local/LocalServer.h b/programs/local/LocalServer.h index ec8e2dd4349..3cc42364a3b 100644 --- a/programs/local/LocalServer.h +++ b/programs/local/LocalServer.h @@ -89,11 +89,6 @@ class LocalServer : public ClientApplicationBase, public Loggers return local_connection->getCHDBProgress().read_bytes; } - void chdbCleanup() - { - cleanup(); - } - private: void cleanStreamingQuery(); }; diff --git a/programs/local/chdb.cpp b/programs/local/chdb.cpp index 3a36e4da2bf..16fe99c9a36 100644 --- a/programs/local/chdb.cpp +++ b/programs/local/chdb.cpp @@ -18,12 +18,12 @@ static std::mutex CHDB_MUTEX; chdb_conn * global_conn_ptr = nullptr; std::string global_db_path; -static DB::LocalServer * bgClickHouseLocal(int argc, char ** argv) +static std::unique_ptr bgClickHouseLocal(int argc, char ** argv) { - DB::LocalServer * app = nullptr; + std::unique_ptr app; try { - app = new DB::LocalServer(); + app = std::make_unique(); app->setBackground(true); app->init(argc, argv); int ret = app->run(); @@ -31,30 +31,24 @@ static DB::LocalServer * bgClickHouseLocal(int argc, char ** argv) { auto err_msg = app->getErrorMsg(); LOG_ERROR(&app->logger(), "Error running bgClickHouseLocal: {}", err_msg); - delete app; - app = nullptr; throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Error running bgClickHouseLocal: {}", err_msg); } return app; } catch (const DB::Exception & e) { - delete app; throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "bgClickHouseLocal {}", DB::getExceptionMessage(e, false)); } catch (const Poco::Exception & e) { - delete app; throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "bgClickHouseLocal {}", e.displayText()); } catch (const std::exception & e) { - delete app; throw std::domain_error(e.what()); } catch (...) { - delete app; throw std::domain_error(DB::getCurrentExceptionMessage(true)); } } @@ -545,10 +539,11 @@ chdb_conn ** connect_chdb(int argc, char ** argv) [&]() { auto * queue = static_cast(conn->queue); + std::unique_ptr server; try { - DB::LocalServer * server = bgClickHouseLocal(argc, argv); - conn->server = server; + server = bgClickHouseLocal(argc, argv); + conn->server = nullptr; conn->connected = true; global_conn_ptr = conn; @@ -570,16 +565,7 @@ chdb_conn ** connect_chdb(int argc, char ** argv) if (queue->shutdown) { - try - { - server->chdbCleanup(); - delete server; - } - catch (...) - { - // Log error but continue shutdown - LOG_ERROR(&Poco::Logger::get("LocalServer"), "Error during server cleanup"); - } + server.reset(); queue->cleanup_done = true; queue->query_cv.notify_all(); break; @@ -588,7 +574,7 @@ chdb_conn ** connect_chdb(int argc, char ** argv) } CHDB::QueryRequestBase & req = *(queue->current_query); - auto result = createQueryResult(server, req); + auto result = createQueryResult(server.get(), req); bool is_end = result.second; { diff --git a/tests/test_open_session_after_failure.py b/tests/test_open_session_after_failure.py new file mode 100644 index 00000000000..4935585b668 --- /dev/null +++ b/tests/test_open_session_after_failure.py @@ -0,0 +1,36 @@ +#!python3 + +import unittest +import shutil +from chdb import session + + +test_dir1 = ".test_open_session_after_failure" +test_dir2 = "/usr/bin" + + +class TestStateful(unittest.TestCase): + def setUp(self) -> None: + shutil.rmtree(test_dir1, ignore_errors=True) + return super().setUp() + + def tearDown(self) -> None: + shutil.rmtree(test_dir1, ignore_errors=True) + return super().tearDown() + + def test_path(self): + # Test that creating session with invalid path (read-only directory) raises exception + with self.assertRaises(Exception): + sess = session.Session(test_dir2) + + # Test that creating session with valid path works after failure + sess = session.Session(test_dir1) + + ret = sess.query("select 'aaaaa'") + self.assertEqual(str(ret), "\"aaaaa\"\n") + + sess.close() + + +if __name__ == '__main__': + unittest.main()