Skip to content
Merged
5 changes: 4 additions & 1 deletion src/packages/dumbo/src/core/connections/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ export type CreateConnectionOptions<
close: (client: DbClient) => Promise<void>;
initTransaction: (
connection: () => ConnectionType,
) => (client: Promise<DbClient>) => DatabaseTransaction<Connector, DbClient>;
) => (
client: Promise<DbClient>,
options?: { close: (client: DbClient, error?: unknown) => Promise<void> },
) => DatabaseTransaction<Connector, DbClient>;
executor: () => Executor;
};

Expand Down
25 changes: 20 additions & 5 deletions src/packages/dumbo/src/core/connections/transaction.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,27 @@ export const transactionFactoryWithDbClient = <
connect: () => Promise<DbClient>,
initTransaction: (
client: Promise<DbClient>,
options?: { close: (client: DbClient, error?: unknown) => Promise<void> },
) => DatabaseTransaction<Connector, DbClient>,
): DatabaseTransactionFactory<Connector, DbClient> => ({
transaction: () => initTransaction(connect()),
withTransaction: (handle) =>
executeInTransaction(initTransaction(connect()), handle),
});
): DatabaseTransactionFactory<Connector, DbClient> => {
let currentTransaction: DatabaseTransaction<Connector, DbClient> | undefined =
undefined;

const getOrInitCurrentTransaction = () =>
currentTransaction ??
(currentTransaction = initTransaction(connect(), {
close: () => {
currentTransaction = undefined;
return Promise.resolve();
},
}));

return {
transaction: getOrInitCurrentTransaction,
withTransaction: (handle) =>
executeInTransaction(getOrInitCurrentTransaction(), handle),
};
};

const wrapInConnectionClosure = async <
ConnectionType extends Connection = Connection,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,19 @@ export const nodePostgresTransaction =
commit: async () => {
const client = await getClient;

await client.query('COMMIT');

if (options?.close) await options?.close(client);
try {
await client.query('COMMIT');
} finally {
if (options?.close) await options?.close(client);
}
},
rollback: async (error?: unknown) => {
const client = await getClient;
await client.query('ROLLBACK');

if (options?.close) await options?.close(client, error);
try {
await client.query('ROLLBACK');
} finally {
if (options?.close) await options?.close(client, error);
}
},
execute: sqlExecutor(nodePostgresSQLExecutor(), {
connect: () => getClient,
Expand Down
46 changes: 42 additions & 4 deletions src/packages/dumbo/src/storage/sqlite/core/connections/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ export type SQLitePoolConnectionOptions<
type: 'PoolClient';
connect: Promise<SQLitePoolClient>;
close: (client: SQLitePoolClient) => Promise<void>;
allowNestedTransactions: boolean;
};

export type SQLiteClientConnectionOptions<
Expand All @@ -58,6 +59,7 @@ export type SQLiteClientConnectionOptions<
type: 'Client';
connect: Promise<SQLiteClient>;
close: (client: SQLiteClient) => Promise<void>;
allowNestedTransactions: boolean;
};

export type SQLiteClientConnection<
Expand All @@ -79,31 +81,65 @@ export const sqliteClientConnection = <
>(
options: SQLiteClientConnectionOptions<ConnectorType>,
): SQLiteClientConnection<ConnectorType> => {
const { connect, close } = options;
const { connect, close, allowNestedTransactions } = options;

return createConnection({
connector: options.connector,
connect,
close,
initTransaction: (connection) =>
sqliteTransaction(options.connector, connection),
sqliteTransaction(options.connector, connection, allowNestedTransactions),
executor: () => sqliteSQLExecutor(options.connector),
});
};

export type TransactionNestingCounter = {
increment: () => void;
decrement: () => void;
reset: () => void;
level: number;
};

