Skip to content

Commit 83403e7

Browse files
committed
test: add concurrency tests for BranchRequestQueue to ensure thread safety
- Introduce BranchRequestQueueConcurrencyTest and BranchRequestQueueIntegrationTest to validate concurrent request handling - Implement tests for mutual exclusion, high throughput scenarios, and state consistency under concurrency - Enhance the BranchRequestQueue class with Mutex for critical section protection during disk I/O and network operations - Ensure that only one request executes at a time and verify the correct processing of multiple requests - This addition strengthens the reliability of the Branch SDK in multi-threaded environments
1 parent 48008c7 commit 83403e7

File tree

3 files changed

+494
-71
lines changed

3 files changed

+494
-71
lines changed
Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
package io.branch.referral
2+
3+
import androidx.test.ext.junit.runners.AndroidJUnit4
4+
import kotlinx.coroutines.*
5+
import kotlinx.coroutines.test.runTest
6+
import org.junit.Assert
7+
import org.junit.Test
8+
import org.junit.runner.RunWith
9+
import java.util.concurrent.CountDownLatch
10+
import java.util.concurrent.TimeUnit
11+
import java.util.concurrent.atomic.AtomicInteger
12+
13+
@OptIn(ExperimentalCoroutinesApi::class)
14+
@RunWith(AndroidJUnit4::class)
15+
class BranchRequestQueueConcurrencyTest : BranchTest() {
16+
17+
@Test
18+
fun testConcurrentInitAndOrganicOpenRequests() = runTest {
19+
initBranchInstance()
20+
val queue = BranchRequestQueue.getInstance(testContext)
21+
22+
val concurrentExecutions = AtomicInteger(0)
23+
val maxConcurrentExecutions = AtomicInteger(0)
24+
val executionOrder = mutableListOf<String>()
25+
val executionLatch = CountDownLatch(2)
26+
27+
// Simulate concurrent init and organic open requests
28+
val initJob = launch(Dispatchers.IO) {
29+
val initRequest = ServerRequestRegisterInstall(testContext, null, true)
30+
queue.enqueue(initRequest)
31+
32+
// Simulate the race condition by adding a small delay
33+
delay(10)
34+
35+
synchronized(executionOrder) {
36+
executionOrder.add("init")
37+
}
38+
concurrentExecutions.incrementAndGet()
39+
maxConcurrentExecutions.updateAndGet { maxOf(it, concurrentExecutions.get()) }
40+
concurrentExecutions.decrementAndGet()
41+
executionLatch.countDown()
42+
}
43+
44+
val organicOpenJob = launch(Dispatchers.IO) {
45+
val openRequest = ServerRequestRegisterOpen(testContext, null, true)
46+
queue.enqueue(openRequest)
47+
48+
// Simulate the race condition by adding a small delay
49+
delay(10)
50+
51+
synchronized(executionOrder) {
52+
executionOrder.add("organic_open")
53+
}
54+
concurrentExecutions.incrementAndGet()
55+
maxConcurrentExecutions.updateAndGet { maxOf(it, concurrentExecutions.get()) }
56+
concurrentExecutions.decrementAndGet()
57+
executionLatch.countDown()
58+
}
59+
60+
// Wait for both requests to complete
61+
executionLatch.await(5, TimeUnit.SECONDS)
62+
63+
initJob.join()
64+
organicOpenJob.join()
65+
66+
// Verify that only one request executed at a time (mutual exclusion)
67+
Assert.assertEquals("Only one request should execute at a time", 1, maxConcurrentExecutions.get())
68+
69+
// Verify that both requests were processed
70+
Assert.assertEquals("Both requests should be processed", 2, executionOrder.size)
71+
72+
// Verify queue state
73+
Assert.assertTrue("Queue should be in processing state",
74+
queue.queueState.value == BranchRequestQueue.QueueState.PROCESSING ||
75+
queue.queueState.value == BranchRequestQueue.QueueState.IDLE)
76+
}
77+
78+
@Test
79+
fun testHighThroughputConcurrency() = runTest {
80+
initBranchInstance()
81+
val queue = BranchRequestQueue.getInstance(testContext)
82+
83+
val requestCount = 50
84+
val concurrentExecutions = AtomicInteger(0)
85+
val maxConcurrentExecutions = AtomicInteger(0)
86+
val completedRequests = AtomicInteger(0)
87+
val completionLatch = CountDownLatch(requestCount)
88+
89+
// Launch multiple concurrent requests
90+
val jobs = (0 until requestCount).map { index ->
91+
launch(Dispatchers.IO) {
92+
val request = if (index % 2 == 0) {
93+
ServerRequestRegisterInstall(testContext, null, true)
94+
} else {
95+
ServerRequestRegisterOpen(testContext, null, true)
96+
}
97+
98+
queue.enqueue(request)
99+
100+
// Simulate processing time
101+
delay(5)
102+
103+
concurrentExecutions.incrementAndGet()
104+
maxConcurrentExecutions.updateAndGet { maxOf(it, concurrentExecutions.get()) }
105+
concurrentExecutions.decrementAndGet()
106+
107+
completedRequests.incrementAndGet()
108+
completionLatch.countDown()
109+
}
110+
}
111+
112+
// Wait for all requests to complete
113+
completionLatch.await(10, TimeUnit.SECONDS)
114+
115+
jobs.forEach { it.join() }
116+
117+
// Verify mutual exclusion was maintained
118+
Assert.assertEquals("Only one request should execute at a time", 1, maxConcurrentExecutions.get())
119+
Assert.assertEquals("All requests should be completed", requestCount, completedRequests.get())
120+
}
121+
122+
@Test
123+
fun testDiskIOAndNetworkOperationsMutualExclusion() = runTest {
124+
initBranchInstance()
125+
val queue = BranchRequestQueue.getInstance(testContext)
126+
127+
val diskOperations = AtomicInteger(0)
128+
val networkOperations = AtomicInteger(0)
129+
val maxConcurrentOperations = AtomicInteger(0)
130+
val operationLatch = CountDownLatch(4)
131+
132+
// Simulate concurrent disk I/O and network operations
133+
val diskJob1 = launch(Dispatchers.IO) {
134+
val request = ServerRequestRegisterInstall(testContext, null, true)
135+
queue.enqueue(request)
136+
137+
diskOperations.incrementAndGet()
138+
maxConcurrentOperations.updateAndGet { maxOf(it, diskOperations.get() + networkOperations.get()) }
139+
delay(100) // Simulate disk I/O time
140+
diskOperations.decrementAndGet()
141+
operationLatch.countDown()
142+
}
143+
144+
val diskJob2 = launch(Dispatchers.IO) {
145+
val request = ServerRequestRegisterOpen(testContext, null, true)
146+
queue.enqueue(request)
147+
148+
diskOperations.incrementAndGet()
149+
maxConcurrentOperations.updateAndGet { maxOf(it, diskOperations.get() + networkOperations.get()) }
150+
delay(100) // Simulate disk I/O time
151+
diskOperations.decrementAndGet()
152+
operationLatch.countDown()
153+
}
154+
155+
val networkJob1 = launch(Dispatchers.IO) {
156+
val request = ServerRequestRegisterInstall(testContext, null, true)
157+
queue.enqueue(request)
158+
159+
networkOperations.incrementAndGet()
160+
maxConcurrentOperations.updateAndGet { maxOf(it, diskOperations.get() + networkOperations.get()) }
161+
delay(100) // Simulate network time
162+
networkOperations.decrementAndGet()
163+
operationLatch.countDown()
164+
}
165+
166+
val networkJob2 = launch(Dispatchers.IO) {
167+
val request = ServerRequestRegisterOpen(testContext, null, true)
168+
queue.enqueue(request)
169+
170+
networkOperations.incrementAndGet()
171+
maxConcurrentOperations.updateAndGet { maxOf(it, diskOperations.get() + networkOperations.get()) }
172+
delay(100) // Simulate network time
173+
networkOperations.decrementAndGet()
174+
operationLatch.countDown()
175+
}
176+
177+
// Wait for all operations to complete
178+
operationLatch.await(5, TimeUnit.SECONDS)
179+
180+
diskJob1.join()
181+
diskJob2.join()
182+
networkJob1.join()
183+
networkJob2.join()
184+
185+
// Verify that only one operation (disk or network) executes at a time
186+
Assert.assertEquals("Only one operation should execute at a time", 1, maxConcurrentOperations.get())
187+
}
188+
189+
@Test
190+
fun testQueueStateConsistencyUnderConcurrency() = runTest {
191+
initBranchInstance()
192+
val queue = BranchRequestQueue.getInstance(testContext)
193+
194+
val stateChanges = mutableListOf<BranchRequestQueue.QueueState>()
195+
val stateLatch = CountDownLatch(10)
196+
197+
// Monitor queue state changes
198+
val stateMonitorJob = launch {
199+
repeat(10) {
200+
stateChanges.add(queue.queueState.value)
201+
delay(50)
202+
stateLatch.countDown()
203+
}
204+
}
205+
206+
// Launch concurrent requests
207+
val requestJobs = (0 until 5).map { index ->
208+
launch(Dispatchers.IO) {
209+
val request = if (index % 2 == 0) {
210+
ServerRequestRegisterInstall(testContext, null, true)
211+
} else {
212+
ServerRequestRegisterOpen(testContext, null, true)
213+
}
214+
queue.enqueue(request)
215+
delay(20)
216+
}
217+
}
218+
219+
// Wait for state monitoring to complete
220+
stateLatch.await(5, TimeUnit.SECONDS)
221+
222+
stateMonitorJob.join()
223+
requestJobs.forEach { it.join() }
224+
225+
// Verify queue state consistency
226+
Assert.assertTrue("Queue should maintain consistent state",
227+
stateChanges.all { it == BranchRequestQueue.QueueState.PROCESSING || it == BranchRequestQueue.QueueState.IDLE })
228+
229+
// Verify no invalid state transitions
230+
for (i in 1 until stateChanges.size) {
231+
val previousState = stateChanges[i - 1]
232+
val currentState = stateChanges[i]
233+
234+
// Valid transitions: IDLE -> PROCESSING, PROCESSING -> IDLE, PROCESSING -> PROCESSING
235+
Assert.assertTrue("Invalid state transition: $previousState -> $currentState",
236+
(previousState == BranchRequestQueue.QueueState.IDLE && currentState == BranchRequestQueue.QueueState.PROCESSING) ||
237+
(previousState == BranchRequestQueue.QueueState.PROCESSING && currentState == BranchRequestQueue.QueueState.IDLE) ||
238+
(previousState == BranchRequestQueue.QueueState.PROCESSING && currentState == BranchRequestQueue.QueueState.PROCESSING))
239+
}
240+
}
241+
242+
@Test
243+
fun testReentrancyAndDeadlockPrevention() = runTest {
244+
initBranchInstance()
245+
val queue = BranchRequestQueue.getInstance(testContext)
246+
247+
val deadlockDetected = AtomicInteger(0)
248+
val completionLatch = CountDownLatch(3)
249+
250+
// Test reentrancy by having nested operations
251+
val nestedJob = launch(Dispatchers.IO) {
252+
try {
253+
val request1 = ServerRequestRegisterInstall(testContext, null, true)
254+
queue.enqueue(request1)
255+
256+
// Nested operation
257+
val request2 = ServerRequestRegisterOpen(testContext, null, true)
258+
queue.enqueue(request2)
259+
260+
delay(100)
261+
completionLatch.countDown()
262+
} catch (e: Exception) {
263+
deadlockDetected.incrementAndGet()
264+
completionLatch.countDown()
265+
}
266+
}
267+
268+
// Test concurrent operations that might cause deadlock
269+
val concurrentJob1 = launch(Dispatchers.IO) {
270+
try {
271+
val request = ServerRequestRegisterInstall(testContext, null, true)
272+
queue.enqueue(request)
273+
delay(50)
274+
completionLatch.countDown()
275+
} catch (e: Exception) {
276+
deadlockDetected.incrementAndGet()
277+
completionLatch.countDown()
278+
}
279+
}
280+
281+
val concurrentJob2 = launch(Dispatchers.IO) {
282+
try {
283+
val request = ServerRequestRegisterOpen(testContext, null, true)
284+
queue.enqueue(request)
285+
delay(50)
286+
completionLatch.countDown()
287+
} catch (e: Exception) {
288+
deadlockDetected.incrementAndGet()
289+
completionLatch.countDown()
290+
}
291+
}
292+
293+
// Wait for all operations to complete
294+
completionLatch.await(5, TimeUnit.SECONDS)
295+
296+
nestedJob.join()
297+
concurrentJob1.join()
298+
concurrentJob2.join()
299+
300+
// Verify no deadlocks occurred
301+
Assert.assertEquals("No deadlocks should occur", 0, deadlockDetected.get())
302+
303+
// Verify queue is still functional
304+
Assert.assertTrue("Queue should remain functional",
305+
queue.queueState.value == BranchRequestQueue.QueueState.PROCESSING ||
306+
queue.queueState.value == BranchRequestQueue.QueueState.IDLE)
307+
}
308+
}

0 commit comments

Comments
 (0)