Skip to content

Commit ab181d2

Browse files
committed
Make MPP ThreadContextElement compile and run
1 parent d102823 commit ab181d2

18 files changed

+125
-115
lines changed

Diff for: kotlinx-coroutines-core/api/kotlinx-coroutines-core.klib.api

+10
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,16 @@ abstract interface <#A: kotlin/Any?> kotlinx.coroutines/CompletableDeferred : ko
186186
abstract fun completeExceptionally(kotlin/Throwable): kotlin/Boolean // kotlinx.coroutines/CompletableDeferred.completeExceptionally|completeExceptionally(kotlin.Throwable){}[0]
187187
}
188188

189+
abstract interface <#A: kotlin/Any?> kotlinx.coroutines/CopyableThreadContextElement : kotlinx.coroutines/ThreadContextElement<#A> { // kotlinx.coroutines/CopyableThreadContextElement|null[0]
190+
abstract fun copyForChild(): kotlinx.coroutines/CopyableThreadContextElement<#A> // kotlinx.coroutines/CopyableThreadContextElement.copyForChild|copyForChild(){}[0]
191+
abstract fun mergeForChild(kotlin.coroutines/CoroutineContext.Element): kotlin.coroutines/CoroutineContext // kotlinx.coroutines/CopyableThreadContextElement.mergeForChild|mergeForChild(kotlin.coroutines.CoroutineContext.Element){}[0]
192+
}
193+
194+
abstract interface <#A: kotlin/Any?> kotlinx.coroutines/ThreadContextElement : kotlin.coroutines/CoroutineContext.Element { // kotlinx.coroutines/ThreadContextElement|null[0]
195+
abstract fun restoreThreadContext(kotlin.coroutines/CoroutineContext, #A) // kotlinx.coroutines/ThreadContextElement.restoreThreadContext|restoreThreadContext(kotlin.coroutines.CoroutineContext;1:0){}[0]
196+
abstract fun updateThreadContext(kotlin.coroutines/CoroutineContext): #A // kotlinx.coroutines/ThreadContextElement.updateThreadContext|updateThreadContext(kotlin.coroutines.CoroutineContext){}[0]
197+
}
198+
189199
abstract interface <#A: kotlin/Throwable & kotlinx.coroutines/CopyableThrowable<#A>> kotlinx.coroutines/CopyableThrowable { // kotlinx.coroutines/CopyableThrowable|null[0]
190200
abstract fun createCopy(): #A? // kotlinx.coroutines/CopyableThrowable.createCopy|createCopy(){}[0]
191201
}

Diff for: kotlinx-coroutines-core/common/src/Builders.common.kt

+3-8
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import kotlinx.atomicfu.*
99
import kotlinx.coroutines.internal.*
1010
import kotlinx.coroutines.intrinsics.*
1111
import kotlinx.coroutines.selects.*
12+
import kotlin.concurrent.Volatile
1213
import kotlin.contracts.*
1314
import kotlin.coroutines.*
1415
import kotlin.coroutines.intrinsics.*
@@ -206,13 +207,7 @@ private class LazyStandaloneCoroutine(
206207
}
207208

