Skip to content

Move ThreadContextElement to common #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions kotlinx-coroutines-core/api/kotlinx-coroutines-core.klib.api
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ abstract interface <#A: kotlin/Any?> kotlinx.coroutines/CompletableDeferred : ko
abstract fun completeExceptionally(kotlin/Throwable): kotlin/Boolean // kotlinx.coroutines/CompletableDeferred.completeExceptionally|completeExceptionally(kotlin.Throwable){}[0]
}

abstract interface <#A: kotlin/Any?> kotlinx.coroutines/ThreadContextElement : kotlin.coroutines/CoroutineContext.Element { // kotlinx.coroutines/ThreadContextElement|null[0]
abstract fun restoreThreadContext(kotlin.coroutines/CoroutineContext, #A) // kotlinx.coroutines/ThreadContextElement.restoreThreadContext|restoreThreadContext(kotlin.coroutines.CoroutineContext;1:0){}[0]
abstract fun updateThreadContext(kotlin.coroutines/CoroutineContext): #A // kotlinx.coroutines/ThreadContextElement.updateThreadContext|updateThreadContext(kotlin.coroutines.CoroutineContext){}[0]
}

abstract interface <#A: kotlin/Throwable & kotlinx.coroutines/CopyableThrowable<#A>> kotlinx.coroutines/CopyableThrowable { // kotlinx.coroutines/CopyableThrowable|null[0]
abstract fun createCopy(): #A? // kotlinx.coroutines/CopyableThrowable.createCopy|createCopy(){}[0]
}
Expand Down
82 changes: 82 additions & 0 deletions kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package kotlinx.coroutines

import kotlin.coroutines.*

/**
* Defines elements in [CoroutineContext] that are installed into thread context
* every time the coroutine with this element in the context is resumed on a thread.
*
* Implementations of this interface define a type [S] of the thread-local state that they need to store on
* resume of a coroutine and restore later on suspend. The infrastructure provides the corresponding storage.
*
* Example usage looks like this:
*
* ```
* // Appends "name" of a coroutine to a current thread name when coroutine is executed
* class CoroutineName(val name: String) : ThreadContextElement<String> {
* // declare companion object for a key of this element in coroutine context
* companion object Key : CoroutineContext.Key<CoroutineName>
*
* // provide the key of the corresponding context element
* override val key: CoroutineContext.Key<CoroutineName>
* get() = Key
*
* // this is invoked before coroutine is resumed on current thread
* override fun updateThreadContext(context: CoroutineContext): String {
* val previousName = Thread.currentThread().name
* Thread.currentThread().name = "$previousName # $name"
* return previousName
* }
*
* // this is invoked after coroutine has suspended on current thread
* override fun restoreThreadContext(context: CoroutineContext, oldState: String) {
* Thread.currentThread().name = oldState
* }
* }
*
* // Usage
* launch(Dispatchers.Main + CoroutineName("Progress bar coroutine")) { ... }
* ```
*
* Every time this coroutine is resumed on a thread, UI thread name is updated to
* "UI thread original name # Progress bar coroutine" and the thread name is restored to the original one when
* this coroutine suspends.
*
* To use [ThreadLocal] variable within the coroutine use [ThreadLocal.asContextElement][asContextElement] function.
*
* ### Reentrancy and thread-safety
*
* Correct implementations of this interface must expect that calls to [restoreThreadContext]
* may happen in parallel to the subsequent [updateThreadContext] and [restoreThreadContext] operations.
* See [CopyableThreadContextElement] for advanced interleaving details.
*
* All implementations of [ThreadContextElement] should be thread-safe and guard their internal mutable state
* within an element accordingly.
*/
public interface ThreadContextElement<S> : CoroutineContext.Element {
/**
* Updates context of the current thread.
* This function is invoked before the coroutine in the specified [context] is resumed in the current thread
* when the context of the coroutine this element.
* The result of this function is the old value of the thread-local state that will be passed to [restoreThreadContext].
* This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which
* context is updated in an undefined state and may crash an application.
*
* @param context the coroutine context.
*/
public fun updateThreadContext(context: CoroutineContext): S

/**
* Restores context of the current thread.
* This function is invoked after the coroutine in the specified [context] is suspended in the current thread
* if [updateThreadContext] was previously invoked on resume of this coroutine.
* The value of [oldState] is the result of the previous invocation of [updateThreadContext] and it should
* be restored in the thread-local state by this function.
* This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which
* context is updated in an undefined state and may crash an application.
*
* @param context the coroutine context.
* @param oldState the value returned by the previous invocation of [updateThreadContext].
*/
public fun restoreThreadContext(context: CoroutineContext, oldState: S)
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,56 @@
package kotlinx.coroutines.internal

import kotlinx.coroutines.*
import kotlin.coroutines.*
import kotlin.jvm.*

internal expect fun threadContextElements(context: CoroutineContext): Any
@JvmField
internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS")

// Used when there are >= 2 active elements in the context
@Suppress("UNCHECKED_CAST")
internal class ThreadState(@JvmField val context: CoroutineContext, n: Int) {
private val values = arrayOfNulls<Any>(n)
private val elements = arrayOfNulls<ThreadContextElement<Any?>>(n)
private var i = 0

fun append(element: ThreadContextElement<*>, value: Any?) {
values[i] = value
elements[i++] = element as ThreadContextElement<Any?>
}

fun restore(context: CoroutineContext) {
for (i in elements.indices.reversed()) {
elements[i]!!.restoreThreadContext(context, values[i])
}
}
}

// Counts ThreadContextElements in the context
// Any? here is Int | ThreadContextElement (when count is one)
private val countAll =
fun (countOrElement: Any?, element: CoroutineContext.Element): Any? {
if (element is ThreadContextElement<*>) {
val inCount = countOrElement as? Int ?: 1
return if (inCount == 0) element else inCount + 1
}
return countOrElement
}

// Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one
internal val findOne =
fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? {
if (found != null) return found
return element as? ThreadContextElement<*>
}

// Updates state for ThreadContextElements in the context using the given ThreadState
internal val updateState =
fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
if (element is ThreadContextElement<*>) {
state.append(element, element.updateThreadContext(state.context))
}
return state
}

