Skip to content

Commit 995f8bd

Browse files
committed
Move ThreadContextElement to common
1 parent 2d9f944 commit 995f8bd

File tree

9 files changed

+485
-139
lines changed

9 files changed

+485
-139
lines changed

Diff for: kotlinx-coroutines-core/api/kotlinx-coroutines-core.klib.api

+4
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ abstract interface <#A: kotlin/Any?> kotlinx.coroutines/CompletableDeferred : ko
159159
abstract fun complete(#A): kotlin/Boolean // kotlinx.coroutines/CompletableDeferred.complete|complete(1:0){}[0]
160160
abstract fun completeExceptionally(kotlin/Throwable): kotlin/Boolean // kotlinx.coroutines/CompletableDeferred.completeExceptionally|completeExceptionally(kotlin.Throwable){}[0]
161161
}
162+
abstract interface <#A: kotlin/Any?> kotlinx.coroutines/ThreadContextElement : kotlin.coroutines/CoroutineContext.Element { // kotlinx.coroutines/ThreadContextElement|null[0]
163+
abstract fun restoreThreadContext(kotlin.coroutines/CoroutineContext, #A) // kotlinx.coroutines/ThreadContextElement.restoreThreadContext|restoreThreadContext(kotlin.coroutines.CoroutineContext;1:0){}[0]
164+
abstract fun updateThreadContext(kotlin.coroutines/CoroutineContext): #A // kotlinx.coroutines/ThreadContextElement.updateThreadContext|updateThreadContext(kotlin.coroutines.CoroutineContext){}[0]
165+
}
162166
abstract interface <#A: kotlin/Throwable & kotlinx.coroutines/CopyableThrowable<#A>> kotlinx.coroutines/CopyableThrowable { // kotlinx.coroutines/CopyableThrowable|null[0]
163167
abstract fun createCopy(): #A? // kotlinx.coroutines/CopyableThrowable.createCopy|createCopy(){}[0]
164168
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package kotlinx.coroutines
2+
3+
import kotlin.coroutines.*
4+
5+
/**
6+
* Defines elements in [CoroutineContext] that are installed into thread context
7+
* every time the coroutine with this element in the context is resumed on a thread.
8+
*
9+
* Implementations of this interface define a type [S] of the thread-local state that they need to store on
10+
* resume of a coroutine and restore later on suspend. The infrastructure provides the corresponding storage.
11+
*
12+
* Example usage looks like this:
13+
*
14+
* ```
15+
* // Appends "name" of a coroutine to a current thread name when coroutine is executed
16+
* class CoroutineName(val name: String) : ThreadContextElement<String> {
17+
* // declare companion object for a key of this element in coroutine context
18+
* companion object Key : CoroutineContext.Key<CoroutineName>
19+
*
20+
* // provide the key of the corresponding context element
21+
* override val key: CoroutineContext.Key<CoroutineName>
22+
* get() = Key
23+
*
24+
* // this is invoked before coroutine is resumed on current thread
25+
* override fun updateThreadContext(context: CoroutineContext): String {
26+
* val previousName = Thread.currentThread().name
27+
* Thread.currentThread().name = "$previousName # $name"
28+
* return previousName
29+
* }
30+
*
31+
* // this is invoked after coroutine has suspended on current thread
32+
* override fun restoreThreadContext(context: CoroutineContext, oldState: String) {
33+
* Thread.currentThread().name = oldState
34+
* }
35+
* }
36+
*
37+
* // Usage
38+
* launch(Dispatchers.Main + CoroutineName("Progress bar coroutine")) { ... }
39+
* ```
40+
*
41+
* Every time this coroutine is resumed on a thread, UI thread name is updated to
42+
* "UI thread original name # Progress bar coroutine" and the thread name is restored to the original one when
43+
* this coroutine suspends.
44+
*
45+
* To use [ThreadLocal] variable within the coroutine use [ThreadLocal.asContextElement][asContextElement] function.
46+
*
47+
* ### Reentrancy and thread-safety
48+
*
49+
* Correct implementations of this interface must expect that calls to [restoreThreadContext]
50+
* may happen in parallel to the subsequent [updateThreadContext] and [restoreThreadContext] operations.
51+
* See [CopyableThreadContextElement] for advanced interleaving details.
52+
*
53+
* All implementations of [ThreadContextElement] should be thread-safe and guard their internal mutable state
54+
* within an element accordingly.
55+
*/
56+
public interface ThreadContextElement<S> : CoroutineContext.Element {
57+
/**
58+
* Updates context of the current thread.
59+
* This function is invoked before the coroutine in the specified [context] is resumed in the current thread
60+
* when the context of the coroutine this element.
61+
* The result of this function is the old value of the thread-local state that will be passed to [restoreThreadContext].
62+
* This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which
63+
* context is updated in an undefined state and may crash an application.
64+
*
65+
* @param context the coroutine context.
66+
*/
67+
public fun updateThreadContext(context: CoroutineContext): S
68+
69+
/**
70+
* Restores context of the current thread.
71+
* This function is invoked after the coroutine in the specified [context] is suspended in the current thread
72+
* if [updateThreadContext] was previously invoked on resume of this coroutine.
73+
* The value of [oldState] is the result of the previous invocation of [updateThreadContext] and it should
74+
* be restored in the thread-local state by this function.
75+
* This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which
76+
* context is updated in an undefined state and may crash an application.
77+
*
78+
* @param context the coroutine context.
79+
* @param oldState the value returned by the previous invocation of [updateThreadContext].
80+
*/
81+
public fun restoreThreadContext(context: CoroutineContext, oldState: S)
82+
}
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,56 @@
11
package kotlinx.coroutines.internal
22

