Skip to content

Commit c40b292

Browse files
committed
Add ThreadContextElement tests
1 parent 66d60ff commit c40b292

File tree

3 files changed

+261
-62
lines changed

3 files changed

+261
-62
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package kotlinx.coroutines
2+
3+
import kotlinx.coroutines.testing.*
4+
import kotlin.coroutines.*
5+
import kotlin.test.*
6+
import kotlinx.coroutines.internal.*
7+
8+
class ThreadContextElementTest : TestBase() {
9+
interface TestThreadContextElement : ThreadContextElement<Int> {
10+
companion object Key : CoroutineContext.Key<TestThreadContextElement>
11+
}
12+
13+
@Test
14+
fun updatesAndRestores() = runTest {
15+
expect(1)
16+
var updateCount = 0
17+
var restoreCount = 0
18+
val threadContextElement = object : TestThreadContextElement {
19+
override fun updateThreadContext(context: CoroutineContext): Int {
20+
updateCount++
21+
return 0
22+
}
23+
24+
override fun restoreThreadContext(context: CoroutineContext, oldState: Int) {
25+
restoreCount++
26+
}
27+
28+
override val key: CoroutineContext.Key<*>
29+
get() = TestThreadContextElement.Key
30+
}
31+
launch(Dispatchers.Unconfined + threadContextElement) {
32+
assertEquals(1, updateCount)
33+
assertEquals(0, restoreCount)
34+
}
35+
assertEquals(1, updateCount)
36+
assertEquals(1, restoreCount)
37+
finish(2)
38+
}
39+
40+
@Test
41+
fun testUndispatched() = runTest {
42+
val exceptionHandler = coroutineContext[CoroutineExceptionHandler]!!
43+
val data = MyData()
44+
val element = MyElement(data)
45+
val job = GlobalScope.launch(
46+
context = Dispatchers.Default + exceptionHandler + element,
47+
start = CoroutineStart.UNDISPATCHED
48+
) {
49+
assertSame(data, threadContextElementThreadLocal.get())
50+
yield()
51+
assertSame(data, threadContextElementThreadLocal.get())
52+
}
53+
assertNull(threadContextElementThreadLocal.get())
54+
job.join()
55+
assertNull(threadContextElementThreadLocal.get())
56+
}
57+
}
58+
59+
internal class MyData
60+
61+
// declare thread local variable holding MyData
62+
internal val threadContextElementThreadLocal = commonThreadLocal<MyData?>(Symbol("ThreadContextElementTest"))
63+
64+
// declare context element holding MyData
65+
internal class MyElement(val data: MyData) : ThreadContextElement<MyData?> {
66+
// declare companion object for a key of this element in coroutine context
67+
companion object Key : CoroutineContext.Key<MyElement>
68+
69+
// provide the key of the corresponding context element
70+
override val key: CoroutineContext.Key<MyElement>
71+
get() = Key
72+
73+
// this is invoked before coroutine is resumed on current thread
74+
override fun updateThreadContext(context: CoroutineContext): MyData? {
75+
val oldState = threadContextElementThreadLocal.get()
76+
threadContextElementThreadLocal.set(data)
77+
return oldState
78+
}
79+
80+
// this is invoked after coroutine has suspended on current thread
81+
override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) {
82+
threadContextElementThreadLocal.set(oldState)
83+
}
84+
}
85+

Diff for: kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt renamed to kotlinx-coroutines-core/jvm/test/ThreadContextElementJvmTest.kt

+35-62
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import kotlin.coroutines.*
66
import kotlin.test.*
77
import kotlinx.coroutines.flow.*
88

