Skip to content

Commit 5040461

Browse files
committed
Move CopyableThreadContextElement to common
1 parent 877c70f commit 5040461

8 files changed

+349
-360
lines changed

Diff for: kotlinx-coroutines-core/common/src/CoroutineContext.common.kt

+79-14
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,13 @@ package kotlinx.coroutines
33
import kotlinx.coroutines.internal.*
44
import kotlin.coroutines.*
55

6-
/**
7-
* Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or
8-
* [ContinuationInterceptor] is specified and adds optional support for debugging facilities (when turned on)
9-
* and copyable-thread-local facilities on JVM.
10-
*/
11-
public expect fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext
12-
13-
/**
14-
* Creates a context for coroutine builder functions that do not launch a new coroutine, e.g. [withContext].
15-
* @suppress
16-
*/
17-
@InternalCoroutinesApi
18-
public expect fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext
19-
206
@PublishedApi // to have unmangled name when using from other modules via suppress
217
@Suppress("PropertyName")
228
internal expect val DefaultDelay: Delay
239

2410
internal expect fun Continuation<*>.toDebugString(): String
2511
internal expect val CoroutineContext.coroutineName: String?
12+
internal expect fun wrapContextWithDebug(context: CoroutineContext): CoroutineContext
2613

2714
/**
2815
* Executes a block using a given coroutine context.
@@ -98,3 +85,81 @@ internal object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.K
9885
override val key: CoroutineContext.Key<*>
9986
get() = this
10087
}
88+
89+
/**
90+
* Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or
91+
* [ContinuationInterceptor] is specified and
92+
*/
93+
@ExperimentalCoroutinesApi
94+
public fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext {
95+
val combined = foldCopies(coroutineContext, context, true)
96+
val debug = wrapContextWithDebug(combined)
97+
return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null)
98+
debug + Dispatchers.Default else debug
99+
}
100+
101+
/**
102+
* Creates a context for coroutine builder functions that do not launch a new coroutine, e.g. [withContext].
103+
* @suppress
104+
*/
105+
@InternalCoroutinesApi
106+
public fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext {
107+
/*
108+
* Fast-path: we only have to copy/merge if 'addedContext' (which typically has one or two elements)
109+
* contains copyable elements.
110+
*/
111+
if (!addedContext.hasCopyableElements()) return this + addedContext
112+
return foldCopies(this, addedContext, false)
113+
}
114+
115+
private fun CoroutineContext.hasCopyableElements(): Boolean =
116+
fold(false) { result, it -> result || it is CopyableThreadContextElement<*> }
117+
118+
/**
119+
* Folds two contexts properly applying [CopyableThreadContextElement] rules when necessary.
120+
* The rules are the following:
121+
* - If neither context has CTCE, the sum of two contexts is returned
122+
* - Every CTCE from the left-hand side context that does not have a matching (by key) element from right-hand side context
123+
* is [copied][CopyableThreadContextElement.copyForChild] if [isNewCoroutine] is `true`.
124+
* - Every CTCE from the left-hand side context that has a matching element in the right-hand side context is [merged][CopyableThreadContextElement.mergeForChild]
125+
* - Every CTCE from the right-hand side context that hasn't been merged is copied
126+
* - Everything else is added to the resulting context as is.
127+
*/
128+
private fun foldCopies(originalContext: CoroutineContext, appendContext: CoroutineContext, isNewCoroutine: Boolean): CoroutineContext {
129+
// Do we have something to copy left-hand side?
130+
val hasElementsLeft = originalContext.hasCopyableElements()
131+
val hasElementsRight = appendContext.hasCopyableElements()
132+
133+
// Nothing to fold, so just return the sum of contexts
134+
if (!hasElementsLeft && !hasElementsRight) {
135+
return originalContext + appendContext
136+
}
137+
138+
var leftoverContext = appendContext
139+
val folded = originalContext.fold<CoroutineContext>(EmptyCoroutineContext) { result, element ->
140+
if (element !is CopyableThreadContextElement<*>) return@fold result + element
141+
// Will this element be overwritten?
142+
val newElement = leftoverContext[element.key]
143+
// No, just copy it
144+
if (newElement == null) {
145+
// For 'withContext'-like builders we do not copy as the element is not shared
146+
return@fold result + if (isNewCoroutine) element.copyForChild() else element
147+
}
148+
// Yes, then first remove the element from append context
149+
leftoverContext = leftoverContext.minusKey(element.key)
150+
// Return the sum
151+
@Suppress("UNCHECKED_CAST")
152+
return@fold result + (element as CopyableThreadContextElement<Any?>).mergeForChild(newElement)
153+
}
154+
155+
if (hasElementsRight) {
156+
leftoverContext = leftoverContext.fold<CoroutineContext>(EmptyCoroutineContext) { result, element ->
157+
// We're appending new context element -- we have to copy it, otherwise it may be shared with others
158+
if (element is CopyableThreadContextElement<*>) {
159+
return@fold result + element.copyForChild()
160+
}
161+
return@fold result + element
162+
}
163+
}
164+
return folded + leftoverContext
165+
}