internal fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!!
102 changes: 99 additions & 3 deletions kotlinx-coroutines-core/jsAndWasmShared/src/CoroutineContext.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package kotlinx.coroutines

import kotlinx.coroutines.internal.CoroutineStackFrame
import kotlinx.coroutines.internal.NO_THREAD_ELEMENTS
import kotlinx.coroutines.internal.ScopeCoroutine
import kotlinx.coroutines.internal.restoreThreadContext
import kotlinx.coroutines.internal.updateThreadContext
import kotlin.coroutines.*

@PublishedApi // Used from kotlinx-coroutines-test via suppress, not part of ABI
Expand All @@ -18,16 +22,108 @@ public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineCo
}

// No debugging facilities on Wasm and JS
internal actual inline fun <T> withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T = block()
internal actual inline fun <T> withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T = block()
internal actual inline fun <T> withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T {
val oldValue = updateThreadContext(context, countOrElement)
try {
return block()
} finally {
restoreThreadContext(context, oldValue)
}
}

internal actual inline fun <T> withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T {
val context = continuation.context
val oldValue = updateThreadContext(context, countOrElement)
val undispatchedCompletion = if (oldValue !== NO_THREAD_ELEMENTS) {
// Only if some values were replaced we'll go to the slow path of figuring out where/how to restore them
continuation.updateUndispatchedCompletion(context, oldValue)
} else {
null // fast path -- don't even try to find undispatchedCompletion as there's nothing to restore in the context
}
try {
return block()
} finally {
if (undispatchedCompletion == null || undispatchedCompletion.clearThreadContext()) {
restoreThreadContext(context, oldValue)
}
}
}

internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? {
if (this !is CoroutineStackFrame) return null
/*
* Fast-path to detect whether we have undispatched coroutine at all in our stack.
*
* Implementation note.
* If we ever find that stackwalking for thread-locals is way too slow, here is another idea:
* 1) Store undispatched coroutine right in the `UndispatchedMarker` instance
* 2) To avoid issues with cross-dispatch boundary, remove `UndispatchedMarker`
* from the context when creating dispatched coroutine in `withContext`.
* Another option is to "unmark it" instead of removing to save an allocation.
* Both options should work, but it requires more careful studying of the performance
* and, mostly, maintainability impact.
*/
val potentiallyHasUndispatchedCoroutine = context[UndispatchedMarker] !== null
if (!potentiallyHasUndispatchedCoroutine) return null
val completion = undispatchedCompletion()
completion?.saveThreadContext(context, oldValue)
return completion
}

internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedCoroutine<*>? {
// Find direct completion of this continuation
val completion: CoroutineStackFrame = when (this) {
is DispatchedCoroutine<*> -> return null
else -> callerFrame ?: return null // something else -- not supported
}
if (completion is UndispatchedCoroutine<*>) return completion // found UndispatchedCoroutine!
return completion.undispatchedCompletion() // walk up the call stack with tail call
}

/**
* Marker indicating that [UndispatchedCoroutine] exists somewhere up in the stack.
* Used as a performance optimization to avoid stack walking where it is not necessary.
*/
private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key<UndispatchedMarker> {
override val key: CoroutineContext.Key<*>
get() = this
}

internal actual fun Continuation<*>.toDebugString(): String = toString()
internal actual val CoroutineContext.coroutineName: String? get() = null // not supported on Wasm and JS

internal actual class UndispatchedCoroutine<in T> actual constructor(
context: CoroutineContext,
uCont: Continuation<T>
) : ScopeCoroutine<T>(context, uCont) {
override fun afterResume(state: Any?) = uCont.resumeWith(recoverResult(state, uCont))

private var savedContext: CoroutineContext? = null
private var savedOldValue: Any? = null

fun saveThreadContext(context: CoroutineContext, oldValue: Any?) {
savedContext = context
savedOldValue = oldValue
}

fun clearThreadContext(): Boolean {
if (savedContext == null) return false
savedContext = null
savedOldValue = null
return true
}

override fun afterResume(state: Any?) {
savedContext?.let { context ->
restoreThreadContext(context, savedOldValue)
savedContext = null
savedOldValue = null
}
// resume undispatched -- update context but stay on the same dispatcher
val result = recoverResult(state, uCont)
withContinuationContext(uCont, null) {
uCont.resumeWith(result)
}
}
}

internal actual inline fun <T> withThreadLocalContext(context: CoroutineContext, block: () -> T) : T = block()
Original file line number Diff line number Diff line change
@@ -1,5 +1,41 @@
package kotlinx.coroutines.internal

import kotlinx.coroutines.*
import kotlin.coroutines.*

internal actual fun threadContextElements(context: CoroutineContext): Any = 0
// countOrElement is pre-cached in dispatched continuation
// returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements
internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? {
@Suppress("NAME_SHADOWING")
val countOrElement = countOrElement ?: threadContextElements(context)
@Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS")
return when {
countOrElement == 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements
countOrElement is Int -> {
// slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values
context.fold(ThreadState(context, countOrElement), updateState)
}
else -> {
// fast path for one ThreadContextElement (no allocations, no additional context scan)
@Suppress("UNCHECKED_CAST")
val element = countOrElement as ThreadContextElement<Any?>
element.updateThreadContext(context)
}
}
}

internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
when {
oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements
oldState is ThreadState -> {
// slow path with multiple stored ThreadContextElements
oldState.restore(context)
}
else -> {
// fast path for one ThreadContextElement, but need to find it
@Suppress("UNCHECKED_CAST")
val element = context.fold(null, findOne) as ThreadContextElement<Any?>
element.restoreThreadContext(context, oldState)
}
}
}
Loading