1
1
package kotlinx.coroutines
2
2
3
+ import kotlinx.atomicfu.atomic
4
+ import kotlinx.atomicfu.loop
3
5
import kotlinx.coroutines.testing.*
4
6
import kotlin.test.*
5
7
import kotlinx.coroutines.flow.*
8
+ import kotlinx.coroutines.internal.CommonThreadLocal
9
+ import kotlinx.coroutines.internal.Symbol
10
+ import kotlinx.coroutines.internal.commonThreadLocal
6
11
import kotlin.coroutines.*
7
12
8
13
class ThreadContextElementTest : TestBase () {
@@ -38,20 +43,18 @@ class ThreadContextElementTest: TestBase() {
38
43
*/
39
44
@Test
40
45
fun testWithContextJobAccess () = runTest {
41
- val executor = Executors .newSingleThreadExecutor()
42
46
// Emulate non-equal dispatchers
43
- val executor1 = object : ExecutorService by executor {}
44
- val executor2 = object : ExecutorService by executor {}
45
- val dispatcher1 = executor1.asCoroutineDispatcher()
46
- val dispatcher2 = executor2.asCoroutineDispatcher()
47
+ val dispatcher = Dispatchers .Default .limitedParallelism(1 )
48
+ val dispatcher1 = dispatcher.limitedParallelism(1 , " dispatcher1" )
49
+ val dispatcher2 = dispatcher.limitedParallelism(1 , " dispatcher2" )
47
50
val captor = JobCaptor ()
48
51
val manuallyCaptured = mutableListOf<String >()
49
52
50
53
fun registerUpdate (job : Job ? ) = manuallyCaptured.add(" Update: $job " )
51
54
fun registerRestore (job : Job ? ) = manuallyCaptured.add(" Restore: $job " )
52
55
53
56
var rootJob: Job ? = null
54
- runBlocking (captor + dispatcher1) {
57
+ withContext (captor + dispatcher1) {
55
58
rootJob = coroutineContext.job
56
59
registerUpdate(rootJob)
57
60
var undispatchedJob: Job ? = null
@@ -84,7 +87,6 @@ class ThreadContextElementTest: TestBase() {
84
87
val expected = manuallyCaptured.filter { it.startsWith(" Update: " ) }.joinToString(separator = " \n " )
85
88
val actual = captor.capturees.filter { it.startsWith(" Update: " ) }.joinToString(separator = " \n " )
86
89
assertEquals(expected, actual)
87
- executor.shutdownNow()
88
90
}
89
91
90
92
@Test
@@ -96,7 +98,7 @@ class ThreadContextElementTest: TestBase() {
96
98
assertEquals(myData, myThreadLocal.get())
97
99
emit(1 )
98
100
}
99
- .flowOn(myThreadLocal.asContextElement () + Dispatchers .Default )
101
+ .flowOn(myThreadLocal.asCtxElement () + Dispatchers .Default )
100
102
.single()
101
103
myThreadLocal.set(null )
102
104
finish(2 )
@@ -105,7 +107,7 @@ class ThreadContextElementTest: TestBase() {
105
107
106
108
class MyData
107
109
108
- class JobCaptor (val capturees : MutableList <String > = CopyOnWriteArrayList ()) : ThreadContextElement<Unit> {
110
+ private class JobCaptor (val capturees : CopyOnWriteList <String > = CopyOnWriteList ()) : ThreadContextElement<Unit> {
109
111
110
112
companion object Key : CoroutineContext.Key<MyElement>
111
113
@@ -121,7 +123,7 @@ class JobCaptor(val capturees: MutableList<String> = CopyOnWriteArrayList()) : T
121
123
}
122
124
123
125
// declare thread local variable holding MyData
124
- private val myThreadLocal = ThreadLocal <MyData ?>()
126
+ internal val myThreadLocal = commonThreadLocal <MyData ?>(Symbol ( " myElement " ) )
125
127
126
128
// declare context element holding MyData
127
129
class MyElement (val data : MyData ) : ThreadContextElement<MyData?> {
@@ -144,3 +146,44 @@ class MyElement(val data: MyData) : ThreadContextElement<MyData?> {
144
146
myThreadLocal.set(oldState)
145
147
}
146
148
}
149
+
150
+
151
+ private class CommonThreadLocalContextElement <T >(
152
+ private val threadLocal : CommonThreadLocal <T >,
153
+ private val value : T = threadLocal.get()
154
+ ): ThreadContextElement<T>, CoroutineContext.Key<CommonThreadLocalContextElement<T>> {
155
+ // provide the key of the corresponding context element
156
+ override val key: CoroutineContext .Key <CommonThreadLocalContextElement <T >>
157
+ get() = this
158
+
159
+ // this is invoked before coroutine is resumed on current thread
160
+ override fun updateThreadContext (context : CoroutineContext ): T {
161
+ val oldState = threadLocal.get()
162
+ threadLocal.set(value)
163
+ return oldState
164
+ }
165
+
166
+ // this is invoked after coroutine has suspended on current thread
167
+ override fun restoreThreadContext (context : CoroutineContext , oldState : T ) {
168
+ threadLocal.set(oldState)
169
+ }
170
+ }
171
+
172
+ // overload resolution issues if this is called `asContextElement`
173
+ internal fun <T > CommonThreadLocal<T>.asCtxElement (value : T = get()): ThreadContextElement <T > =
174
+ CommonThreadLocalContextElement (this , value)
175
+
176
+ private class CopyOnWriteList <T > private constructor(list : List <T >) {
177
+ private val field = atomic(list)
178
+
179
+ constructor () : this (emptyList())
180
+
181
+ fun add (value : T ) {
182
+ field.loop { current ->
183
+ val new = current + value
184
+ if (field.compareAndSet(current, new)) return
185
+ }
186
+ }
187
+
188
+ fun filter (predicate : (T ) -> Boolean ): List <T > = field.value.filter(predicate)
189
+ }
0 commit comments