208209
// Used by withContext when context changes, but dispatcher stays the same
209-
internal expect class UndispatchedCoroutine<in T>(
210-
context: CoroutineContext,
211-
uCont: Continuation<T>
212-
) : ScopeCoroutine<T>
213-
214-
// Used by withContext when context changes, but dispatcher stays the same
215-
internal actual class UndispatchedCoroutine<in T>actual constructor (
210+
internal class UndispatchedCoroutine<in T>(
216211
context: CoroutineContext,
217212
uCont: Continuation<T>
218213
) : ScopeCoroutine<T>(if (context[UndispatchedMarker] == null) context + UndispatchedMarker else context, uCont) {
@@ -249,7 +244,7 @@ internal actual class UndispatchedCoroutine<in T>actual constructor (
249244
* - It's never accessed when we are sure there are no thread context elements
250245
* - It's cleaned up via [ThreadLocal.remove] as soon as the coroutine is suspended or finished.
251246
*/
252-
private val threadStateToRecover = ThreadLocal<Pair<CoroutineContext, Any?>>()
247+
private val threadStateToRecover = commonThreadLocal<Pair<CoroutineContext, Any?>?>(Symbol("UndispatchedCoroutine"))
253248

254249
/*
255250
* Indicates that a coroutine has at least one thread context element associated with it

Diff for: kotlinx-coroutines-core/common/src/CoroutineContext.common.kt

+6-22
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,18 @@ package kotlinx.coroutines
33
import kotlinx.coroutines.internal.*
44
import kotlin.coroutines.*
55

6-
/**
7-
* Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or
8-
* [ContinuationInterceptor] is specified and adds optional support for debugging facilities (when turned on)
9-
* and copyable-thread-local facilities on JVM.
10-
*/
11-
public expect fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext
12-
13-
/**
14-
* Creates a context for coroutine builder functions that do not launch a new coroutine, e.g. [withContext].
15-
* @suppress
16-
*/
17-
@InternalCoroutinesApi
18-
public expect fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext
19-
206
@PublishedApi // to have unmangled name when using from other modules via suppress
217
@Suppress("PropertyName")
228
internal expect val DefaultDelay: Delay
239

24-
// countOrElement -- pre-cached value for ThreadContext.kt
25-
internal expect inline fun <T> withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T
26-
internal expect inline fun <T> withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T
2710
internal expect fun Continuation<*>.toDebugString(): String
2811
internal expect val CoroutineContext.coroutineName: String?
12+
internal expect fun wrapContextWithDebug(context: CoroutineContext): CoroutineContext
2913

3014
/**
3115
* Executes a block using a given coroutine context.
3216
*/
33-
internal actual inline fun <T> withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T {
17+
internal inline fun <T> withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T {
3418
val oldValue = updateThreadContext(context, countOrElement)
3519
try {
3620
return block()
@@ -42,7 +26,7 @@ internal actual inline fun <T> withCoroutineContext(context: CoroutineContext, c
4226
/**
4327
* Executes a block using a context of a given continuation.
4428
*/
45-
internal actual inline fun <T> withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T {
29+
internal inline fun <T> withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T {
4630
val context = continuation.context
4731
val oldValue = updateThreadContext(context, countOrElement)
4832
val undispatchedCompletion = if (oldValue !== NO_THREAD_ELEMENTS) {
@@ -60,7 +44,7 @@ internal actual inline fun <T> withContinuationContext(continuation: Continuatio
6044
}
6145
}
6246

63-
internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? {
47+
private fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? {
6448
if (this !is CoroutineStackFrame) return null
6549
/*
6650
* Fast-path to detect whether we have undispatched coroutine at all in our stack.
@@ -81,7 +65,7 @@ internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineCont
8165
return completion
8266
}
8367

84-
internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedCoroutine<*>? {
68+
private tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedCoroutine<*>? {
8569
// Find direct completion of this continuation
8670
val completion: CoroutineStackFrame = when (this) {
8771
is DispatchedCoroutine<*> -> return null
@@ -95,7 +79,7 @@ internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedC
9579
* Marker indicating that [UndispatchedCoroutine] exists somewhere up in the stack.
9680
* Used as a performance optimization to avoid stack walking where it is not necessary.
9781
*/
98-
private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key<UndispatchedMarker> {
82+
internal object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key<UndispatchedMarker> {
9983
override val key: CoroutineContext.Key<*>
10084
get() = this
10185
}

Diff for: kotlinx-coroutines-core/common/src/CoroutineContext.sharedWithJvm.kt

+10-7
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
1+
// This file should be a part of `CoroutineContext.common.kt`, but adding `JvmName` to that fails: KT-75248
2+
@file:JvmName("CoroutineContextKt")
3+
@file:JvmMultifileClass
14
package kotlinx.coroutines
25

6+
import kotlin.coroutines.ContinuationInterceptor
37
import kotlin.coroutines.CoroutineContext
48
import kotlin.coroutines.EmptyCoroutineContext
5-
9+
import kotlin.jvm.JvmMultifileClass
10+
import kotlin.jvm.JvmName
611

712
/**
813
* Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or
9-
* [ContinuationInterceptor] is specified and adds optional support for debugging facilities (when turned on)
10-
* and copyable-thread-local facilities on JVM.
11-
* See [DEBUG_PROPERTY_NAME] for description of debugging facilities on JVM.
14+
* [ContinuationInterceptor] is specified and
1215
*/
1316
@ExperimentalCoroutinesApi
14-
public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext {
17+
public fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext {
1518
val combined = foldCopies(coroutineContext, context, true)
16-
val debug = if (DEBUG) combined + CoroutineId(COROUTINE_ID.incrementAndGet()) else combined
19+
val debug = wrapContextWithDebug(combined)
1720
return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null)
1821
debug + Dispatchers.Default else debug
1922
}
@@ -23,7 +26,7 @@ public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext):
2326
* @suppress
2427
*/
2528
@InternalCoroutinesApi
26-
public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext {
29+
public fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext {
2730
/*
2831
* Fast-path: we only have to copy/merge if 'addedContext' (which typically has one or two elements)
2932
* contains copyable elements.

Diff for: kotlinx-coroutines-core/common/src/internal/ThreadContext.common.kt

+6-8
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ import kotlinx.coroutines.ThreadContextElement
44
import kotlin.coroutines.*
55
import kotlin.jvm.JvmField
66

7-
internal expect fun threadContextElements(context: CoroutineContext): Any
8-
97
@JvmField
108
internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS")
119

@@ -29,9 +27,9 @@ private class ThreadState(@JvmField val context: CoroutineContext, n: Int) {
2927
}
3028

3129
// Counts ThreadContextElements in the context
32-
// Any? here is Int | ThreadContextElement (when count is one)
30+
// Any here is Int | ThreadContextElement (when count is one)
3331
private val countAll =
34-
fun (countOrElement: Any?, element: CoroutineContext.Element): Any? {
32+
fun (countOrElement: Any, element: CoroutineContext.Element): Any {
3533
if (element is ThreadContextElement<*>) {
3634
val inCount = countOrElement as? Int ?: 1
3735
return if (inCount == 0) element else inCount + 1
@@ -55,17 +53,15 @@ private val updateState =
5553
return state
5654
}
5755

58-
internal actual fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!!
56+
internal expect inline fun isZeroCount(countOrElement: Any?): Boolean
5957

6058
// countOrElement is pre-cached in dispatched continuation
6159
// returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements
6260
internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? {
6361
@Suppress("NAME_SHADOWING")
6462
val countOrElement = countOrElement ?: threadContextElements(context)
65-
@Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS")
6663
return when {
67-
countOrElement === 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements
68-
// ^^^ identity comparison for speed, we know zero always has the same identity
64+
isZeroCount(countOrElement) -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements
6965
countOrElement is Int -> {
7066
// slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values
7167
context.fold(ThreadState(context, countOrElement), updateState)
@@ -94,3 +90,5 @@ internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
9490
}
9591
}
9692
}
93+
94+
internal fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)

Diff for: kotlinx-coroutines-core/common/src/internal/ThreadLocal.common.kt

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package kotlinx.coroutines.internal
33
internal expect class CommonThreadLocal<T> {
44
fun get(): T
55
fun set(value: T)
6+
fun remove()
67
}
78

89
/**

Diff for: kotlinx-coroutines-core/common/test/ThreadContextElementTest.kt

+53-10
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
package kotlinx.coroutines
22

3+
import kotlinx.atomicfu.atomic
4+
import kotlinx.atomicfu.loop
35
import kotlinx.coroutines.testing.*
46
import kotlin.test.*
57
import kotlinx.coroutines.flow.*
8+
import kotlinx.coroutines.internal.CommonThreadLocal
9+
import kotlinx.coroutines.internal.Symbol
10+
import kotlinx.coroutines.internal.commonThreadLocal
611
import kotlin.coroutines.*
712

813
class ThreadContextElementTest: TestBase() {
@@ -38,20 +43,18 @@ class ThreadContextElementTest: TestBase() {
3843
*/
3944
@Test
4045
fun testWithContextJobAccess() = runTest {
41-
val executor = Executors.newSingleThreadExecutor()
4246
// Emulate non-equal dispatchers
43-
val executor1 = object : ExecutorService by executor {}
44-
val executor2 = object : ExecutorService by executor {}
45-
val dispatcher1 = executor1.asCoroutineDispatcher()
46-
val dispatcher2 = executor2.asCoroutineDispatcher()
47+
val dispatcher = Dispatchers.Default.limitedParallelism(1)
48+
val dispatcher1 = dispatcher.limitedParallelism(1, "dispatcher1")
49+
val dispatcher2 = dispatcher.limitedParallelism(1, "dispatcher2")
4750
val captor = JobCaptor()
4851
val manuallyCaptured = mutableListOf<String>()
4952

5053
fun registerUpdate(job: Job?) = manuallyCaptured.add("Update: $job")
5154
fun registerRestore(job: Job?) = manuallyCaptured.add("Restore: $job")
5255

5356
var rootJob: Job? = null
54-
runBlocking(captor + dispatcher1) {
57+
withContext(captor + dispatcher1) {
5558
rootJob = coroutineContext.job
5659
registerUpdate(rootJob)
5760
var undispatchedJob: Job? = null
@@ -84,7 +87,6 @@ class ThreadContextElementTest: TestBase() {
8487
val expected = manuallyCaptured.filter { it.startsWith("Update: ") }.joinToString(separator = "\n")
8588
val actual = captor.capturees.filter { it.startsWith("Update: ") }.joinToString(separator = "\n")
8689
assertEquals(expected, actual)
87-
executor.shutdownNow()
8890
}
8991

9092
@Test
@@ -96,7 +98,7 @@ class ThreadContextElementTest: TestBase() {
9698
assertEquals(myData, myThreadLocal.get())
9799
emit(1)
98100
}
99-
.flowOn(myThreadLocal.asContextElement() + Dispatchers.Default)
101+
.flowOn(myThreadLocal.asCtxElement() + Dispatchers.Default)
100102
.single()
101103
myThreadLocal.set(null)
102104
finish(2)
@@ -105,7 +107,7 @@ class ThreadContextElementTest: TestBase() {
105107

106108
class MyData
107109

108-
class JobCaptor(val capturees: MutableList<String> = CopyOnWriteArrayList()) : ThreadContextElement<Unit> {
110+
private class JobCaptor(val capturees: CopyOnWriteList<String> = CopyOnWriteList()) : ThreadContextElement<Unit> {
109111

110112
companion object Key : CoroutineContext.Key<MyElement>
111113

@@ -121,7 +123,7 @@ class JobCaptor(val capturees: MutableList<String> = CopyOnWriteArrayList()) : T
121123
}
122124

123125
// declare thread local variable holding MyData
124-
private val myThreadLocal = ThreadLocal<MyData?>()
126+
internal val myThreadLocal = commonThreadLocal<MyData?>(Symbol("myElement"))
125127

126128
// declare context element holding MyData
127129
class MyElement(val data: MyData) : ThreadContextElement<MyData?> {
@@ -144,3 +146,44 @@ class MyElement(val data: MyData) : ThreadContextElement<MyData?> {
144146
myThreadLocal.set(oldState)
145147
}
146148
}
149+
150+
151+
private class CommonThreadLocalContextElement<T>(
152+
private val threadLocal: CommonThreadLocal<T>,
153+
private val value: T = threadLocal.get()
154+
): ThreadContextElement<T>, CoroutineContext.Key<CommonThreadLocalContextElement<T>> {
155+
// provide the key of the corresponding context element
156+
override val key: CoroutineContext.Key<CommonThreadLocalContextElement<T>>
157+
get() = this
158+
159+
// this is invoked before coroutine is resumed on current thread
160+
override fun updateThreadContext(context: CoroutineContext): T {
161+
val oldState = threadLocal.get()
162+
threadLocal.set(value)
163+
return oldState
164+
}
165+
166+
// this is invoked after coroutine has suspended on current thread
167+
override fun restoreThreadContext(context: CoroutineContext, oldState: T) {
168+
threadLocal.set(oldState)
169+
}
170+
}
171+
172+
// overload resolution issues if this is called `asContextElement`
173+
internal fun <T> CommonThreadLocal<T>.asCtxElement(value: T = get()): ThreadContextElement<T> =
174+
CommonThreadLocalContextElement(this, value)
175+
176+
private class CopyOnWriteList<T> private constructor(list: List<T>) {
177+
private val field = atomic(list)
178+
179+
constructor() : this(emptyList())
180+
181+
fun add(value: T) {
182+
field.loop { current ->
183+
val new = current + value
184+
if (field.compareAndSet(current, new)) return
185+
}
186+
}
187+
188+
fun filter(predicate: (T) -> Boolean): List<T> = field.value.filter(predicate)
189+
}

Diff for: kotlinx-coroutines-core/common/test/ThreadContextMutableCopiesTest.kt

+10-6
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@ package kotlinx.coroutines
22

33
import kotlinx.coroutines.testing.*
44
import kotlinx.coroutines.flow.*
5+
import kotlinx.coroutines.internal.Symbol
6+
import kotlinx.coroutines.internal.commonThreadLocal
57
import kotlin.coroutines.*
68
import kotlin.test.*
79

810
class ThreadContextMutableCopiesTest : TestBase() {
911
companion object {
10-
val threadLocalData: ThreadLocal<MutableList<String>> = ThreadLocal.withInitial { ArrayList() }
12+
internal val threadLocalData = commonThreadLocal<MutableList<String>>(Symbol("ThreadLocalData")).also {
13+
it.set(mutableListOf())
14+
}
1115
}
1216

1317
class MyMutableElement(
@@ -42,7 +46,7 @@ class ThreadContextMutableCopiesTest : TestBase() {
4246
@Test
4347
fun testDataIsCopied() = runTest {
4448
val root = MyMutableElement(ArrayList())
45-
runBlocking(root) {
49+
launch(root) {
4650
val data = threadLocalData.get()
4751
expect(1)
4852
launch(root) {
@@ -56,7 +60,7 @@ class ThreadContextMutableCopiesTest : TestBase() {
5660
@Test
5761
fun testDataIsNotOverwritten() = runTest {
5862
val root = MyMutableElement(ArrayList())
59-
runBlocking(root) {
63+
withContext(root) {
6064
expect(1)
6165
val originalData = threadLocalData.get()
6266
threadLocalData.get().add("X")
@@ -75,7 +79,7 @@ class ThreadContextMutableCopiesTest : TestBase() {
7579
@Test
7680
fun testDataIsMerged() = runTest {
7781
val root = MyMutableElement(ArrayList())
78-
runBlocking(root) {
82+
withContext(root) {
7983
expect(1)
8084
val originalData = threadLocalData.get()
8185
threadLocalData.get().add("X")
@@ -94,7 +98,7 @@ class ThreadContextMutableCopiesTest : TestBase() {
9498
@Test
9599
fun testDataIsNotOverwrittenWithContext() = runTest {
96100
val root = MyMutableElement(ArrayList())
97-
runBlocking(root) {
101+
withContext(root) {
98102
val originalData = threadLocalData.get()
99103
threadLocalData.get().add("X")
100104
expect(1)
@@ -114,7 +118,7 @@ class ThreadContextMutableCopiesTest : TestBase() {
114118
fun testDataIsCopiedForRunBlocking() = runTest {
115119
val root = MyMutableElement(ArrayList())
116120
val originalData = root.mutableData
117-
runBlocking(root) {
121+
withContext(root) {
118122
assertNotSame(originalData, threadLocalData.get())
119123
}
120124
}

0 commit comments

Comments
 (0)