Diff for: kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt

+104
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,107 @@ public interface ThreadContextElement<S> : CoroutineContext.Element {
8080
*/
8181
public fun restoreThreadContext(context: CoroutineContext, oldState: S)
8282
}
83+
84+
/**
85+
* A [ThreadContextElement] copied whenever a child coroutine inherits a context containing it.
86+
*
87+
* When an API uses a _mutable_ [ThreadLocal] for consistency, a [CopyableThreadContextElement]
88+
* can give coroutines "coroutine-safe" write access to that `ThreadLocal`.
89+
*
90+
* A write made to a `ThreadLocal` with a matching [CopyableThreadContextElement] by a coroutine
91+
* will be visible to _itself_ and any child coroutine launched _after_ that write.
92+
*
93+
* Writes will not be visible to the parent coroutine, peer coroutines, or coroutines that happen
94+
* to use the same thread. Writes made to the `ThreadLocal` by the parent coroutine _after_
95+
* launching a child coroutine will not be visible to that child coroutine.
96+
*
97+
* This can be used to allow a coroutine to use a mutable ThreadLocal API transparently and
98+
* correctly, regardless of the coroutine's structured concurrency.
99+
*
100+
* This example adapts a `ThreadLocal` method trace to be "coroutine local" while the method trace
101+
* is in a coroutine:
102+
*
103+
* ```
104+
* class TraceContextElement(private val traceData: TraceData?) : CopyableThreadContextElement<TraceData?> {
105+
* companion object Key : CoroutineContext.Key<TraceContextElement>
106+
*
107+
* override val key: CoroutineContext.Key<TraceContextElement> = Key
108+
*
109+
* override fun updateThreadContext(context: CoroutineContext): TraceData? {
110+
* val oldState = traceThreadLocal.get()
111+
* traceThreadLocal.set(traceData)
112+
* return oldState
113+
* }
114+
*
115+
* override fun restoreThreadContext(context: CoroutineContext, oldState: TraceData?) {
116+
* traceThreadLocal.set(oldState)
117+
* }
118+
*
119+
* override fun copyForChild(): TraceContextElement {
120+
* // Copy from the ThreadLocal source of truth at child coroutine launch time. This makes
121+
* // ThreadLocal writes between resumption of the parent coroutine and the launch of the
122+
* // child coroutine visible to the child.
123+
* return TraceContextElement(traceThreadLocal.get()?.copy())
124+
* }
125+
*
126+
* override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext {
127+
* // Merge operation defines how to handle situations when both
128+
* // the parent coroutine has an element in the context and
129+
* // an element with the same key was also
130+
* // explicitly passed to the child coroutine.
131+
* // If merging does not require special behavior,
132+
* // the copy of the element can be returned.
133+
* return TraceContextElement(traceThreadLocal.get()?.copy())
134+
* }
135+
* }
136+
* ```
137+
*
138+
* A coroutine using this mechanism can safely call Java code that assumes the corresponding thread local element's
139+
* value is installed into the target thread local.
140+
*
141+
* ### Reentrancy and thread-safety
142+
*
143+
* Correct implementations of this interface must expect that calls to [restoreThreadContext]
144+
* may happen in parallel to the subsequent [updateThreadContext] and [restoreThreadContext] operations.
145+
*
146+
* Even though an element is copied for each child coroutine, an implementation should be able to handle the following
147+
* interleaving when a coroutine with the corresponding element is launched on a multithreaded dispatcher:
148+
*
149+
* ```
150+
* coroutine.updateThreadContext() // Thread #1
151+
* ... coroutine body ...
152+
* // suspension + immediate dispatch happen here
153+
* coroutine.updateThreadContext() // Thread #2, coroutine is already resumed
154+
* // ... coroutine body after suspension point on Thread #2 ...
155+
* coroutine.restoreThreadContext() // Thread #1, is invoked late because Thread #1 is slow
156+
* coroutine.restoreThreadContext() // Thread #2, may happen in parallel with the previous restore
157+
* ```
158+
*
159+
* All implementations of [CopyableThreadContextElement] should be thread-safe and guard their internal mutable state
160+
* within an element accordingly.
161+
*/
162+
@DelicateCoroutinesApi
163+
@ExperimentalCoroutinesApi
164+
public interface CopyableThreadContextElement<S> : ThreadContextElement<S> {
165+
166+
/**
167+
* Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child
168+
* coroutine's context that is under construction if the added context does not contain an element with the same [key].
169+
*
170+
* This function is called on the element each time a new coroutine inherits a context containing it,
171+
* and the returned value is folded into the context given to the child.
172+
*
173+
* Since this method is called whenever a new coroutine is launched in a context containing this
174+
* [CopyableThreadContextElement], implementations are performance-sensitive.
175+
*/
176+
public fun copyForChild(): CopyableThreadContextElement<S>
177+
178+
/**
179+
* Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child
180+
* coroutine's context that is under construction if the added context does contain an element with the same [key].
181+
*
182+
* This method is invoked on the original element, accepting as the parameter
183+
* the element that is supposed to overwrite it.
184+
*/
185+
public fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext
186+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
package kotlinx.coroutines
2+
3+
import kotlinx.coroutines.flow.flow
4+
import kotlinx.coroutines.flow.flowOn
5+
import kotlinx.coroutines.flow.single
6+
import kotlinx.coroutines.internal.Symbol
7+
import kotlinx.coroutines.internal.commonThreadLocal
8+
import kotlinx.coroutines.testing.TestBase
9+
import kotlin.coroutines.CoroutineContext
10+
import kotlin.test.Test
11+
import kotlin.test.assertEquals
12+
import kotlin.test.assertNotSame
13+
14+
class CommonThreadContextMutableCopiesTest : TestBase() {
15+
companion object {
16+
internal val threadLocalData = commonThreadLocal<MutableList<String>>(Symbol("ThreadLocalData"))
17+
}
18+
19+
class MyMutableElement(
20+
val mutableData: MutableList<String>
21+
) : CopyableThreadContextElement<MutableList<String>> {
22+
23+
companion object Key : CoroutineContext.Key<MyMutableElement>
24+
25+
override val key: CoroutineContext.Key<*>
26+
get() = Key
27+
28+
override fun updateThreadContext(context: CoroutineContext): MutableList<String> {
29+
val st = threadLocalData.get()
30+
threadLocalData.set(mutableData)
31+
return st
32+
}
33+
34+
override fun restoreThreadContext(context: CoroutineContext, oldState: MutableList<String>) {
35+
threadLocalData.set(oldState)
36+
}
37+
38+
override fun copyForChild(): MyMutableElement {
39+
return MyMutableElement(ArrayList(mutableData))
40+
}
41+
42+
override fun mergeForChild(overwritingElement: CoroutineContext.Element): MyMutableElement {
43+
overwritingElement as MyMutableElement // <- app-specific, may be another subtype
44+
return MyMutableElement((mutableData.toSet() + overwritingElement.mutableData).toMutableList())
45+
}
46+
}
47+
48+
@Test
49+
fun testDataIsCopied() = runTest {
50+
val root = MyMutableElement(ArrayList())
51+
launch(root) {
52+
val data = threadLocalData.get()
53+
expect(1)
54+
launch(root) {
55+
assertNotSame(data, threadLocalData.get())
56+
assertEquals(data, threadLocalData.get())
57+
finish(2)
58+
}
59+
}
60+
}
61+
62+
@Test
63+
fun testDataIsNotOverwritten() = runTest {
64+
val root = MyMutableElement(ArrayList())
65+
withContext(root) {
66+
expect(1)
67+
val originalData = threadLocalData.get()
68+
threadLocalData.get().add("X")
69+
launch {
70+
threadLocalData.get().add("Y")
71+
// Note here, +root overwrites the data
72+
launch(Dispatchers.Default + root) {
73+
assertEquals(listOf("X", "Y"), threadLocalData.get())
74+
assertNotSame(originalData, threadLocalData.get())
75+
finish(2)
76+
}
77+
}
78+
}
79+
}
80+
81+
@Test
82+
fun testDataIsMerged() = runTest {
83+
val root = MyMutableElement(ArrayList())
84+
withContext(root) {
85+
expect(1)
86+
val originalData = threadLocalData.get()
87+
threadLocalData.get().add("X")
88+
launch {
89+
threadLocalData.get().add("Y")
90+
// Note here, +root overwrites the data
91+
launch(Dispatchers.Default + MyMutableElement(mutableListOf("Z"))) {
92+
assertEquals(listOf("X", "Y", "Z"), threadLocalData.get())
93+
assertNotSame(originalData, threadLocalData.get())
94+
finish(2)
95+
}
96+
}
97+
}
98+
}
99+
100+
@Test
101+
fun testDataIsNotOverwrittenWithContext() = runTest {
102+
val root = MyMutableElement(ArrayList())
103+
withContext(root) {
104+
val originalData = threadLocalData.get()
105+
threadLocalData.get().add("X")
106+
expect(1)
107+
launch {
108+
threadLocalData.get().add("Y")
109+
// Note here, +root overwrites the data
110+
withContext(Dispatchers.Default + root) {
111+
assertEquals(listOf("X", "Y"), threadLocalData.get())
112+
assertNotSame(originalData, threadLocalData.get())
113+
finish(2)
114+
}
115+
}
116+
}
117+
}
118+
119+
@Test
120+
fun testDataIsCopiedForCoroutine() = runTest {
121+
val root = MyMutableElement(ArrayList())
122+
val originalData = root.mutableData
123+
expect(1)
124+
launch(root) {
125+
assertNotSame(originalData, threadLocalData.get())
126+
finish(2)
127+
}
128+
}
129+
130+
@Test
131+
fun testDataIsCopiedThroughFlowOnUndispatched() = runTest {
132+
expect(1)
133+
val root = MyMutableElement(ArrayList())
134+
val originalData = root.mutableData
135+
flow {
136+
assertNotSame(originalData, threadLocalData.get())
137+
emit(1)
138+
}
139+
.flowOn(root)
140+
.single()
141+
finish(2)
142+
}
143+
144+
@Test
145+
fun testDataIsCopiedThroughFlowOnDispatched() = runTest {
146+
expect(1)
147+
val root = MyMutableElement(ArrayList())
148+
val originalData = root.mutableData
149+
flow {
150+
assertNotSame(originalData, threadLocalData.get())
151+
emit(1)
152+
}
153+
.flowOn(root + Dispatchers.Default)
154+
.single()
155+
finish(2)
156+
}
157+
}

Diff for: kotlinx-coroutines-core/jsAndWasmShared/src/CoroutineContext.kt

+1-10
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,7 @@ import kotlin.coroutines.*
66
internal actual val DefaultDelay: Delay
77
get() = Dispatchers.Default as Delay
88

9-
public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext {
10-
val combined = coroutineContext + context
11-
return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null)
12-
combined + Dispatchers.Default else combined
13-
}
14-
15-
public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext {
16-
return this + addedContext
17-
}
18-
199
// No debugging facilities on Wasm and JS
2010
internal actual fun Continuation<*>.toDebugString(): String = toString()
2111
internal actual val CoroutineContext.coroutineName: String? get() = null // not supported on Wasm and JS
12+
internal actual fun wrapContextWithDebug(context: CoroutineContext): CoroutineContext = context

0 commit comments

Comments
 (0)