Skip to content

Commit 83856b3

Browse files
authored
fix(mls): reset epoch when the conversation is reset (#3658)
Add support for including the `epoch` in MLS group state updates to ensure accurate state transitions. Update related test cases and database schema accordingly.
1 parent 59f0084 commit 83856b3

File tree

8 files changed

+58
-21
lines changed

8 files changed

+58
-21
lines changed

logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ interface MLSConversationRepository {
223223
suspend fun updateGroupIdAndState(
224224
conversationId: ConversationId,
225225
newGroupId: GroupID,
226+
newEpoch: Long,
226227
groupState: ConversationEntity.GroupState = ConversationEntity.GroupState.PENDING_JOIN
227228
): Either<CoreFailure, Unit>
228229
}
@@ -837,12 +838,14 @@ internal class MLSConversationDataSource(
837838
override suspend fun updateGroupIdAndState(
838839
conversationId: ConversationId,
839840
newGroupId: GroupID,
841+
newEpoch: Long,
840842
groupState: ConversationEntity.GroupState
841843
): Either<CoreFailure, Unit> =
842844
wrapStorageRequest {
843845
conversationDAO.updateMLSGroupIdAndState(
844846
conversationId.toDao(),
845847
idMapper.toCryptoModel(newGroupId),
848+
newEpoch,
846849
groupState
847850
)
848851
}

logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MLSResetConversationEventHandler.kt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ internal class MLSResetConversationEventHandlerImpl(
4343
event.newGroupID
4444
).getOrElse { false }
4545

46+
val newEpoch = if (hasEstablishedMLSGroup) {
47+
mlsContext.conversationEpoch(event.newGroupID.value).toLong()
48+
} else {
49+
0L
50+
}
51+
4652
val newState = if (hasEstablishedMLSGroup) {
4753
// already have the group, no need to join
4854
// can mean that the welcome event arrived before the reset
@@ -51,10 +57,12 @@ internal class MLSResetConversationEventHandlerImpl(
5157
// update local db with the new group id and set the conversation as not established
5258
ConversationEntity.GroupState.PENDING_AFTER_RESET
5359
}
60+
5461
mlsConversationRepository.updateGroupIdAndState(
5562
event.conversationId,
5663
event.newGroupID,
57-
groupState = newState
64+
groupState = newState,
65+
newEpoch = newEpoch,
5866
)
5967
}
6068
}

logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,10 @@ import com.wire.kalium.cryptography.E2EIClient
3030
import com.wire.kalium.cryptography.ExternalSenderKey
3131
import com.wire.kalium.cryptography.GroupInfoBundle
3232
import com.wire.kalium.cryptography.GroupInfoEncryptionType
33-
import com.wire.kalium.cryptography.MLSClient
3433
import com.wire.kalium.cryptography.RatchetTreeType
3534
import com.wire.kalium.cryptography.RotateBundle
3635
import com.wire.kalium.cryptography.WelcomeBundle
3736
import com.wire.kalium.cryptography.WireIdentity
38-
import com.wire.kalium.logic.data.client.MLSClientProvider
3937
import com.wire.kalium.logic.data.client.toCrypto
4038
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.CIPHER_SUITE
4139
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.CRYPTO_CLIENT_ID
@@ -1679,13 +1677,13 @@ class MLSConversationRepositoryTest {
16791677

16801678
suspend fun withUpdateMLSGroupIdAndStateSuccessful() = apply {
16811679
coEvery {
1682-
conversationDAO.updateMLSGroupIdAndState(any(), any(), any())
1680+
conversationDAO.updateMLSGroupIdAndState(any(), any(), any(), any())
16831681
}.returns(Unit)
16841682
}
16851683

16861684
suspend fun withUpdateMLSGroupIdAndStateFailing(failure: StorageFailure.Generic) = apply {
16871685
coEvery {
1688-
conversationDAO.updateMLSGroupIdAndState(any(), any(), any())
1686+
conversationDAO.updateMLSGroupIdAndState(any(), any(), any(), any())
16891687
}.throws(failure.rootCause)
16901688
}
16911689

@@ -1785,13 +1783,14 @@ class MLSConversationRepositoryTest {
17851783
.withUpdateMLSGroupIdAndStateSuccessful()
17861784
.arrange()
17871785

1788-
val result = mlsConversationRepository.updateGroupIdAndState(conversationId, newGroupId)
1786+
val result = mlsConversationRepository.updateGroupIdAndState(conversationId, newGroupId, 0L)
17891787

17901788
result.shouldSucceed()
17911789
coVerify {
17921790
arrangement.conversationDAO.updateMLSGroupIdAndState(
17931791
conversationId.toDao(),
17941792
newGroupId.toCrypto(),
1793+
0L,
17951794
ConversationEntity.GroupState.PENDING_JOIN
17961795
)
17971796
}.wasInvoked(exactly = once)
@@ -1806,13 +1805,14 @@ class MLSConversationRepositoryTest {
18061805
.withUpdateMLSGroupIdAndStateSuccessful()
18071806
.arrange()
18081807

1809-
val result = mlsConversationRepository.updateGroupIdAndState(conversationId, newGroupId, customState)
1808+
val result = mlsConversationRepository.updateGroupIdAndState(conversationId, newGroupId, 0L, customState)
18101809

18111810
result.shouldSucceed()
18121811
coVerify {
18131812
arrangement.conversationDAO.updateMLSGroupIdAndState(
18141813
conversationId.toDao(),
18151814
newGroupId.toCrypto(),
1815+
0L,
18161816
customState
18171817
)
18181818
}.wasInvoked(exactly = once)
@@ -1827,7 +1827,7 @@ class MLSConversationRepositoryTest {
18271827
.withUpdateMLSGroupIdAndStateFailing(storageFailure)
18281828
.arrange()
18291829

1830-
val result = mlsConversationRepository.updateGroupIdAndState(conversationId, newGroupId)
1830+
val result = mlsConversationRepository.updateGroupIdAndState(conversationId, newGroupId, 0L)
18311831

18321832
result.shouldFail {
18331833
assertIs<StorageFailure.Generic>(it)
@@ -1850,13 +1850,14 @@ class MLSConversationRepositoryTest {
18501850
.withUpdateMLSGroupIdAndStateSuccessful()
18511851
.arrange()
18521852

1853-
val result = mlsConversationRepository.updateGroupIdAndState(conversationId, newGroupId, state)
1853+
val result = mlsConversationRepository.updateGroupIdAndState(conversationId, newGroupId, 0L,state)
18541854

18551855
result.shouldSucceed()
18561856
coVerify {
18571857
arrangement.conversationDAO.updateMLSGroupIdAndState(
18581858
conversationId.toDao(),
18591859
newGroupId.toCrypto(),
1860+
0L,
18601861
state
18611862
)
18621863
}.wasInvoked(exactly = once)

logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MLSResetConversationEventHandlerTest.kt

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class MLSResetConversationEventHandlerTest {
5858
}.wasNotInvoked()
5959

6060
coVerify {
61-
arrangement.mlsConversationRepository.updateGroupIdAndState(any(), any(), any())
61+
arrangement.mlsConversationRepository.updateGroupIdAndState(any(), any(), any(), any())
6262
}.wasNotInvoked()
6363
}
6464

@@ -85,17 +85,20 @@ class MLSResetConversationEventHandlerTest {
8585
arrangement.mlsConversationRepository.updateGroupIdAndState(
8686
eq(CONVERSATION_ID),
8787
eq(NEW_GROUP_ID),
88-
eq(ConversationEntity.GroupState.PENDING_AFTER_RESET)
88+
eq(0),
89+
eq(ConversationEntity.GroupState.PENDING_AFTER_RESET),
8990
)
9091
}.wasInvoked(exactly = once)
9192
}
9293

9394
@Test
9495
fun givenNewGroupAlreadyEstablished_whenHandlingEvent_thenShouldUpdateWithEstablishedState() =
9596
runTest {
97+
val newGroupEpoch = 42L
9698
val (arrangement, handler) = arrange {
9799
withLeaveGroupSucceeding()
98100
withHasEstablishedMLSGroupReturning(true)
101+
withNewGroupEpoch(newGroupEpoch)
99102
withUpdateGroupIdAndStateSucceeding()
100103
}
101104

@@ -116,6 +119,7 @@ class MLSResetConversationEventHandlerTest {
116119
arrangement.mlsConversationRepository.updateGroupIdAndState(
117120
eq(CONVERSATION_ID),
118121
eq(NEW_GROUP_ID),
122+
eq(newGroupEpoch),
119123
eq(ConversationEntity.GroupState.ESTABLISHED)
120124
)
121125
}.wasInvoked(exactly = once)
@@ -147,6 +151,7 @@ class MLSResetConversationEventHandlerTest {
147151
arrangement.mlsConversationRepository.updateGroupIdAndState(
148152
eq(CONVERSATION_ID),
149153
eq(NEW_GROUP_ID),
154+
eq(0L),
150155
eq(ConversationEntity.GroupState.PENDING_AFTER_RESET)
151156
)
152157
}.wasInvoked(exactly = once)
@@ -173,6 +178,7 @@ class MLSResetConversationEventHandlerTest {
173178
arrangement.mlsConversationRepository.updateGroupIdAndState(
174179
matches { it == event.conversationId },
175180
matches { it == event.newGroupID },
181+
eq(0L),
176182
matches { it == ConversationEntity.GroupState.PENDING_AFTER_RESET }
177183
)
178184
}.wasInvoked(exactly = once)
@@ -193,16 +199,19 @@ class MLSResetConversationEventHandlerTest {
193199
arrangement.mlsConversationRepository.updateGroupIdAndState(
194200
eq(CONVERSATION_ID),
195201
eq(NEW_GROUP_ID),
202+
eq(0L),
196203
eq(ConversationEntity.GroupState.PENDING_AFTER_RESET)
197204
)
198205
}.wasInvoked(exactly = once)
199206
}
200207

201208
@Test
202-
fun givenAllSucceeds_whenHandlingEvent_thenShouldLeaveGroupAndUpdateState() = runTest {
209+
fun givenAllSucceedsAndGroupIsEstablished_whenHandlingEvent_thenShouldLeaveGroupAndUpdateState() = runTest {
210+
val newGroupEpoch = 44L
203211
val (arrangement, handler) = arrange {
204212
withLeaveGroupSucceeding()
205213
withHasEstablishedMLSGroupReturning(true)
214+
withNewGroupEpoch(newGroupEpoch)
206215
withUpdateGroupIdAndStateSucceeding()
207216
}
208217

@@ -220,6 +229,7 @@ class MLSResetConversationEventHandlerTest {
220229
arrangement.mlsConversationRepository.updateGroupIdAndState(
221230
eq(CONVERSATION_ID),
222231
eq(NEW_GROUP_ID),
232+
eq(newGroupEpoch),
223233
eq(ConversationEntity.GroupState.ESTABLISHED)
224234
)
225235
}.wasInvoked(exactly = once)
@@ -248,6 +258,12 @@ class MLSResetConversationEventHandlerTest {
248258
}.returns(Either.Right(hasGroup))
249259
}
250260

261+
suspend fun withNewGroupEpoch(newGroupEpoch: Long) = apply {
262+
coEvery {
263+
mlsContext.conversationEpoch(any())
264+
}.returns(newGroupEpoch.toULong())
265+
}
266+
251267
suspend fun withHasEstablishedMLSGroupFailing(failure: CoreFailure) = apply {
252268
coEvery {
253269
mlsConversationRepository.hasEstablishedMLSGroup(any(), any())
@@ -256,13 +272,13 @@ class MLSResetConversationEventHandlerTest {
256272

257273
suspend fun withUpdateGroupIdAndStateSucceeding() = apply {
258274
coEvery {
259-
mlsConversationRepository.updateGroupIdAndState(any(), any(), any())
275+
mlsConversationRepository.updateGroupIdAndState(any(), any(), any(), any())
260276
}.returns(Either.Right(Unit))
261277
}
262278

263279
suspend fun withUpdateGroupIdAndStateFailing(failure: CoreFailure) = apply {
264280
coEvery {
265-
mlsConversationRepository.updateGroupIdAndState(any(), any(), any())
281+
mlsConversationRepository.updateGroupIdAndState(any(), any(), any(), any())
266282
}.returns(Either.Left(failure))
267283
}
268284

persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ SET mls_group_id = :new_group_id,
154154
WHEN mls_group_state = 'ESTABLISHED' THEN :mls_group_state
155155
WHEN :mls_group_state IN ('ESTABLISHED', 'PENDING_AFTER_RESET') THEN :mls_group_state
156156
ELSE mls_group_state
157-
END
157+
END,
158+
mls_epoch = :new_epoch
158159
WHERE qualified_id = :conversation_id;
159160

160161
updateConversationNotificationsDateWithTheLastMessage:

persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ interface ConversationDAO {
155155
suspend fun updateMLSGroupIdAndState(
156156
conversationId: QualifiedIDEntity,
157157
newGroupId: String,
158+
newEpoch: Long,
158159
groupState: ConversationEntity.GroupState
159160
)
160161

persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,10 @@ internal class ConversationDAOImpl internal constructor(
253253
override suspend fun updateMLSGroupIdAndState(
254254
conversationId: QualifiedIDEntity,
255255
newGroupId: String,
256+
newEpoch: Long,
256257
groupState: ConversationEntity.GroupState
257258
) = withContext(coroutineContext) {
258-
conversationQueries.updateMLSGroupIdAndState(newGroupId, groupState, conversationId)
259+
conversationQueries.updateMLSGroupIdAndState(newGroupId, groupState, newEpoch, conversationId)
259260
}
260261

261262
override suspend fun updateConversationModifiedDate(qualifiedID: QualifiedIDEntity, date: Instant) = withContext(coroutineContext) {

persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConversationDAOTest.kt

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2413,19 +2413,21 @@ class ConversationDAOTest : BaseDatabaseTest() {
24132413
}
24142414

24152415
@Test
2416-
fun givenConversationExists_whenUpdateMLSGroupIdAndState_thenBothFieldsAreUpdated() = runTest(dispatcher) {
2416+
fun givenConversationExists_whenUpdateMLSGroupIdAndState_thenAllFieldsAreUpdated() = runTest(dispatcher) {
24172417
conversationDAO.insertConversation(conversationEntity2)
24182418

24192419
val newGroupId = "updated_group_id"
24202420
val newState = ConversationEntity.GroupState.ESTABLISHED
2421+
val newEpoch = 42L
24212422

2422-
conversationDAO.updateMLSGroupIdAndState(conversationEntity2.id, newGroupId, newState)
2423+
conversationDAO.updateMLSGroupIdAndState(conversationEntity2.id, newGroupId, newEpoch, newState)
24232424

24242425
val updatedConversation = conversationDAO.getConversationById(conversationEntity2.id)
24252426
assertNotNull(updatedConversation)
24262427
val protocolInfo = updatedConversation.protocolInfo as ConversationEntity.ProtocolInfo.MLS
24272428
assertEquals(newGroupId, protocolInfo.groupId)
24282429
assertEquals(newState, protocolInfo.groupState)
2430+
assertEquals(newEpoch, protocolInfo.epoch.toLong())
24292431
}
24302432

24312433
@Test
@@ -2444,23 +2446,26 @@ class ConversationDAOTest : BaseDatabaseTest() {
24442446

24452447
val newGroupId = "new_group_id"
24462448
val newState = ConversationEntity.GroupState.PENDING_CREATION
2449+
val newEpoch = 44L
24472450

2448-
conversationDAO.updateMLSGroupIdAndState(originalConversation.id, newGroupId, newState)
2451+
conversationDAO.updateMLSGroupIdAndState(originalConversation.id, newGroupId, newEpoch, newState)
24492452

24502453
val updatedConversation = conversationDAO.getConversationById(originalConversation.id)
24512454
assertNotNull(updatedConversation)
24522455
val protocolInfo = updatedConversation.protocolInfo as ConversationEntity.ProtocolInfo.MLS
24532456
assertEquals(newGroupId, protocolInfo.groupId)
24542457
assertEquals(newState, protocolInfo.groupState)
2458+
assertEquals(newEpoch, protocolInfo.epoch.toLong())
24552459
}
24562460

24572461
@Test
24582462
fun givenNonExistentConversation_whenUpdateMLSGroupIdAndState_thenNoErrorOccurs() = runTest(dispatcher) {
24592463
val nonExistentId = QualifiedIDEntity("non_existent", "domain.com")
24602464
val newGroupId = "new_group_id"
24612465
val newState = ConversationEntity.GroupState.ESTABLISHED
2466+
val newEpoch = 44L
24622467

2463-
conversationDAO.updateMLSGroupIdAndState(nonExistentId, newGroupId, newState)
2468+
conversationDAO.updateMLSGroupIdAndState(nonExistentId, newGroupId, newEpoch, newState)
24642469

24652470
val conversation = conversationDAO.getConversationById(nonExistentId)
24662471
assertNull(conversation)
@@ -2481,8 +2486,9 @@ class ConversationDAOTest : BaseDatabaseTest() {
24812486
conversationDAO.insertConversation(conversation)
24822487

24832488
val newGroupId = "group_id_$index"
2489+
val newEpoch = index.toLong()
24842490

2485-
conversationDAO.updateMLSGroupIdAndState(conversationId, newGroupId, state)
2491+
conversationDAO.updateMLSGroupIdAndState(conversationId, newGroupId, newEpoch, state)
24862492

24872493
val updatedConversation = conversationDAO.getConversationById(conversationId)
24882494
assertNotNull(updatedConversation)

0 commit comments

Comments
 (0)