diff --git a/src/packages/dumbo/src/core/connections/connection.ts b/src/packages/dumbo/src/core/connections/connection.ts index c8e6cc85..dae9e618 100644 --- a/src/packages/dumbo/src/core/connections/connection.ts +++ b/src/packages/dumbo/src/core/connections/connection.ts @@ -45,7 +45,10 @@ export type CreateConnectionOptions< close: (client: DbClient) => Promise; initTransaction: ( connection: () => ConnectionType, - ) => (client: Promise) => DatabaseTransaction; + ) => ( + client: Promise, + options?: { close: (client: DbClient, error?: unknown) => Promise }, + ) => DatabaseTransaction; executor: () => Executor; }; diff --git a/src/packages/dumbo/src/core/connections/transaction.ts b/src/packages/dumbo/src/core/connections/transaction.ts index e6b6683e..7fefba1d 100644 --- a/src/packages/dumbo/src/core/connections/transaction.ts +++ b/src/packages/dumbo/src/core/connections/transaction.ts @@ -70,12 +70,27 @@ export const transactionFactoryWithDbClient = < connect: () => Promise, initTransaction: ( client: Promise, + options?: { close: (client: DbClient, error?: unknown) => Promise }, ) => DatabaseTransaction, -): DatabaseTransactionFactory => ({ - transaction: () => initTransaction(connect()), - withTransaction: (handle) => - executeInTransaction(initTransaction(connect()), handle), -}); +): DatabaseTransactionFactory => { + let currentTransaction: DatabaseTransaction | undefined = + undefined; + + const getOrInitCurrentTransaction = () => + currentTransaction ?? + (currentTransaction = initTransaction(connect(), { + close: () => { + currentTransaction = undefined; + return Promise.resolve(); + }, + })); + + return { + transaction: getOrInitCurrentTransaction, + withTransaction: (handle) => + executeInTransaction(getOrInitCurrentTransaction(), handle), + }; +}; const wrapInConnectionClosure = async < ConnectionType extends Connection = Connection, diff --git a/src/packages/dumbo/src/storage/postgresql/pg/connections/transaction.ts b/src/packages/dumbo/src/storage/postgresql/pg/connections/transaction.ts index 2755413b..9dfc16ce 100644 --- a/src/packages/dumbo/src/storage/postgresql/pg/connections/transaction.ts +++ b/src/packages/dumbo/src/storage/postgresql/pg/connections/transaction.ts @@ -30,15 +30,19 @@ export const nodePostgresTransaction = commit: async () => { const client = await getClient; - await client.query('COMMIT'); - - if (options?.close) await options?.close(client); + try { + await client.query('COMMIT'); + } finally { + if (options?.close) await options?.close(client); + } }, rollback: async (error?: unknown) => { const client = await getClient; - await client.query('ROLLBACK'); - - if (options?.close) await options?.close(client, error); + try { + await client.query('ROLLBACK'); + } finally { + if (options?.close) await options?.close(client, error); + } }, execute: sqlExecutor(nodePostgresSQLExecutor(), { connect: () => getClient, diff --git a/src/packages/dumbo/src/storage/sqlite/core/connections/index.ts b/src/packages/dumbo/src/storage/sqlite/core/connections/index.ts index 733dcfe0..70151e17 100644 --- a/src/packages/dumbo/src/storage/sqlite/core/connections/index.ts +++ b/src/packages/dumbo/src/storage/sqlite/core/connections/index.ts @@ -49,6 +49,7 @@ export type SQLitePoolConnectionOptions< type: 'PoolClient'; connect: Promise; close: (client: SQLitePoolClient) => Promise; + allowNestedTransactions: boolean; }; export type SQLiteClientConnectionOptions< @@ -58,6 +59,7 @@ export type SQLiteClientConnectionOptions< type: 'Client'; connect: Promise; close: (client: SQLiteClient) => Promise; + allowNestedTransactions: boolean; }; export type SQLiteClientConnection< @@ -79,31 +81,65 @@ export const sqliteClientConnection = < >( options: SQLiteClientConnectionOptions, ): SQLiteClientConnection => { - const { connect, close } = options; + const { connect, close, allowNestedTransactions } = options; return createConnection({ connector: options.connector, connect, close, initTransaction: (connection) => - sqliteTransaction(options.connector, connection), + sqliteTransaction(options.connector, connection, allowNestedTransactions), executor: () => sqliteSQLExecutor(options.connector), }); }; +export type TransactionNestingCounter = { + increment: () => void; + decrement: () => void; + reset: () => void; + level: number; +}; + +export const transactionNestingCounter = (): TransactionNestingCounter => { + let transactionLevel = 0; + + return { + reset: () => { + transactionLevel = 0; + }, + increment: () => { + transactionLevel++; + }, + decrement: () => { + transactionLevel--; + + if (transactionLevel < 0) { + throw new Error('Transaction level is out of bounds'); + } + }, + get level() { + return transactionLevel; + }, + }; +}; + export const sqlitePoolClientConnection = < ConnectorType extends SQLiteConnectorType = SQLiteConnectorType, >( options: SQLitePoolConnectionOptions, ): SQLitePoolClientConnection => { - const { connect, close } = options; + const { connect, close, allowNestedTransactions } = options; return createConnection({ connector: options.connector, connect, close, initTransaction: (connection) => - sqliteTransaction(options.connector, connection), + sqliteTransaction( + options.connector, + connection, + allowNestedTransactions ?? false, + ), executor: () => sqliteSQLExecutor(options.connector), }); }; @@ -113,11 +149,13 @@ export function sqliteConnection< >( options: SQLitePoolConnectionOptions, ): SQLitePoolClientConnection; + export function sqliteConnection< ConnectorType extends SQLiteConnectorType = SQLiteConnectorType, >( options: SQLiteClientConnectionOptions, ): SQLiteClientConnection; + export function sqliteConnection< ConnectorType extends SQLiteConnectorType = SQLiteConnectorType, >( diff --git a/src/packages/dumbo/src/storage/sqlite/core/pool/pool.ts b/src/packages/dumbo/src/storage/sqlite/core/pool/pool.ts index 20d8ee75..a3fb76aa 100644 --- a/src/packages/dumbo/src/storage/sqlite/core/pool/pool.ts +++ b/src/packages/dumbo/src/storage/sqlite/core/pool/pool.ts @@ -58,6 +58,7 @@ export const sqliteSingletonClientPool = < options: { connector: ConnectorType; database?: string | undefined; + allowNestedTransactions?: boolean; } & SQLiteFileNameOrConnectionString, ): SQLiteAmbientClientPool => { const { connector } = options; @@ -80,6 +81,7 @@ export const sqliteSingletonClientPool = < connector, type: 'Client', connect, + allowNestedTransactions: options.allowNestedTransactions ?? false, close: () => Promise.resolve(), })); }; @@ -103,9 +105,10 @@ export const sqliteAlwaysNewClientPool = < options: { connector: ConnectorType; database?: string | undefined; + allowNestedTransactions?: boolean; } & SQLiteFileNameOrConnectionString, ): SQLiteAmbientClientPool => { - const { connector } = options; + const { connector, allowNestedTransactions } = options; return createConnectionPool({ connector: connector, @@ -124,6 +127,7 @@ export const sqliteAlwaysNewClientPool = < connector, type: 'Client', connect, + allowNestedTransactions: allowNestedTransactions ?? false, close: (client) => client.close(), }); }, @@ -135,8 +139,9 @@ export const sqliteAmbientClientPool = < >(options: { connector: ConnectorType; client: SQLiteClient; + allowNestedTransactions: boolean; }): SQLiteAmbientClientPool => { - const { client, connector } = options; + const { client, connector, allowNestedTransactions } = options; const getConnection = () => { const connect = Promise.resolve(client); @@ -145,6 +150,7 @@ export const sqliteAmbientClientPool = < connector, type: 'Client', connect, + allowNestedTransactions, close: () => Promise.resolve(), }); }; @@ -178,6 +184,7 @@ export type SQLitePoolPooledOptions< connector: ConnectorType; pooled?: true; singleton?: boolean; + allowNestedTransactions?: boolean; }; export type SQLitePoolNotPooledOptions< @@ -188,11 +195,13 @@ export type SQLitePoolNotPooledOptions< pooled?: false; client: SQLiteClient; singleton?: true; + allowNestedTransactions?: boolean; } | { connector: ConnectorType; pooled?: boolean; singleton?: boolean; + allowNestedTransactions?: boolean; } | { connector: ConnectorType; @@ -201,6 +210,7 @@ export type SQLitePoolNotPooledOptions< | SQLiteClientConnection; pooled?: false; singleton?: true; + allowNestedTransactions?: boolean; }; export type SQLitePoolOptions< @@ -218,6 +228,7 @@ export function sqlitePool< options: SQLitePoolNotPooledOptions & SQLiteFileNameOrConnectionString, ): SQLiteAmbientClientPool; + export function sqlitePool< ConnectorType extends SQLiteConnectorType = SQLiteConnectorType, >( @@ -231,7 +242,11 @@ export function sqlitePool< // setSQLiteTypeParser(serializer ?? JSONSerializer); if ('client' in options && options.client) - return sqliteAmbientClientPool({ connector, client: options.client }); + return sqliteAmbientClientPool({ + connector, + client: options.client, + allowNestedTransactions: options.allowNestedTransactions ?? false, + }); if ('connection' in options && options.connection) return sqliteAmbientConnectionPool({ diff --git a/src/packages/dumbo/src/storage/sqlite/core/transactions/index.ts b/src/packages/dumbo/src/storage/sqlite/core/transactions/index.ts index 54c533d0..4a38d7ee 100644 --- a/src/packages/dumbo/src/storage/sqlite/core/transactions/index.ts +++ b/src/packages/dumbo/src/storage/sqlite/core/transactions/index.ts @@ -5,7 +5,10 @@ import { type DatabaseTransaction, } from '../../../../core'; import { sqliteSQLExecutor } from '../../core/execute'; -import type { SQLiteClientOrPoolClient } from '../connections'; +import { + transactionNestingCounter, + type SQLiteClientOrPoolClient, +} from '../connections'; export type SQLiteTransaction< ConnectorType extends SQLiteConnectorType = SQLiteConnectorType, @@ -18,31 +21,71 @@ export const sqliteTransaction = >( connector: ConnectorType, connection: () => Connection, + allowNestedTransactions: boolean, ) => ( getClient: Promise, options?: { close: (client: DbClient, error?: unknown) => Promise }, - ): DatabaseTransaction => ({ - connection: connection(), - connector, - begin: async () => { - const client = await getClient; - await client.query('BEGIN TRANSACTION'); - }, - commit: async () => { - const client = await getClient; - - await client.query('COMMIT'); - - if (options?.close) await options?.close(client); - }, - rollback: async (error?: unknown) => { - const client = await getClient; - await client.query('ROLLBACK'); - - if (options?.close) await options?.close(client, error); - }, - execute: sqlExecutor(sqliteSQLExecutor(connector), { - connect: () => getClient, - }), - }); + ): DatabaseTransaction => { + const transactionCounter = transactionNestingCounter(); + return { + connection: connection(), + connector, + begin: async function () { + const client = await getClient; + + if (allowNestedTransactions) { + if (transactionCounter.level >= 1) { + transactionCounter.increment(); + await client.query( + `SAVEPOINT transaction${transactionCounter.level}`, + ); + return; + } + + transactionCounter.increment(); + } + + await client.query('BEGIN TRANSACTION'); + }, + commit: async function () { + const client = await getClient; + + try { + if (allowNestedTransactions) { + if (transactionCounter.level > 1) { + await client.query( + `RELEASE transaction${transactionCounter.level}`, + ); + transactionCounter.decrement(); + + return; + } + + transactionCounter.reset(); + } + await client.query('COMMIT'); + } finally { + if (options?.close) await options?.close(client); + } + }, + rollback: async function (error?: unknown) { + const client = await getClient; + try { + if (allowNestedTransactions) { + if (transactionCounter.level > 1) { + transactionCounter.decrement(); + return; + } + } + + await client.query('ROLLBACK'); + } finally { + if (options?.close) await options?.close(client, error); + } + }, + execute: sqlExecutor(sqliteSQLExecutor(connector), { + connect: () => getClient, + }), + }; + }; diff --git a/src/packages/dumbo/src/storage/sqlite/core/transactions/transactions.int.spec.ts b/src/packages/dumbo/src/storage/sqlite/core/transactions/transactions.int.spec.ts new file mode 100644 index 00000000..d9542b3c --- /dev/null +++ b/src/packages/dumbo/src/storage/sqlite/core/transactions/transactions.int.spec.ts @@ -0,0 +1,378 @@ +import assert from 'assert'; +import fs from 'fs'; +import { afterEach, describe, it } from 'node:test'; +import path from 'path'; +import { fileURLToPath } from 'url'; +import { InMemorySQLiteDatabase, sqlitePool } from '..'; +import { rawSql } from '../../../../core'; + +void describe('SQLite Transactions', () => { + const inMemoryfileName: string = InMemorySQLiteDatabase; + + const testDatabasePath = path.resolve( + path.dirname(fileURLToPath(import.meta.url)), + ); + const fileName = path.resolve(testDatabasePath, 'test-transactions.db'); + + const testCases = [ + { testName: 'in-memory', fileName: inMemoryfileName }, + { testName: 'file', fileName: fileName }, + ]; + + afterEach(() => { + if (!fs.existsSync(fileName)) { + return; + } + try { + fs.unlinkSync(fileName); + } catch (error) { + console.log('Error deleting file:', error); + } + }); + + for (const { testName, fileName } of testCases) { + void describe(`transactions with ${testName} database`, () => { + void it('commits a nested transaction with pool', async () => { + const pool = sqlitePool({ + connector: 'SQLite:sqlite3', + fileName, + allowNestedTransactions: true, + }); + const connection = await pool.connection(); + + try { + await connection.execute.query( + rawSql('CREATE TABLE test_table (id INTEGER, value TEXT)'), + ); + + const result = await connection.withTransaction(async () => { + await connection.execute.query( + rawSql( + 'INSERT INTO test_table (id, value) VALUES (2, "test") RETURNING id', + ), + ); + + const result = await connection.withTransaction( + async () => { + const result = await connection.execute.query( + rawSql( + 'INSERT INTO test_table (id, value) VALUES (1, "test") RETURNING id', + ), + ); + return (result.rows[0]?.id as number) ?? null; + }, + ); + + return result; + }); + + assert.strictEqual(result, 1); + + const rows = await connection.execute.query( + rawSql('SELECT COUNT(*) as count FROM test_table'), + ); + + assert.strictEqual(rows.rows[0]?.count, 2); + } finally { + await connection.close(); + await pool.close(); + } + }); + void it('should fail with an error if transaction nested is false', async () => { + const pool = sqlitePool({ + connector: 'SQLite:sqlite3', + fileName, + allowNestedTransactions: false, + }); + const connection = await pool.connection(); + + try { + await connection.execute.query( + rawSql('CREATE TABLE test_table (id INTEGER, value TEXT)'), + ); + + await connection.withTransaction(async () => { + await connection.execute.query( + rawSql( + 'INSERT INTO test_table (id, value) VALUES (2, "test") RETURNING id', + ), + ); + + const result = await connection.withTransaction( + async () => { + const result = await connection.execute.query( + rawSql( + 'INSERT INTO test_table (id, value) VALUES (1, "test") RETURNING id', + ), + ); + return (result.rows[0]?.id as number) ?? null; + }, + ); + + return result; + }); + } catch (error) { + assert.strictEqual( + (error as Error).message, + 'SQLITE_ERROR: cannot start a transaction within a transaction', + ); + } finally { + await connection.close(); + await pool.close(); + } + }); + + void it('should try catch and roll back everything when the inner transaction errors for a pooled connection', async () => { + const pool = sqlitePool({ + connector: 'SQLite:sqlite3', + fileName, + allowNestedTransactions: true, + }); + const connection = await pool.connection(); + const connection2 = await pool.connection(); + + try { + await connection.execute.query( + rawSql('CREATE TABLE test_table (id INTEGER, value TEXT)'), + ); + + try { + await connection.withTransaction(async () => { + await connection.execute.query( + rawSql( + 'INSERT INTO test_table (id, value) VALUES (2, "test") RETURNING id', + ), + ); + + await connection2.withTransaction(() => { + throw new Error('Intentionally throwing'); + }); + }); + } catch (error) { + assert.strictEqual( + (error as Error).message, + 'Intentionally throwing', + ); + } + const rows = await connection.execute.query( + rawSql('SELECT COUNT(*) as count FROM test_table'), + ); + + assert.strictEqual(rows.rows[0]?.count, 0); + } finally { + await connection.close(); + await pool.close(); + } + }); + void it('should try catch and roll back everything when the outer transactions errors for a pooled connection', async () => { + const pool = sqlitePool({ + connector: 'SQLite:sqlite3', + fileName, + allowNestedTransactions: true, + }); + const connection = await pool.connection(); + const connection2 = await pool.connection(); + + try { + await connection.execute.query( + rawSql('CREATE TABLE test_table (id INTEGER, value TEXT)'), + ); + + try { + await connection.withTransaction<{ + id: null | string; + }>(async () => { + await connection.execute.query( + rawSql( + 'INSERT INTO test_table (id, value) VALUES (1, "test") RETURNING id', + ), + ); + + await connection2.withTransaction(async () => { + const result = await connection.execute.query( + rawSql( + 'INSERT INTO test_table (id, value) VALUES (2, "test") RETURNING id', + ), + ); + return (result.rows[0]?.id as number) ?? null; + }); + + throw new Error('Intentionally throwing'); + }); + } catch (error) { + // make sure the rror is the correct one. catch but let it continue so it doesnt trigger + // the outer errors + assert.strictEqual( + (error as Error).message, + 'Intentionally throwing', + ); + } + const rows = await connection.execute.query( + rawSql('SELECT COUNT(*) as count FROM test_table'), + ); + + assert.strictEqual(rows.rows[0]?.count, 0); + } finally { + await connection.close(); + await pool.close(); + } + }); + + void it('commits a nested transaction with singleton pool', async () => { + const pool = sqlitePool({ + connector: 'SQLite:sqlite3', + fileName, + singleton: true, + allowNestedTransactions: true, + }); + const connection = await pool.connection(); + const connection2 = await pool.connection(); + + try { + await connection.execute.query( + rawSql('CREATE TABLE test_table (id INTEGER, value TEXT)'), + ); + + const result = await connection.withTransaction( + async () => { + await connection.execute.query( + rawSql( + 'INSERT INTO test_table (id, value) VALUES (2, "test") RETURNING id', + ), + ); + + const result = await connection2.withTransaction( + async () => { + const result = await connection2.execute.query( + rawSql( + 'INSERT INTO test_table (id, value) VALUES (1, "test") RETURNING id', + ), + ); + return (result.rows[0]?.id as number) ?? null; + }, + ); + + return result; + }, + ); + + assert.strictEqual(result, 1); + + const rows = await connection.execute.query<{ count: number }>( + rawSql('SELECT COUNT(*) as count FROM test_table'), + ); + + assert.strictEqual(rows.rows[0]?.count, 2); + } finally { + await connection.close(); + await pool.close(); + } + }); + + void it('transactions errors inside the nested inner transaction for a singleton should try catch and roll back everything', async () => { + const pool = sqlitePool({ + connector: 'SQLite:sqlite3', + fileName, + singleton: true, + allowNestedTransactions: true, + }); + const connection = await pool.connection(); + const connection2 = await pool.connection(); + + try { + await connection.execute.query( + rawSql('CREATE TABLE test_table (id INTEGER, value TEXT)'), + ); + + try { + await connection.withTransaction<{ + id: null | string; + }>(async () => { + await connection.execute.query( + rawSql( + 'INSERT INTO test_table (id, value) VALUES (2, "test") RETURNING id', + ), + ); + + const result = await connection2.withTransaction<{ + id: null | string; + }>(() => { + throw new Error('Intentionally throwing'); + }); + + return { success: true, result: result }; + }); + } catch (error) { + assert.strictEqual( + (error as Error).message, + 'Intentionally throwing', + ); + } + + const rows = await connection.execute.query( + rawSql('SELECT COUNT(*) as count FROM test_table'), + ); + + assert.strictEqual(rows.rows[0]?.count, 0); + } finally { + await connection.close(); + await pool.close(); + } + }); + void it('transactions errors inside the outer transaction for a singleton should try catch and roll back everything', async () => { + const pool = sqlitePool({ + connector: 'SQLite:sqlite3', + fileName, + singleton: true, + allowNestedTransactions: true, + }); + const connection = await pool.connection(); + const connection2 = await pool.connection(); + + try { + await connection.execute.query( + rawSql('CREATE TABLE test_table (id INTEGER, value TEXT)'), + ); + + try { + await connection.withTransaction<{ + id: null | string; + }>(async () => { + await connection.execute.query( + rawSql( + 'INSERT INTO test_table (id, value) VALUES (1, "test") RETURNING id', + ), + ); + + await connection2.withTransaction(async () => { + const result = await connection.execute.query( + rawSql( + 'INSERT INTO test_table (id, value) VALUES (2, "test") RETURNING id', + ), + ); + return (result.rows[0]?.id as number) ?? null; + }); + + throw new Error('Intentionally throwing'); + }); + } catch (error) { + // make sure the rror is the correct one. catch but let it continue so it doesnt trigger + // the outer errors + assert.strictEqual( + (error as Error).message, + 'Intentionally throwing', + ); + } + const rows = await connection.execute.query( + rawSql('SELECT COUNT(*) as count FROM test_table'), + ); + + assert.strictEqual(rows.rows[0]?.count, 0); + } finally { + await connection.close(); + await pool.close(); + } + }); + }); + } +});