export const transactionNestingCounter = (): TransactionNestingCounter => {
let transactionLevel = 0;

return {
reset: () => {
transactionLevel = 0;
},
increment: () => {
transactionLevel++;
},
decrement: () => {
transactionLevel--;

if (transactionLevel < 0) {
throw new Error('Transaction level is out of bounds');
}
},
get level() {
return transactionLevel;
},
};
};

export const sqlitePoolClientConnection = <
ConnectorType extends SQLiteConnectorType = SQLiteConnectorType,
>(
options: SQLitePoolConnectionOptions<ConnectorType>,
): SQLitePoolClientConnection<ConnectorType> => {
const { connect, close } = options;
const { connect, close, allowNestedTransactions } = options;

return createConnection({
connector: options.connector,
connect,
close,
initTransaction: (connection) =>
sqliteTransaction(options.connector, connection),
sqliteTransaction(
options.connector,
connection,
allowNestedTransactions ?? false,
),
executor: () => sqliteSQLExecutor(options.connector),
});
};
Expand All @@ -113,11 +149,13 @@ export function sqliteConnection<
>(
options: SQLitePoolConnectionOptions<ConnectorType>,
): SQLitePoolClientConnection;

export function sqliteConnection<
ConnectorType extends SQLiteConnectorType = SQLiteConnectorType,
>(
options: SQLiteClientConnectionOptions<ConnectorType>,
): SQLiteClientConnection;

export function sqliteConnection<
ConnectorType extends SQLiteConnectorType = SQLiteConnectorType,
>(
Expand Down
21 changes: 18 additions & 3 deletions src/packages/dumbo/src/storage/sqlite/core/pool/pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ export const sqliteSingletonClientPool = <
options: {
connector: ConnectorType;
database?: string | undefined;
allowNestedTransactions?: boolean;
} & SQLiteFileNameOrConnectionString,
): SQLiteAmbientClientPool<ConnectorType> => {
const { connector } = options;
Expand All @@ -80,6 +81,7 @@ export const sqliteSingletonClientPool = <
connector,
type: 'Client',
connect,
allowNestedTransactions: options.allowNestedTransactions ?? false,
close: () => Promise.resolve(),
}));
};
Expand All @@ -103,9 +105,10 @@ export const sqliteAlwaysNewClientPool = <
options: {
connector: ConnectorType;
database?: string | undefined;
allowNestedTransactions?: boolean;
} & SQLiteFileNameOrConnectionString,
): SQLiteAmbientClientPool<ConnectorType> => {
const { connector } = options;
const { connector, allowNestedTransactions } = options;

return createConnectionPool({
connector: connector,
Expand All @@ -124,6 +127,7 @@ export const sqliteAlwaysNewClientPool = <
connector,
type: 'Client',
connect,
allowNestedTransactions: allowNestedTransactions ?? false,
close: (client) => client.close(),
});
},
Expand All @@ -135,8 +139,9 @@ export const sqliteAmbientClientPool = <
>(options: {
connector: ConnectorType;
client: SQLiteClient;
allowNestedTransactions: boolean;
}): SQLiteAmbientClientPool<ConnectorType> => {
const { client, connector } = options;
const { client, connector, allowNestedTransactions } = options;

const getConnection = () => {
const connect = Promise.resolve(client);
Expand All @@ -145,6 +150,7 @@ export const sqliteAmbientClientPool = <
connector,
type: 'Client',
connect,
allowNestedTransactions,
close: () => Promise.resolve(),
});
};
Expand Down Expand Up @@ -178,6 +184,7 @@ export type SQLitePoolPooledOptions<
connector: ConnectorType;
pooled?: true;
singleton?: boolean;
allowNestedTransactions?: boolean;
};

