Skip to content

Commit b3d20a1

Browse files
authored
feat: add use case for migrating a single conversation to MLS (#4103)
1 parent 8dbbade commit b3d20a1

4 files changed

Lines changed: 269 additions & 10 deletions

File tree

logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,8 @@ import com.wire.kalium.logic.feature.conversation.ConversationsRecoveryManagerIm
274274
import com.wire.kalium.logic.feature.conversation.MLSConversationsRecoveryManager
275275
import com.wire.kalium.logic.feature.conversation.MLSConversationsRecoveryManagerImpl
276276
import com.wire.kalium.logic.feature.conversation.MLSFaultyKeysConversationsRepairUseCaseImpl
277+
import com.wire.kalium.logic.feature.conversation.MigrateConversationToMLSUseCase
278+
import com.wire.kalium.logic.feature.conversation.MigrateConversationToMLSUseCaseImpl
277279
import com.wire.kalium.logic.feature.conversation.ObserveOtherUserSecurityClassificationLabelUseCase
278280
import com.wire.kalium.logic.feature.conversation.ObserveOtherUserSecurityClassificationLabelUseCaseImpl
279281
import com.wire.kalium.logic.feature.conversation.ObserveSecurityClassificationLabelUseCase
@@ -2892,6 +2894,13 @@ public class UserSessionScope internal constructor(
28922894
kaliumConfigs = kaliumConfigs,
28932895
)
28942896

2897+
public val migrateConversationToMLS: MigrateConversationToMLSUseCase
2898+
get() = MigrateConversationToMLSUseCaseImpl(
2899+
mlsMigrator = mlsMigrator,
2900+
conversationRepository = conversationRepository,
2901+
coreCryptoTransactionProvider = cryptoTransactionProvider
2902+
)
2903+
28952904
public val longWork: LongWorkScope = LongWorkScope(
28962905
{ this },
28972906
{ slowSyncRepository.slowSyncStatus.map { it is SlowSyncStatus.Ongoing } }
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Wire
3+
* Copyright (C) 2026 Wire Swiss GmbH
4+
*
5+
* This program is free software: you can redistribute it and/or modify
6+
* it under the terms of the GNU General Public License as published by
7+
* the Free Software Foundation, either version 3 of the License, or
8+
* (at your option) any later version.
9+
*
10+
* This program is distributed in the hope that it will be useful,
11+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
* GNU General Public License for more details.
14+
*
15+
* You should have received a copy of the GNU General Public License
16+
* along with this program. If not, see http://www.gnu.org/licenses/.
17+
*/
18+
package com.wire.kalium.logic.feature.conversation
19+
20+
import com.wire.kalium.common.error.CoreFailure
21+
import com.wire.kalium.common.functional.Either
22+
import com.wire.kalium.common.functional.flatMap
23+
import com.wire.kalium.common.functional.fold
24+
import com.wire.kalium.logic.data.client.CryptoTransactionProvider
25+
import com.wire.kalium.logic.data.conversation.Conversation
26+
import com.wire.kalium.logic.data.conversation.ConversationRepository
27+
import com.wire.kalium.logic.data.id.ConversationId
28+
import com.wire.kalium.logic.feature.mlsmigration.MLSMigrator
29+
30+
/**
31+
* This use case will migrate a given conversation to use the MLS protocol
32+
*/
33+
public interface MigrateConversationToMLSUseCase {
34+
/**
35+
* @param conversationId the id of the conversation
36+
* @return the [Result] indicating a successful operation, otherwise a [CoreFailure]
37+
*/
38+
public suspend operator fun invoke(conversationId: ConversationId): Result
39+
40+
public sealed interface Result {
41+
public data object Success : Result
42+
public data class Failure(val cause: CoreFailure) : Result
43+
}
44+
}
45+
46+
internal class MigrateConversationToMLSUseCaseImpl(
47+
val mlsMigrator: MLSMigrator,
48+
val conversationRepository: ConversationRepository,
49+
val coreCryptoTransactionProvider: CryptoTransactionProvider
50+
) : MigrateConversationToMLSUseCase {
51+
override suspend fun invoke(conversationId: ConversationId): MigrateConversationToMLSUseCase.Result {
52+
return conversationRepository.getConversationProtocolInfo(conversationId)
53+
.flatMap { protocolInfo ->
54+
when (protocolInfo) {
55+
is Conversation.ProtocolInfo.MLS -> Either.Right(Unit)
56+
is Conversation.ProtocolInfo.Mixed -> {
57+
coreCryptoTransactionProvider.transaction {
58+
mlsMigrator.finalise(it, conversationId)
59+
}
60+
}
61+
Conversation.ProtocolInfo.Proteus -> {
62+
coreCryptoTransactionProvider.transaction { context ->
63+
mlsMigrator.migrate(context, conversationId).flatMap {
64+
mlsMigrator.finalise(context, conversationId)
65+
}
66+
}
67+
}
68+
}
69+
}.fold({
70+
MigrateConversationToMLSUseCase.Result.Failure(it)
71+
}, {
72+
MigrateConversationToMLSUseCase.Result.Success
73+
})
74+
}
75+
}

logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrator.kt

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ import com.wire.kalium.common.functional.Either
3333
import com.wire.kalium.common.functional.flatMap
3434
import com.wire.kalium.common.functional.flatMapLeft
3535
import com.wire.kalium.common.functional.fold
36-
import com.wire.kalium.common.functional.foldToEitherWhileRight
3736
import com.wire.kalium.common.functional.right
3837
import com.wire.kalium.common.logger.kaliumLogger
3938
import com.wire.kalium.cryptography.CryptoTransactionContext
@@ -45,6 +44,8 @@ internal interface MLSMigrator {
4544
suspend fun migrateProteusConversations(): Either<CoreFailure, Unit>
4645
suspend fun finaliseProteusConversations(): Either<CoreFailure, Unit>
4746
suspend fun finaliseAllProteusConversations(): Either<CoreFailure, Unit>
47+
suspend fun migrate(transactionContext: CryptoTransactionContext, conversationId: ConversationId): Either<CoreFailure, Unit>
48+
suspend fun finalise(transactionContext: CryptoTransactionContext, conversationId: ConversationId): Either<CoreFailure, Unit>
4849
}
4950

5051
@Suppress("LongParameterList")
@@ -68,9 +69,10 @@ internal class MLSMigratorImpl(
6869
conversationRepository.getConversationIds(Conversation.Type.Group.Regular, Protocol.PROTEUS, teamId)
6970
.flatMap { conversations ->
7071
transactionProvider.transaction("migrateProteusConversations") { transactionContext ->
71-
conversations.foldToEitherWhileRight(Unit) { conversationId, _ ->
72+
conversations.forEach { conversationId ->
7273
migrate(transactionContext, conversationId)
7374
}
75+
Either.Right(Unit)
7476
}
7577
}
7678
}
@@ -83,9 +85,10 @@ internal class MLSMigratorImpl(
8385
transactionProvider.transaction("finaliseAllProteusConversations") { transactionContext ->
8486
conversationRepository.getConversationIds(Conversation.Type.Group.Regular, Protocol.MIXED, teamId)
8587
.flatMap {
86-
it.foldToEitherWhileRight(Unit) { conversationId, _ ->
88+
it.forEach { conversationId ->
8789
finalise(transactionContext, conversationId)
8890
}
91+
Either.Right(Unit)
8992
}
9093
}
9194
}
@@ -99,15 +102,16 @@ internal class MLSMigratorImpl(
99102
transactionProvider.transaction("finaliseProteusConversations") { transactionContext ->
100103
conversationRepository.getTeamConversationIdsReadyToCompleteMigration(teamId)
101104
.flatMap {
102-
it.foldToEitherWhileRight(Unit) { conversationId, _ ->
105+
it.forEach { conversationId ->
103106
finalise(transactionContext, conversationId)
104107
}
108+
Either.Right(Unit)
105109
}
106110
}
107111
}
108112
}
109113

110-
private suspend fun migrate(transactionContext: CryptoTransactionContext, conversationId: ConversationId): Either<CoreFailure, Unit> {
114+
override suspend fun migrate(transactionContext: CryptoTransactionContext, conversationId: ConversationId): Either<CoreFailure, Unit> {
111115
kaliumLogger.i("migrating ${conversationId.toLogString()} to mixed")
112116
return updateConversationProtocol(transactionContext, conversationId, Protocol.MIXED, localOnly = false)
113117
.flatMap { updated ->
@@ -126,18 +130,18 @@ internal class MLSMigratorImpl(
126130
}
127131
kaliumLogger.i("migrating ${conversationId.toLogString()} to mls")
128132
establishConversation(transactionContext, conversationId)
129-
}.flatMapLeft {
130-
kaliumLogger.w("failed to migrate ${conversationId.toLogString()} to mixed: $it")
131-
Either.Right(Unit)
133+
}.flatMapLeft { failure ->
134+
kaliumLogger.w("failed to migrate ${conversationId.toLogString()} to mixed: $failure")
135+
Either.Left(failure)
132136
}
133137
}
134138

135-
private suspend fun finalise(transactionContext: CryptoTransactionContext, conversationId: ConversationId): Either<CoreFailure, Unit> {
139+
override suspend fun finalise(transactionContext: CryptoTransactionContext, conversationId: ConversationId): Either<CoreFailure, Unit> {
136140
kaliumLogger.i("finalising ${conversationId.toLogString()} to mls")
137141
return updateConversationProtocol(transactionContext, conversationId, Protocol.MLS, localOnly = false)
138142
.fold({ failure ->
139143
kaliumLogger.w("failed to finalise ${conversationId.toLogString()} to mls: $failure")
140-
Either.Right(Unit)
144+
Either.Left(failure)
141145
}, { updated ->
142146
if (updated) {
143147
systemMessageInserter.insertProtocolChangedSystemMessage(
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
/*
2+
* Wire
3+
* Copyright (C) 2026 Wire Swiss GmbH
4+
*
5+
* This program is free software: you can redistribute it and/or modify
6+
* it under the terms of the GNU General Public License as published by
7+
* the Free Software Foundation, either version 3 of the License, or
8+
* (at your option) any later version.
9+
*
10+
* This program is distributed in the hope that it will be useful,
11+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
* GNU General Public License for more details.
14+
*
15+
* You should have received a copy of the GNU General Public License
16+
* along with this program. If not, see http://www.gnu.org/licenses/.
17+
*/
18+
package com.wire.kalium.logic.feature.conversation
19+
20+
import com.wire.kalium.common.error.CoreFailure
21+
import com.wire.kalium.common.error.StorageFailure
22+
import com.wire.kalium.common.functional.Either
23+
import com.wire.kalium.logic.data.conversation.Conversation
24+
import com.wire.kalium.logic.data.conversation.ConversationRepository
25+
import com.wire.kalium.logic.data.id.ConversationId
26+
import com.wire.kalium.logic.feature.mlsmigration.MLSMigrator
27+
import com.wire.kalium.logic.framework.TestConversation
28+
import com.wire.kalium.logic.util.arrangement.provider.CryptoTransactionProviderArrangement
29+
import com.wire.kalium.logic.util.arrangement.provider.CryptoTransactionProviderArrangementMokkeryImpl
30+
import dev.mokkery.answering.returns
31+
import dev.mokkery.everySuspend
32+
import dev.mokkery.matcher.any
33+
import dev.mokkery.matcher.eq
34+
import dev.mokkery.mock
35+
import dev.mokkery.verify.VerifyMode
36+
import dev.mokkery.verifySuspend
37+
import kotlinx.coroutines.test.runTest
38+
import kotlin.test.Test
39+
import kotlin.test.assertEquals
40+
import kotlin.test.assertIs
41+
42+
class MigrateConversationToMLSUseCaseTest {
43+
44+
@Test
45+
fun givenMLSConversation_whenMigratingToMLS_thenReturnSuccessWithoutTransaction() = runTest {
46+
val (arrangement, useCase) = Arrangement()
47+
.withConversationProtocolInfo(TestConversation.MLS_PROTOCOL_INFO)
48+
.arrange()
49+
50+
val result = useCase(CONVERSATION_ID)
51+
52+
assertIs<MigrateConversationToMLSUseCase.Result.Success>(result)
53+
verifySuspend(VerifyMode.not) {
54+
arrangement.cryptoTransactionProvider.transaction<Unit>(any(), any())
55+
}
56+
verifySuspend(VerifyMode.not) {
57+
arrangement.mlsMigrator.migrate(any(), any())
58+
}
59+
verifySuspend(VerifyMode.not) {
60+
arrangement.mlsMigrator.finalise(any(), any())
61+
}
62+
}
63+
64+
@Test
65+
fun givenMixedConversation_whenMigratingToMLS_thenFinaliseConversation() = runTest {
66+
val (arrangement, useCase) = Arrangement()
67+
.withConversationProtocolInfo(TestConversation.MIXED_PROTOCOL_INFO)
68+
.withFinaliseResult(Either.Right(Unit))
69+
.arrange()
70+
71+
val result = useCase(CONVERSATION_ID)
72+
73+
assertIs<MigrateConversationToMLSUseCase.Result.Success>(result)
74+
verifySuspend {
75+
arrangement.mlsMigrator.finalise(eq(arrangement.transactionContext), eq(CONVERSATION_ID))
76+
}
77+
verifySuspend(VerifyMode.not) {
78+
arrangement.mlsMigrator.migrate(any(), any())
79+
}
80+
}
81+
82+
@Test
83+
fun givenProteusConversation_whenMigratingToMLS_thenMigrateAndFinaliseConversation() = runTest {
84+
val (arrangement, useCase) = Arrangement()
85+
.withConversationProtocolInfo(Conversation.ProtocolInfo.Proteus)
86+
.withMigrateResult(Either.Right(Unit))
87+
.withFinaliseResult(Either.Right(Unit))
88+
.arrange()
89+
90+
val result = useCase(CONVERSATION_ID)
91+
92+
assertIs<MigrateConversationToMLSUseCase.Result.Success>(result)
93+
verifySuspend {
94+
arrangement.mlsMigrator.migrate(eq(arrangement.transactionContext), eq(CONVERSATION_ID))
95+
}
96+
verifySuspend {
97+
arrangement.mlsMigrator.finalise(eq(arrangement.transactionContext), eq(CONVERSATION_ID))
98+
}
99+
}
100+
101+
@Test
102+
fun givenProteusConversationAndMigrationFails_whenMigratingToMLS_thenReturnFailureAndSkipFinalise() = runTest {
103+
val (arrangement, useCase) = Arrangement()
104+
.withConversationProtocolInfo(Conversation.ProtocolInfo.Proteus)
105+
.withMigrateResult(Either.Left(FAILURE))
106+
.arrange()
107+
108+
val result = useCase(CONVERSATION_ID)
109+
110+
assertIs<MigrateConversationToMLSUseCase.Result.Failure>(result)
111+
assertEquals(FAILURE, result.cause)
112+
verifySuspend(VerifyMode.not) {
113+
arrangement.mlsMigrator.finalise(any(), any())
114+
}
115+
}
116+
117+
@Test
118+
fun givenConversationProtocolInfoFails_whenMigratingToMLS_thenReturnFailure() = runTest {
119+
val (_, useCase) = Arrangement()
120+
.withConversationProtocolInfoFailure(FAILURE)
121+
.arrange()
122+
123+
val result = useCase(CONVERSATION_ID)
124+
125+
assertIs<MigrateConversationToMLSUseCase.Result.Failure>(result)
126+
assertEquals(FAILURE, result.cause)
127+
}
128+
129+
private class Arrangement : CryptoTransactionProviderArrangement by CryptoTransactionProviderArrangementMokkeryImpl() {
130+
val conversationRepository = mock<ConversationRepository>()
131+
val mlsMigrator = mock<MLSMigrator>()
132+
133+
suspend fun withConversationProtocolInfo(protocolInfo: Conversation.ProtocolInfo) = apply {
134+
everySuspend {
135+
conversationRepository.getConversationProtocolInfo(eq(CONVERSATION_ID))
136+
} returns Either.Right(protocolInfo)
137+
}
138+
139+
suspend fun withConversationProtocolInfoFailure(failure: StorageFailure) = apply {
140+
everySuspend {
141+
conversationRepository.getConversationProtocolInfo(eq(CONVERSATION_ID))
142+
} returns Either.Left(failure)
143+
}
144+
145+
suspend fun withMigrateResult(result: Either<CoreFailure, Unit>) = apply {
146+
everySuspend {
147+
mlsMigrator.migrate(eq(transactionContext), eq(CONVERSATION_ID))
148+
} returns result
149+
}
150+
151+
suspend fun withFinaliseResult(result: Either<CoreFailure, Unit>) = apply {
152+
everySuspend {
153+
mlsMigrator.finalise(eq(transactionContext), eq(CONVERSATION_ID))
154+
} returns result
155+
}
156+
157+
suspend fun arrange(): Pair<Arrangement, MigrateConversationToMLSUseCase> {
158+
withTransactionReturning(Either.Right(Unit))
159+
return this to MigrateConversationToMLSUseCaseImpl(
160+
mlsMigrator = mlsMigrator,
161+
conversationRepository = conversationRepository,
162+
coreCryptoTransactionProvider = cryptoTransactionProvider
163+
)
164+
}
165+
}
166+
167+
private companion object {
168+
val CONVERSATION_ID = ConversationId("conversation-id", "domain.example")
169+
val FAILURE = StorageFailure.DataNotFound
170+
}
171+
}

0 commit comments

Comments
 (0)