3+
import kotlinx.coroutines.*
34
import kotlin.coroutines.*
5+
import kotlin.jvm.*
46

5-
internal expect fun threadContextElements(context: CoroutineContext): Any
7+
@JvmField
8+
internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS")
9+
10+
// Used when there are >= 2 active elements in the context
11+
@Suppress("UNCHECKED_CAST")
12+
internal class ThreadState(@JvmField val context: CoroutineContext, n: Int) {
13+
private val values = arrayOfNulls<Any>(n)
14+
private val elements = arrayOfNulls<ThreadContextElement<Any?>>(n)
15+
private var i = 0
16+
17+
fun append(element: ThreadContextElement<*>, value: Any?) {
18+
values[i] = value
19+
elements[i++] = element as ThreadContextElement<Any?>
20+
}
21+
22+
fun restore(context: CoroutineContext) {
23+
for (i in elements.indices.reversed()) {
24+
elements[i]!!.restoreThreadContext(context, values[i])
25+
}
26+
}
27+
}
28+
29+
// Counts ThreadContextElements in the context
30+
// Any? here is Int | ThreadContextElement (when count is one)
31+
private val countAll =
32+
fun (countOrElement: Any?, element: CoroutineContext.Element): Any? {
33+
if (element is ThreadContextElement<*>) {
34+
val inCount = countOrElement as? Int ?: 1
35+
return if (inCount == 0) element else inCount + 1
36+
}
37+
return countOrElement
38+
}
39+
40+
// Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one
41+
internal val findOne =
42+
fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? {
43+
if (found != null) return found
44+
return element as? ThreadContextElement<*>
45+
}
46+
47+
// Updates state for ThreadContextElements in the context using the given ThreadState
48+
internal val updateState =
49+
fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
50+
if (element is ThreadContextElement<*>) {
51+
state.append(element, element.updateThreadContext(state.context))
52+
}
53+
return state
54+
}
55+
56+
internal fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!!

Diff for: kotlinx-coroutines-core/jsAndWasmShared/src/CoroutineContext.kt

+101-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
package kotlinx.coroutines
22

3+
import kotlinx.coroutines.internal.*
4+
import kotlinx.coroutines.internal.CoroutineStackFrame
5+
import kotlinx.coroutines.internal.NO_THREAD_ELEMENTS
36
import kotlinx.coroutines.internal.ScopeCoroutine
7+
import kotlinx.coroutines.internal.restoreThreadContext
8+
import kotlinx.coroutines.internal.updateThreadContext
49
import kotlin.coroutines.*
10+
import kotlin.jvm.*
511

612
@PublishedApi // Used from kotlinx-coroutines-test via suppress, not part of ABI
713
internal actual val DefaultDelay: Delay
@@ -18,14 +24,106 @@ public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineCo
1824
}
1925

