1
1
package kotlinx.coroutines
2
2
3
+ import kotlinx.coroutines.internal.*
4
+ import kotlinx.coroutines.internal.CoroutineStackFrame
5
+ import kotlinx.coroutines.internal.NO_THREAD_ELEMENTS
3
6
import kotlinx.coroutines.internal.ScopeCoroutine
7
+ import kotlinx.coroutines.internal.restoreThreadContext
8
+ import kotlinx.coroutines.internal.updateThreadContext
4
9
import kotlin.coroutines.*
10
+ import kotlin.jvm.*
5
11
6
12
@PublishedApi // Used from kotlinx-coroutines-test via suppress, not part of ABI
7
13
internal actual val DefaultDelay : Delay
@@ -18,14 +24,106 @@ public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineCo
18
24
}
19
25
20
26
// No debugging facilities on Wasm and JS
21
- internal actual inline fun <T > withCoroutineContext (context : CoroutineContext , countOrElement : Any? , block : () -> T ): T = block()
22
- internal actual inline fun <T > withContinuationContext (continuation : Continuation <* >, countOrElement : Any? , block : () -> T ): T = block()
27
+ internal actual inline fun <T > withCoroutineContext (context : CoroutineContext , countOrElement : Any? , block : () -> T ): T {
28
+ val oldValue = updateThreadContext(context, countOrElement)
29
+ try {
30
+ return block()
31
+ } finally {
32
+ restoreThreadContext(context, oldValue)
33
+ }
34
+ }
35
+
36
+ internal actual inline fun <T > withContinuationContext (continuation : Continuation <* >, countOrElement : Any? , block : () -> T ): T {
37
+ val context = continuation.context
38
+ val oldValue = updateThreadContext(context, countOrElement)
39
+ val undispatchedCompletion = if (oldValue != = NO_THREAD_ELEMENTS ) {
40
+ // Only if some values were replaced we'll go to the slow path of figuring out where/how to restore them
41
+ continuation.updateUndispatchedCompletion(context, oldValue)
42
+ } else {
43
+ null // fast path -- don't even try to find undispatchedCompletion as there's nothing to restore in the context
44
+ }
45
+ try {
46
+ return block()
47
+ } finally {
48
+ if (undispatchedCompletion == null || undispatchedCompletion.clearThreadContext()) {
49
+ restoreThreadContext(context, oldValue)
50
+ }
51
+ }
52
+ }
53
+
54
+ internal fun Continuation <* >.updateUndispatchedCompletion (context : CoroutineContext , oldValue : Any? ): UndispatchedCoroutine <* >? {
55
+ if (this !is CoroutineStackFrame ) return null
56
+ /*
57
+ * Fast-path to detect whether we have undispatched coroutine at all in our stack.
58
+ *
59
+ * Implementation note.
60
+ * If we ever find that stackwalking for thread-locals is way too slow, here is another idea:
61
+ * 1) Store undispatched coroutine right in the `UndispatchedMarker` instance
62
+ * 2) To avoid issues with cross-dispatch boundary, remove `UndispatchedMarker`
63
+ * from the context when creating dispatched coroutine in `withContext`.
64
+ * Another option is to "unmark it" instead of removing to save an allocation.
65
+ * Both options should work, but it requires more careful studying of the performance
66
+ * and, mostly, maintainability impact.
67
+ */
68
+ val potentiallyHasUndispatchedCoroutine = context[UndispatchedMarker ] != = null
69
+ if (! potentiallyHasUndispatchedCoroutine) return null
70
+ val completion = undispatchedCompletion()
71
+ completion?.saveThreadContext(context, oldValue)
72
+ return completion
73
+ }
74
+
75
+ internal tailrec fun CoroutineStackFrame.undispatchedCompletion (): UndispatchedCoroutine <* >? {
76
+ // Find direct completion of this continuation
77
+ val completion: CoroutineStackFrame = when (this ) {
78
+ is DispatchedCoroutine <* > -> return null
79
+ else -> callerFrame ? : return null // something else -- not supported
80
+ }
81
+ if (completion is UndispatchedCoroutine <* >) return completion // found UndispatchedCoroutine!
82
+ return completion.undispatchedCompletion() // walk up the call stack with tail call
83
+ }
84
+
85
+ /* *
86
+ * Marker indicating that [UndispatchedCoroutine] exists somewhere up in the stack.
87
+ * Used as a performance optimization to avoid stack walking where it is not necessary.
88
+ */
89
+ private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key<UndispatchedMarker> {
90
+ override val key: CoroutineContext .Key <* >
91
+ get() = this
92
+ }
93
+
23
94
internal actual fun Continuation <* >.toDebugString (): String = toString()
24
95
internal actual val CoroutineContext .coroutineName: String? get() = null // not supported on Wasm and JS
25
96
26
97
internal actual class UndispatchedCoroutine <in T > actual constructor(
27
98
context : CoroutineContext ,
28
99
uCont : Continuation <T >
29
100
) : ScopeCoroutine<T>(context, uCont) {
30
- override fun afterResume (state : Any? ) = uCont.resumeWith(recoverResult(state, uCont))
101
+
102
+ private var savedContext: CoroutineContext ? = null
103
+ private var savedOldValue: Any? = null
104
+
105
+ fun saveThreadContext (context : CoroutineContext , oldValue : Any? ) {
106
+ savedContext = context
107
+ savedOldValue = oldValue
108
+ }
109
+
110
+ fun clearThreadContext (): Boolean {
111
+ if (savedContext == null ) return false
112
+ savedContext = null
113
+ savedOldValue = null
114
+ return true
115
+ }
116
+
117
+ override fun afterResume (state : Any? ) {
118
+ savedContext?.let { context ->
119
+ restoreThreadContext(context, savedOldValue)
120
+ savedContext = null
121
+ savedOldValue = null
122
+ }
123
+ // resume undispatched -- update context but stay on the same dispatcher
124
+ val result = recoverResult(state, uCont)
125
+ withContinuationContext(uCont, null ) {
126
+ uCont.resumeWith(result)
127
+ }
128
+ }
31
129
}
0 commit comments