9-
class ThreadContextElementTest : TestBase() {
9+
class ThreadContextElementJvmTest : TestBase() {
1010

1111
@Test
1212
fun testExample() = runTest {
@@ -15,23 +15,23 @@ class ThreadContextElementTest : TestBase() {
1515
val mainThread = Thread.currentThread()
1616
val data = MyData()
1717
val element = MyElement(data)
18-
assertNull(myThreadLocal.get())
18+
assertNull(threadContextElementThreadLocal.get())
1919
val job = GlobalScope.launch(element + exceptionHandler) {
2020
assertTrue(mainThread != Thread.currentThread())
2121
assertSame(element, coroutineContext[MyElement])
22-
assertSame(data, myThreadLocal.get())
22+
assertSame(data, threadContextElementThreadLocal.get())
2323
withContext(mainDispatcher) {
2424
assertSame(mainThread, Thread.currentThread())
2525
assertSame(element, coroutineContext[MyElement])
26-
assertSame(data, myThreadLocal.get())
26+
assertSame(data, threadContextElementThreadLocal.get())
2727
}
2828
assertTrue(mainThread != Thread.currentThread())
2929
assertSame(element, coroutineContext[MyElement])
30-
assertSame(data, myThreadLocal.get())
30+
assertSame(data, threadContextElementThreadLocal.get())
3131
}
32-
assertNull(myThreadLocal.get())
32+
assertNull(threadContextElementThreadLocal.get())
3333
job.join()
34-
assertNull(myThreadLocal.get())
34+
assertNull(threadContextElementThreadLocal.get())
3535
}
3636

3737
@Test
@@ -43,13 +43,13 @@ class ThreadContextElementTest : TestBase() {
4343
context = Dispatchers.Default + exceptionHandler + element,
4444
start = CoroutineStart.UNDISPATCHED
4545
) {
46-
assertSame(data, myThreadLocal.get())
46+
assertSame(data, threadContextElementThreadLocal.get())
4747
yield()
48-
assertSame(data, myThreadLocal.get())
48+
assertSame(data, threadContextElementThreadLocal.get())
4949
}
50-
assertNull(myThreadLocal.get())
50+
assertNull(threadContextElementThreadLocal.get())
5151
job.join()
52-
assertNull(myThreadLocal.get())
52+
assertNull(threadContextElementThreadLocal.get())
5353
}
5454

5555
@Test
@@ -58,22 +58,22 @@ class ThreadContextElementTest : TestBase() {
5858
newSingleThreadContext("withContext").use {
5959
val data = MyData()
6060
GlobalScope.async(Dispatchers.Default + MyElement(data)) {
61-
assertSame(data, myThreadLocal.get())
61+
assertSame(data, threadContextElementThreadLocal.get())
6262
expect(2)
6363

6464
val newData = MyData()
6565
GlobalScope.async(it + MyElement(newData)) {
66-
assertSame(newData, myThreadLocal.get())
66+
assertSame(newData, threadContextElementThreadLocal.get())
6767
expect(3)
6868
}.await()
6969

7070
withContext(it + MyElement(newData)) {
71-
assertSame(newData, myThreadLocal.get())
71+
assertSame(newData, threadContextElementThreadLocal.get())
7272
expect(4)
7373
}
7474

7575
GlobalScope.async(it) {
76-
assertNull(myThreadLocal.get())
76+
assertNull(threadContextElementThreadLocal.get())
7777
expect(5)
7878
}.await()
7979

@@ -126,31 +126,31 @@ class ThreadContextElementTest : TestBase() {
126126
newFixedThreadPoolContext(nThreads = 4, name = "withContext").use {
127127
withContext(it + CopyForChildCoroutineElement(MyData())) {
128128
val forBlockData = MyData()
129-
myThreadLocal.setForBlock(forBlockData) {
130-
assertSame(myThreadLocal.get(), forBlockData)
129+
threadContextElementThreadLocal.setForBlock(forBlockData) {
130+
assertSame(threadContextElementThreadLocal.get(), forBlockData)
131131
launch {
132-
assertSame(myThreadLocal.get(), forBlockData)
132+
assertSame(threadContextElementThreadLocal.get(), forBlockData)
133133
}
134134
launch {
135-
assertSame(myThreadLocal.get(), forBlockData)
135+
assertSame(threadContextElementThreadLocal.get(), forBlockData)
136136
// Modify value in child coroutine. Writes to the ThreadLocal and
137137
// the (copied) ThreadLocalElement's memory are not visible to peer or
138138
// ancestor coroutines, so this write is both threadsafe and coroutinesafe.
139139
val innerCoroutineData = MyData()
140-
myThreadLocal.setForBlock(innerCoroutineData) {
141-
assertSame(myThreadLocal.get(), innerCoroutineData)
140+
threadContextElementThreadLocal.setForBlock(innerCoroutineData) {
141+
assertSame(threadContextElementThreadLocal.get(), innerCoroutineData)
142142
}
143-
assertSame(myThreadLocal.get(), forBlockData) // Asserts value was restored.
143+
assertSame(threadContextElementThreadLocal.get(), forBlockData) // Asserts value was restored.
144144
}
145145
launch {
146146
val innerCoroutineData = MyData()
147-
myThreadLocal.setForBlock(innerCoroutineData) {
148-
assertSame(myThreadLocal.get(), innerCoroutineData)
147+
threadContextElementThreadLocal.setForBlock(innerCoroutineData) {
148+
assertSame(threadContextElementThreadLocal.get(), innerCoroutineData)
149149
}
150-
assertSame(myThreadLocal.get(), forBlockData)
150+
assertSame(threadContextElementThreadLocal.get(), forBlockData)
151151
}
152152
}
153-
assertNull(myThreadLocal.get()) // Asserts value was restored to its origin
153+
assertNull(threadContextElementThreadLocal.get()) // Asserts value was restored to its origin
154154
}
155155
}
156156
}
@@ -193,58 +193,31 @@ class ThreadContextElementTest : TestBase() {
193193
@Test
194194
fun testThreadLocalFlowOn() = runTest {
195195
val myData = MyData()
196-
myThreadLocal.set(myData)
196+
threadContextElementThreadLocal.set(myData)
197197
expect(1)
198198
flow {
199-
assertEquals(myData, myThreadLocal.get())
199+
assertEquals(myData, threadContextElementThreadLocal.get())
200200
emit(1)
201201
}
202-
.flowOn(myThreadLocal.asContextElement() + Dispatchers.Default)
202+
.flowOn(threadContextElementThreadLocal.asContextElement() + Dispatchers.Default)
203203
.single()
204-
myThreadLocal.set(null)
204+
threadContextElementThreadLocal.set(null)
205205
finish(2)
206206
}
207207
}
208208

209-
class MyData
210-
211-
// declare thread local variable holding MyData
212-
private val myThreadLocal = ThreadLocal<MyData?>()
213-
214-
// declare context element holding MyData
215-
class MyElement(val data: MyData) : ThreadContextElement<MyData?> {
216-
// declare companion object for a key of this element in coroutine context
217-
companion object Key : CoroutineContext.Key<MyElement>
218-
219-
// provide the key of the corresponding context element
220-
override val key: CoroutineContext.Key<MyElement>
221-
get() = Key
222-
223-
// this is invoked before coroutine is resumed on current thread
224-
override fun updateThreadContext(context: CoroutineContext): MyData? {
225-
val oldState = myThreadLocal.get()
226-
myThreadLocal.set(data)
227-
return oldState
228-
}
229-
230-
// this is invoked after coroutine has suspended on current thread
231-
override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) {
232-
myThreadLocal.set(oldState)
233-
}
234-
}
235-
236209
/**
237210
* A [ThreadContextElement] that implements copy semantics in [copyForChild].
238211
*/
239-
class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextElement<MyData?> {
212+
internal class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextElement<MyData?> {
240213
companion object Key : CoroutineContext.Key<CopyForChildCoroutineElement>
241214

242215
override val key: CoroutineContext.Key<CopyForChildCoroutineElement>
243216
get() = Key
244217

245218
override fun updateThreadContext(context: CoroutineContext): MyData? {
246-
val oldState = myThreadLocal.get()
247-
myThreadLocal.set(data)
219+
val oldState = threadContextElementThreadLocal.get()
220+
threadContextElementThreadLocal.set(data)
248221
return oldState
249222
}
250223

@@ -253,7 +226,7 @@ class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextEle
253226
}
254227

255228
override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) {
256-
myThreadLocal.set(oldState)
229+
threadContextElementThreadLocal.set(oldState)
257230
}
258231

259232
/**
@@ -268,7 +241,7 @@ class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextEle
268241
* thread and calls [restoreThreadContext].
269242
*/
270243
override fun copyForChild(): CopyForChildCoroutineElement {
271-
return CopyForChildCoroutineElement(myThreadLocal.get())
244+
return CopyForChildCoroutineElement(threadContextElementThreadLocal.get())
272245
}
273246
}
274247

0 commit comments

Comments
 (0)