2026
// No debugging facilities on Wasm and JS
21-
internal actual inline fun <T> withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T = block()
22-
internal actual inline fun <T> withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T = block()
27+
internal actual inline fun <T> withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T {
28+
val oldValue = updateThreadContext(context, countOrElement)
29+
try {
30+
return block()
31+
} finally {
32+
restoreThreadContext(context, oldValue)
33+
}
34+
}
35+
36+
internal actual inline fun <T> withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T {
37+
val context = continuation.context
38+
val oldValue = updateThreadContext(context, countOrElement)
39+
val undispatchedCompletion = if (oldValue !== NO_THREAD_ELEMENTS) {
40+
// Only if some values were replaced we'll go to the slow path of figuring out where/how to restore them
41+
continuation.updateUndispatchedCompletion(context, oldValue)
42+
} else {
43+
null // fast path -- don't even try to find undispatchedCompletion as there's nothing to restore in the context
44+
}
45+
try {
46+
return block()
47+
} finally {
48+
if (undispatchedCompletion == null || undispatchedCompletion.clearThreadContext()) {
49+
restoreThreadContext(context, oldValue)
50+
}
51+
}
52+
}
53+
54+
internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? {
55+
if (this !is CoroutineStackFrame) return null
56+
/*
57+
* Fast-path to detect whether we have undispatched coroutine at all in our stack.
58+
*
59+
* Implementation note.
60+
* If we ever find that stackwalking for thread-locals is way too slow, here is another idea:
61+
* 1) Store undispatched coroutine right in the `UndispatchedMarker` instance
62+
* 2) To avoid issues with cross-dispatch boundary, remove `UndispatchedMarker`
63+
* from the context when creating dispatched coroutine in `withContext`.
64+
* Another option is to "unmark it" instead of removing to save an allocation.
65+
* Both options should work, but it requires more careful studying of the performance
66+
* and, mostly, maintainability impact.
67+
*/
68+
val potentiallyHasUndispatchedCoroutine = context[UndispatchedMarker] !== null
69+
if (!potentiallyHasUndispatchedCoroutine) return null
70+
val completion = undispatchedCompletion()
71+
completion?.saveThreadContext(context, oldValue)
72+
return completion
73+
}
74+
75+
internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedCoroutine<*>? {
76+
// Find direct completion of this continuation
77+
val completion: CoroutineStackFrame = when (this) {
78+
is DispatchedCoroutine<*> -> return null
79+
else -> callerFrame ?: return null // something else -- not supported
80+
}
81+
if (completion is UndispatchedCoroutine<*>) return completion // found UndispatchedCoroutine!
82+
return completion.undispatchedCompletion() // walk up the call stack with tail call
83+
}
84+
85+
/**
86+
* Marker indicating that [UndispatchedCoroutine] exists somewhere up in the stack.
87+
* Used as a performance optimization to avoid stack walking where it is not necessary.
88+
*/
89+
private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key<UndispatchedMarker> {
90+
override val key: CoroutineContext.Key<*>
91+
get() = this
92+
}
93+
2394
internal actual fun Continuation<*>.toDebugString(): String = toString()
2495
internal actual val CoroutineContext.coroutineName: String? get() = null // not supported on Wasm and JS
2596

2697
internal actual class UndispatchedCoroutine<in T> actual constructor(
2798
context: CoroutineContext,
2899
uCont: Continuation<T>
29100
) : ScopeCoroutine<T>(context, uCont) {
30-
override fun afterResume(state: Any?) = uCont.resumeWith(recoverResult(state, uCont))
101+
102+
private var savedContext: CoroutineContext? = null
103+
private var savedOldValue: Any? = null
104+
105+
fun saveThreadContext(context: CoroutineContext, oldValue: Any?) {
106+
savedContext = context
107+
savedOldValue = oldValue
108+
}
109+
110+
fun clearThreadContext(): Boolean {
111+
if (savedContext == null) return false
112+
savedContext = null
113+
savedOldValue = null
114+
return true
115+
}
116+
117+
override fun afterResume(state: Any?) {
118+
savedContext?.let { context ->
119+
restoreThreadContext(context, savedOldValue)
120+
savedContext = null
121+
savedOldValue = null
122+
}
123+
// resume undispatched -- update context but stay on the same dispatcher
124+
val result = recoverResult(state, uCont)
125+
withContinuationContext(uCont, null) {
126+
uCont.resumeWith(result)
127+
}
128+
}
31129
}
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,41 @@
11
package kotlinx.coroutines.internal
22

3+
import kotlinx.coroutines.*
34
import kotlin.coroutines.*
45

5-
internal actual fun threadContextElements(context: CoroutineContext): Any = 0
6+
// countOrElement is pre-cached in dispatched continuation
7+
// returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements
8+
internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? {
9+
@Suppress("NAME_SHADOWING")
10+
val countOrElement = countOrElement ?: threadContextElements(context)
11+
@Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS")
12+
return when {
13+
countOrElement == 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements
14+
countOrElement is Int -> {
15+
// slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values
16+
context.fold(ThreadState(context, countOrElement), updateState)
17+
}
18+
else -> {
19+
// fast path for one ThreadContextElement (no allocations, no additional context scan)
20+
@Suppress("UNCHECKED_CAST")
21+
val element = countOrElement as ThreadContextElement<Any?>
22+
element.updateThreadContext(context)
23+
}
24+
}
25+
}
26+
27+
internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
28+
when {
29+
oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements
30+
oldState is ThreadState -> {
31+
// slow path with multiple stored ThreadContextElements
32+
oldState.restore(context)
33+
}
34+
else -> {
35+
// fast path for one ThreadContextElement, but need to find it
36+
@Suppress("UNCHECKED_CAST")
37+
val element = context.fold(null, findOne) as ThreadContextElement<Any?>
38+
element.restoreThreadContext(context, oldState)
39+
}
40+
}
41+
}

0 commit comments

Comments
 (0)