Skip to content

Commit b0f26d7

Browse files
keeferrourkesvc-squareup-copybara
authored andcommitted
Add rollback hooks to Transacter sessions
This change adds a new callback: `Session.onRollback { e: Throwable -> }` to execute work that should be done if the work done in the transaction failed and resulted in a rollback. Misk supports multiple Transacters and sessions, but the API/impl are the same for each. GitOrigin-RevId: 88e4f96ea23f6eb2b383dca8a3e02aadd148d5f4
1 parent e5235be commit b0f26d7

File tree

11 files changed

+173
-27
lines changed

11 files changed

+173
-27
lines changed

misk-hibernate/src/main/kotlin/misk/hibernate/RealTransacter.kt

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ private constructor(
123123
}
124124
}
125125
}
126-
}
126+
},
127127
)
128128
},
129129
5,
@@ -172,7 +172,9 @@ private constructor(
172172
if (config.type.isVitess) {
173173
hibernateSession.doReturningWork { conn ->
174174
previousTarget = conn.catalog
175-
if (previousTarget != primaryTarget) { conn.catalog = primaryTarget }
175+
if (previousTarget != primaryTarget) {
176+
conn.catalog = primaryTarget
177+
}
176178
}
177179
}
178180

@@ -193,7 +195,9 @@ private constructor(
193195
// Restore to the same destination the lock was acquired on.
194196
hibernateSession.doReturningWork { conn ->
195197
previousTarget = conn.catalog
196-
if (previousTarget != primaryTarget) { conn.catalog = primaryTarget }
198+
if (previousTarget != primaryTarget) {
199+
conn.catalog = primaryTarget
200+
}
197201
}
198202
}
199203
hibernateSession.tryReleaseLock(lockKey)
@@ -278,6 +282,7 @@ private constructor(
278282
if (transaction.isActive) {
279283
try {
280284
transaction.rollback()
285+
session.onSessionClose { session.runRollbackHooks(e) }
281286
} catch (suppressed: Exception) {
282287
rethrow.addSuppressed(suppressed)
283288
}
@@ -332,7 +337,7 @@ private constructor(
332337
"No reader is configured for replica reads, pass in both a writer and reader qualifier " +
333338
"and the full DataSourceClustersConfig into HibernateModule, like this:\n" +
334339
"\tinstall(HibernateModule(AppDb::class, AppReaderDb::class, " +
335-
"config.data_source_clusters[\"name\"]))"
340+
"config.data_source_clusters[\"name\"]))",
336341
)
337342
}
338343

@@ -503,6 +508,8 @@ private constructor(
503508
private val preCommitHooks = mutableListOf<() -> Unit>()
504509
private val postCommitHooks = mutableListOf<() -> Unit>()
505510
private val sessionCloseHooks = mutableListOf<() -> Unit>()
511+
private val rollbackHooks = mutableListOf<(error: Throwable) -> Unit>()
512+
506513
internal var inTransaction = false
507514

508515
init {
@@ -616,7 +623,7 @@ private constructor(
616623

617624
internal fun preCommit() {
618625
preCommitHooks.forEach { preCommitHook ->
619-
// Propagate hook exceptions up to the transacter so that the the transaction is rolled
626+
// Propagate hook exceptions up to the transacter so that the transaction is rolled
620627
// back and the error gets returned to the application.
621628
preCommitHook()
622629
}
@@ -648,6 +655,14 @@ private constructor(
648655
sessionCloseHooks.add(work)
649656
}
650657

658+
override fun onRollback(work: (error: Throwable) -> Unit) {
659+
rollbackHooks.add(work)
660+
}
661+
662+
internal fun runRollbackHooks(error: Throwable) {
663+
rollbackHooks.forEach { rollbackHook -> rollbackHook(error) }
664+
}
665+
651666
override fun <T> withoutChecks(vararg checks: Check, body: () -> T): T {
652667
return CheckDisabler.withoutChecks(*checks) { body() }
653668
}

misk-hibernate/src/test/kotlin/misk/hibernate/TransacterTest.kt

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,39 @@ abstract class TransacterTest {
695695
assertThat(postCommitHooksTriggered).isEmpty()
696696
}
697697

698+
@Test
699+
fun rollbackHooksCalledOnRollbackOnly() {
700+
val rollbackHooksTriggered = mutableListOf<String>()
701+
702+
// Happy path.
703+
transacter.transaction { session ->
704+
session.onRollback { error ->
705+
rollbackHooksTriggered.add("never")
706+
error("this should never have happened")
707+
}
708+
}
709+
710+
assertThat(rollbackHooksTriggered).isEmpty()
711+
712+
// Rollback path.
713+
assertThrows<IllegalStateException> {
714+
transacter.transaction { session ->
715+
session.onRollback { error ->
716+
assertThat(error).hasMessage("bad things happened here")
717+
assertThat(transacter.inTransaction).isFalse
718+
rollbackHooksTriggered.add("first")
719+
}
720+
session.onRollback { error ->
721+
assertThat(error).hasMessage("bad things happened here")
722+
assertThat(transacter.inTransaction).isFalse
723+
rollbackHooksTriggered.add("second")
724+
}
725+
error("bad things happened here")
726+
}
727+
}
728+
assertThat(rollbackHooksTriggered).containsExactly("first", "second")
729+
}
730+
698731
@Test
699732
fun errorInPostCommitHookDoesNotRollback() {
700733
val postCommitHooksTriggered = mutableListOf<String>()

misk-jdbc/api/misk-jdbc.api

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,10 +419,12 @@ public final class misk/jdbc/JDBCSession : misk/jdbc/Session {
419419
public final fun component1 ()Ljava/sql/Connection;
420420
public final fun executePostCommitHooks ()V
421421
public final fun executePreCommitHooks ()V
422+
public final fun executeRollbackHooks (Ljava/lang/Throwable;)V
422423
public final fun executeSessionCloseHooks ()V
423424
public final fun getConnection ()Ljava/sql/Connection;
424425
public fun onPostCommit (Lkotlin/jvm/functions/Function0;)V
425426
public fun onPreCommit (Lkotlin/jvm/functions/Function0;)V
427+
public fun onRollback (Lkotlin/jvm/functions/Function1;)V
426428
public fun onSessionClose (Lkotlin/jvm/functions/Function0;)V
427429
public fun useConnection (Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
428430
}
@@ -574,6 +576,7 @@ public final class misk/jdbc/SchemaMigratorService : com/google/common/util/conc
574576
public abstract interface class misk/jdbc/Session {
575577
public abstract fun onPostCommit (Lkotlin/jvm/functions/Function0;)V
576578
public abstract fun onPreCommit (Lkotlin/jvm/functions/Function0;)V
579+
public abstract fun onRollback (Lkotlin/jvm/functions/Function1;)V
577580
public abstract fun onSessionClose (Lkotlin/jvm/functions/Function0;)V
578581
public abstract fun useConnection (Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
579582
}

misk-jdbc/src/main/kotlin/misk/jdbc/JDBCSession.kt

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ package misk.jdbc
22

33
import java.sql.Connection
44
import java.util.concurrent.ConcurrentHashMap
5+
import java.util.concurrent.ConcurrentLinkedQueue
56
import java.util.concurrent.ConcurrentMap
67

7-
class JDBCSession(val connection: Connection): Session {
8+
class JDBCSession(val connection: Connection) : Session {
89
private val hooks: ConcurrentMap<HookType, List<() -> Unit>> = ConcurrentHashMap()
10+
private val rollbackHooks: ConcurrentLinkedQueue<(error: Throwable) -> Unit> = ConcurrentLinkedQueue()
911

1012
override fun <T> useConnection(work: (Connection) -> T): T {
1113
return work(connection)
@@ -23,6 +25,10 @@ class JDBCSession(val connection: Connection): Session {
2325
hooks.add(HookType.SESSION_CLOSE, work)
2426
}
2527

28+
override fun onRollback(work: (error: Throwable) -> Unit) {
29+
rollbackHooks.add(work)
30+
}
31+
2632
fun executePreCommitHooks() {
2733
hooks[HookType.PRE]?.forEach { it() }
2834
}
@@ -41,6 +47,10 @@ class JDBCSession(val connection: Connection): Session {
4147
hooks[HookType.SESSION_CLOSE]?.forEach { it() }
4248
}
4349

50+
fun executeRollbackHooks(error: Throwable) {
51+
rollbackHooks.forEach { it(error) }
52+
}
53+
4454
fun ConcurrentMap<HookType, List<() -> Unit>>.add(hookType: HookType, work: () -> Unit) {
4555
merge(hookType, listOf(work)) { oldList, newList ->
4656
mutableListOf<() -> Unit>().apply {
@@ -50,7 +60,7 @@ class JDBCSession(val connection: Connection): Session {
5060
}
5161
}
5262

53-
enum class HookType{
63+
enum class HookType {
5464
PRE,
5565
POST,
5666
SESSION_CLOSE

misk-jdbc/src/main/kotlin/misk/jdbc/Session.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,10 @@ interface Session {
2525
* A new transaction can be initiated as part of this hook.
2626
*/
2727
fun onSessionClose(work: () -> Unit)
28+
29+
/**
30+
* Registers a hook that fires if the session transaction is rolled back.
31+
* This is called after the transaction has closed, so a new transaction can be initiated as part of this hook.
32+
*/
33+
fun onRollback(work: (error: Throwable) -> Unit)
2834
}

misk-jdbc/src/main/kotlin/misk/jdbc/Transacter.kt

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,7 @@ interface Transacter {
1717
* Prefer using [transactionWithSession] instead of this method as it has more functionality such
1818
* as commit hooks.
1919
*/
20-
@Deprecated(
21-
"Use transactionWithSession instead",
22-
replaceWith = ReplaceWith("transactionWithSession(work)")
23-
)
20+
@Deprecated("Use transactionWithSession instead", replaceWith = ReplaceWith("transactionWithSession(work)"))
2421
fun <T> transaction(work: (connection: Connection) -> T): T
2522

2623
/**
@@ -40,6 +37,7 @@ class RealTransacter(private val dataSourceService: DataSourceService) : Transac
4037

4138
override val inTransaction: Boolean get() = transacting.get()
4239

40+
@Deprecated("Use transactionWithSession instead", replaceWith = ReplaceWith("transactionWithSession(work)"))
4341
override fun <T> transaction(work: (connection: Connection) -> T): T =
4442
transactionWithSession { session -> session.useConnection(work) }
4543

@@ -61,14 +59,17 @@ class RealTransacter(private val dataSourceService: DataSourceService) : Transac
6159
if (connection.autoCommit) {
6260
connection.autoCommit = false
6361
}
64-
// Do stuff
6562

63+
// Do stuff
6664
session = JDBCSession(connection)
67-
val result = work(session!!)
65+
val result = runCatching { work(session) }
66+
.onFailure { e -> session.onSessionClose { session.executeRollbackHooks(e) } }
67+
.getOrThrow()
68+
6869
// COMMIT
69-
session!!.executePreCommitHooks()
70+
session.executePreCommitHooks()
7071
connection.commit()
71-
session!!.executePostCommitHooks()
72+
session.executePostCommitHooks()
7273
result
7374
}
7475
} finally {

misk-jdbc/src/test/kotlin/misk/jdbc/RealTransacterTest.kt

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import org.assertj.core.api.Assertions.assertThat
1717
import org.assertj.core.api.Assertions.assertThatExceptionOfType
1818
import org.junit.jupiter.api.Disabled
1919
import org.junit.jupiter.api.Test
20+
import org.junit.jupiter.api.assertThrows
2021
import wisp.config.Config
2122
import wisp.deployment.TESTING
2223
import java.sql.Connection
@@ -167,6 +168,39 @@ abstract class RealTransacterTest {
167168
assertThat(afterCount).isEqualTo(1)
168169
}
169170

171+
@Test
172+
fun rollbackHooksCalledOnRollbackOnly() {
173+
val rollbackHooksTriggered = mutableListOf<String>()
174+
175+
// Happy path.
176+
transacter.transactionWithSession { session ->
177+
session.onRollback { _ ->
178+
rollbackHooksTriggered.add("never")
179+
error("this should never have happened")
180+
}
181+
}
182+
183+
assertThat(rollbackHooksTriggered).isEmpty()
184+
185+
// Rollback path.
186+
assertThrows<IllegalStateException> {
187+
transacter.transactionWithSession { session ->
188+
session.onRollback { error ->
189+
assertThat(error).hasMessage("bad things happened here")
190+
assertThat(transacter.inTransaction).isFalse
191+
rollbackHooksTriggered.add("first")
192+
}
193+
session.onRollback { error ->
194+
assertThat(error).hasMessage("bad things happened here")
195+
assertThat(transacter.inTransaction).isFalse
196+
rollbackHooksTriggered.add("second")
197+
}
198+
error("bad things happened here")
199+
}
200+
}
201+
assertThat(rollbackHooksTriggered).containsExactly("first", "second")
202+
}
203+
170204
@Test
171205
fun `session close hooks execute when there are no exceptions`() {
172206
var sessionCloseHook1Executed = false

misk-jooq/api/misk-jooq.api

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,12 @@ public final class misk/jooq/JooqSession : misk/jdbc/Session {
4141
public final fun component1 ()Lorg/jooq/DSLContext;
4242
public final fun executePostCommitHooks ()V
4343
public final fun executePreCommitHooks ()V
44+
public final fun executeRollbackHooks (Ljava/lang/Throwable;)V
4445
public final fun executeSessionCloseHooks ()V
4546
public final fun getCtx ()Lorg/jooq/DSLContext;
4647
public fun onPostCommit (Lkotlin/jvm/functions/Function0;)V
4748
public fun onPreCommit (Lkotlin/jvm/functions/Function0;)V
49+
public fun onRollback (Lkotlin/jvm/functions/Function1;)V
4850
public fun onSessionClose (Lkotlin/jvm/functions/Function0;)V
4951
public fun useConnection (Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
5052
}

misk-jooq/src/main/kotlin/misk/jooq/JooqSession.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@ import org.jooq.DSLContext
66
import java.lang.UnsupportedOperationException
77
import java.sql.Connection
88
import java.util.concurrent.ConcurrentHashMap
9+
import java.util.concurrent.ConcurrentLinkedQueue
910
import java.util.concurrent.ConcurrentMap
1011

1112
class JooqSession(val ctx: DSLContext) : Session {
1213
private val hooks: ConcurrentMap<HookType, List<() -> Unit>> = ConcurrentHashMap()
14+
private val rollbackHooks: ConcurrentLinkedQueue<(error: Throwable) -> Unit> = ConcurrentLinkedQueue()
1315

1416
override fun <T> useConnection(work: (Connection) -> T): T {
1517
return ctx.connectionResult(work)
@@ -27,6 +29,10 @@ class JooqSession(val ctx: DSLContext) : Session {
2729
hooks.add(HookType.SESSION_CLOSE, work)
2830
}
2931

32+
override fun onRollback(work: (error: Throwable) -> Unit) {
33+
rollbackHooks.add(work)
34+
}
35+
3036
fun executePreCommitHooks() {
3137
hooks[HookType.PRE]?.forEach {
3238
it()
@@ -47,6 +53,10 @@ class JooqSession(val ctx: DSLContext) : Session {
4753
hooks[HookType.SESSION_CLOSE]?.forEach { it() }
4854
}
4955

56+
fun executeRollbackHooks(error: Throwable) {
57+
rollbackHooks.forEach { it(error) }
58+
}
59+
5060
fun ConcurrentMap<HookType, List<() -> Unit>>.add(hookType: HookType, work: () -> Unit) {
5161
merge(hookType, listOf(work)) { oldList, newList ->
5262
mutableListOf<() -> Unit>().apply {

misk-jooq/src/main/kotlin/misk/jooq/JooqTransacter.kt

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,10 @@ class JooqTransacter @JvmOverloads constructor(
8787
return try {
8888
dslContext(dataSourceService, clock, dataSourceConfig, options).transactionResult { configuration ->
8989
jooqSession = JooqSession(DSL.using(configuration))
90-
callback(jooqSession!!).also { jooqSession!!.executePreCommitHooks() }
91-
}.also {
92-
jooqSession?.executePostCommitHooks()
93-
}
90+
runCatching { callback(jooqSession).also { jooqSession.executePreCommitHooks() } }
91+
.onFailure { jooqSession.onSessionClose { jooqSession.executeRollbackHooks(it) } }
92+
.getOrElse { throw it } // JooqExtensions.kt shadows Result<T>.getOrThrow()... This is equivalent.
93+
}.also { jooqSession?.executePostCommitHooks() }
9494
} finally {
9595
jooqSession?.executeSessionCloseHooks()
9696
}
@@ -108,25 +108,25 @@ class JooqTransacter @JvmOverloads constructor(
108108
RenderMapping().withSchemata(
109109
MappedSchema()
110110
.withInput(jooqCodeGenSchemaName)
111-
.withOutput(datasourceConfig.database)
112-
)
111+
.withOutput(datasourceConfig.database),
112+
),
113113
)
114114

115115
val connectionProvider = IsolationLevelAwareConnectionProvider(
116116
dataSourceConnectionProvider = DataSourceConnectionProvider(dataSourceService.dataSource),
117-
transacterOptions = options
117+
transacterOptions = options,
118118
)
119119

120120
return DSL.using(connectionProvider, datasourceConfig.type.toSqlDialect(), settings)
121121
.apply {
122122
configuration().set(
123123
DefaultTransactionProvider(
124124
configuration().connectionProvider(),
125-
false
126-
)
125+
false,
126+
),
127127
).apply {
128128
val executeListeners = mutableListOf(
129-
DefaultExecuteListenerProvider(AvoidUsingSelectStarListener())
129+
DefaultExecuteListenerProvider(AvoidUsingSelectStarListener()),
130130
)
131131
if ("true" == datasourceConfig.show_sql) {
132132
executeListeners.add(DefaultExecuteListenerProvider(JooqSQLLogger()))
@@ -138,8 +138,8 @@ class JooqTransacter @JvmOverloads constructor(
138138
JooqTimestampRecordListener(
139139
clock = clock,
140140
createdAtColumnName = jooqTimestampRecordListenerOptions.createdAtColumnName,
141-
updatedAtColumnName = jooqTimestampRecordListenerOptions.updatedAtColumnName
142-
)
141+
updatedAtColumnName = jooqTimestampRecordListenerOptions.updatedAtColumnName,
142+
),
143143
)
144144
}
145145
}.apply(jooqConfigExtension)

0 commit comments

Comments
 (0)