export type SQLitePoolNotPooledOptions<
Expand All @@ -188,11 +195,13 @@ export type SQLitePoolNotPooledOptions<
pooled?: false;
client: SQLiteClient;
singleton?: true;
allowNestedTransactions?: boolean;
}
| {
connector: ConnectorType;
pooled?: boolean;
singleton?: boolean;
allowNestedTransactions?: boolean;
}
| {
connector: ConnectorType;
Expand All @@ -201,6 +210,7 @@ export type SQLitePoolNotPooledOptions<
| SQLiteClientConnection<ConnectorType>;
pooled?: false;
singleton?: true;
allowNestedTransactions?: boolean;
};

export type SQLitePoolOptions<
Expand All @@ -218,6 +228,7 @@ export function sqlitePool<
options: SQLitePoolNotPooledOptions<ConnectorType> &
SQLiteFileNameOrConnectionString,
): SQLiteAmbientClientPool<ConnectorType>;

export function sqlitePool<
ConnectorType extends SQLiteConnectorType = SQLiteConnectorType,
>(
Expand All @@ -231,7 +242,11 @@ export function sqlitePool<
// setSQLiteTypeParser(serializer ?? JSONSerializer);

if ('client' in options && options.client)
return sqliteAmbientClientPool({ connector, client: options.client });
return sqliteAmbientClientPool({
connector,
client: options.client,
allowNestedTransactions: options.allowNestedTransactions ?? false,
});

if ('connection' in options && options.connection)
return sqliteAmbientConnectionPool({
Expand Down
93 changes: 68 additions & 25 deletions src/packages/dumbo/src/storage/sqlite/core/transactions/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ import {
type DatabaseTransaction,
} from '../../../../core';
import { sqliteSQLExecutor } from '../../core/execute';
import type { SQLiteClientOrPoolClient } from '../connections';
import {
transactionNestingCounter,
type SQLiteClientOrPoolClient,
} from '../connections';

export type SQLiteTransaction<
ConnectorType extends SQLiteConnectorType = SQLiteConnectorType,
Expand All @@ -18,31 +21,71 @@ export const sqliteTransaction =
>(
connector: ConnectorType,
connection: () => Connection<ConnectorType, DbClient>,
allowNestedTransactions: boolean,
) =>
(
getClient: Promise<DbClient>,
options?: { close: (client: DbClient, error?: unknown) => Promise<void> },
): DatabaseTransaction<ConnectorType, DbClient> => ({
connection: connection(),
connector,
begin: async () => {
const client = await getClient;
await client.query('BEGIN TRANSACTION');
},
commit: async () => {
const client = await getClient;

await client.query('COMMIT');

if (options?.close) await options?.close(client);
},
rollback: async (error?: unknown) => {
const client = await getClient;
await client.query('ROLLBACK');

if (options?.close) await options?.close(client, error);
},
execute: sqlExecutor(sqliteSQLExecutor(connector), {
connect: () => getClient,
}),
});
): DatabaseTransaction<ConnectorType, DbClient> => {
const transactionCounter = transactionNestingCounter();
return {
connection: connection(),
connector,
begin: async function () {
const client = await getClient;

if (allowNestedTransactions) {
if (transactionCounter.level >= 1) {
transactionCounter.increment();
await client.query(
`SAVEPOINT transaction${transactionCounter.level}`,
);
return;
}

transactionCounter.increment();
}

await client.query('BEGIN TRANSACTION');
},
commit: async function () {
const client = await getClient;

try {
if (allowNestedTransactions) {
if (transactionCounter.level > 1) {
await client.query(
`RELEASE transaction${transactionCounter.level}`,
);
transactionCounter.decrement();

return;
}

transactionCounter.reset();
}
await client.query('COMMIT');
} finally {
if (options?.close) await options?.close(client);
}
},
rollback: async function (error?: unknown) {
const client = await getClient;
try {
if (allowNestedTransactions) {
if (transactionCounter.level > 1) {
transactionCounter.decrement();
return;
}
}

await client.query('ROLLBACK');
} finally {
if (options?.close) await options?.close(client, error);
}
},
execute: sqlExecutor(sqliteSQLExecutor(connector), {
connect: () => getClient,
}),
};
};
Loading