Skip to content

ChannelFlow with unfused flowOn does not update/restore ThreadContextElement after collect #4403

Open
@Zincfox

Description

@Zincfox

Describe the bug

When collecting a channelFlow that was flowOn-d a ThreadContextElement, while preventing fusion, the associated ThreadContextElement is not always properly restored afterwards until the next suspension point.

The specific conditions seem to be:

  • Use of channelFlow {} (normal flow {} works as expected)
  • Map the ChannelFlow to prevent fusion (works with fusion as expected)
  • Use flowOn on the resulting mapped ChannelFlow
  • After collect {}ing, check before the next suspension point (tested by adding yield() before check, which updates as expected during suspension)

I could not find a matching issue here besides potentially #4121, which I admittedly did not fully understand, but I believe the presence of a ThreadContextElement in the parent CoroutineContext, bundled with the fact that this does not occur when using a normal Flow instead of a ChannelFlow or allowing fusion or adding a yield() call before observing the result, indicates that this is a different issue with similar symptoms (possibly ChannelFlow-specific)?

Version

kotlinx-coroutines:1.10.1
kotlinx-coroutines-test:1.10.1

build.gradle.kts:

plugins {
    kotlin("jvm") version "2.1.10"
}
//[...]
dependencies {
//[...]
    implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.10.1")
    testImplementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.10.1")
//[...]
}
tasks.test {
    useJUnitPlatform()
}
kotlin {
    jvmToolchain(17)
}

Provide a Reproducer

Minimal reproducer (ThreadLocalElement)
import kotlinx.coroutines.asContextElement
import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.yield
import kotlin.test.Test
import kotlin.test.assertEquals

class CoroutineThreadLocalTest {

    companion object {

        val threadLocal = ThreadLocal<String?>()
    }

    @Test
    fun testChannelFlowThreadLocal_minimal() = runBlocking(threadLocal.asContextElement(value = "1")) {
        channelFlow<Int> {
            send(1)
        }.map { //.map prevents fusion of flowOn into channelFlow
            it
        }.flowOn(
            threadLocal.asContextElement(value="2")
        ).collect {  }

        //yield() //adding yield() here makes this test pass
        //!!fails!!
        assertEquals("1", threadLocal.get(), "Unexpected threadlocal value after channelFlow collection")
    }
}
Extended reproducer (ThreadLocalElement)
import kotlinx.coroutines.asContextElement
import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.yield
import kotlin.test.Test
import kotlin.test.assertEquals

class CoroutineThreadLocalTest {

    companion object {

        val threadLocal = ThreadLocal<String?>()
    }

    @Test
    fun testChannelFlowThreadLocal() = runBlocking(threadLocal.asContextElement(value = "1")) {
        //passes
        assertEquals("1", threadLocal.get(), "Unexpected initial threadlocal value")

        val constructedFlow = channelFlow<Int> {
            //passes
            assertEquals("2", threadLocal.get(), "Unexpected threadlocal value inside channelFlow before send")
            send(1)
            //passes
            assertEquals("2", threadLocal.get(), "Unexpected threadlocal value inside channelFlow after send")
        }

        //passes
        assertEquals("1", threadLocal.get(), "Unexpected threadlocal value after channelFlow construction")

        val transformedFlow = constructedFlow.map {
            //passes
            assertEquals("2", threadLocal.get(), "Unexpected threadlocal value inside map")
            it
        } //.map prevents fusion of flowOn into channelFlow

        //passes
        assertEquals("1", threadLocal.get(), "Unexpected threadlocal value after channelFlow mapping")

        val flowWithThreadLocal = transformedFlow.flowOn(threadLocal.asContextElement(value = "2"))

        //passes
        assertEquals("1", threadLocal.get(), "Unexpected threadlocal value after channelFlow threadlocal-binding")

        flowWithThreadLocal.collect {
            //passes
            assertEquals("1", threadLocal.get(), "Unexpected threadlocal value inside collect")
        }
        //yield() //adding yield here makes this test pass
        //!!fails!!
        assertEquals("1", threadLocal.get(), "Unexpected threadlocal value after channelFlow collection")
    }
}
Instrumented reproducer
import kotlinx.coroutines.ThreadContextElement
import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.yield
import kotlin.coroutines.CoroutineContext
import kotlin.test.Test
import kotlin.test.assertEquals

class InstrumentedCollectChannelFlowTest {
    class ThreadContextElementInstrumentation(
        val value: String,
        val threadLocal: ThreadLocal<String?> = ThreadLocal(),
    ) : ThreadContextElement<String?> {

        companion object {
            val KEY = object : CoroutineContext.Key<ThreadContextElementInstrumentation> {}

            var indentationLevel = 0

            fun printIndented(text: String) {
                println("  ".repeat(indentationLevel)+text)
            }
        }

        override val key: CoroutineContext.Key<*> = KEY

        override fun updateThreadContext(context: CoroutineContext): String? {
            val oldValue = threadLocal.get()
            printIndented("Updating thread context: $oldValue => {$value")
            threadLocal.set(value)
            indentationLevel++
            return oldValue
        }

        override fun restoreThreadContext(context: CoroutineContext, oldState: String?) {
            indentationLevel--
            printIndented("Restoring thread context: $value} => $oldState")
            threadLocal.set(oldState)
        }
    }

    companion object {

        val threadLocal = ThreadLocal<String?>()
    }

    @Test
    fun testChannelFlowThreadLocalInstrumented() = runBlocking(ThreadContextElementInstrumentation(
        value="1",
        threadLocal = threadLocal,
    )) {
        ThreadContextElementInstrumentation.printIndented("toplevel threadlocal value check initial (==1)")
        //passes
        assertEquals("1", threadLocal.get(), "Unexpected initial threadlocal value")

        val constructedFlow = channelFlow<Int> {
            ThreadContextElementInstrumentation.printIndented("channelFlow threadlocal value check before send (==2)")
            //passes
            assertEquals("2", threadLocal.get(), "Unexpected threadlocal value inside channelFlow before send")
            send(1)
            ThreadContextElementInstrumentation.printIndented("channelFlow threadlocal value check after send (==2)")
            //passes
            assertEquals("2", threadLocal.get(), "Unexpected threadlocal value inside channelFlow after send")
        }

        ThreadContextElementInstrumentation.printIndented("toplevel threadlocal value check after channelFlow construction (==1)")
        //passes
        assertEquals("1", threadLocal.get(), "Unexpected threadlocal value after channelFlow construction")

        val transformedFlow = constructedFlow.map {
            ThreadContextElementInstrumentation.printIndented("channelFlow map threadlocal value check (==2)")
            //passes
            assertEquals("2", threadLocal.get(), "Unexpected threadlocal value inside map")
            it
        } //.map prevents fusion of flowOn into channelFlow

        ThreadContextElementInstrumentation.printIndented("toplevel threadlocal value check after mapping (==1)")
        //passes
        assertEquals("1", threadLocal.get(), "Unexpected threadlocal value after channelFlow mapping")

        val flowWithThreadLocal = transformedFlow.flowOn(ThreadContextElementInstrumentation(
            value="2",
            threadLocal = threadLocal,
        ))

        ThreadContextElementInstrumentation.printIndented("toplevel threadlocal value check after flowOn (==1)")
        //passes
        assertEquals("1", threadLocal.get(), "Unexpected threadlocal value after channelFlow threadlocal-binding")

        flowWithThreadLocal.collect {
            ThreadContextElementInstrumentation.printIndented("channelFlow collect threadlocal value check (==1)")
            //passes
            assertEquals("1", threadLocal.get(), "Unexpected threadlocal value inside collect")
        }

        //ThreadContextElementInstrumentation.printIndented("toplevel threadlocal before yield")
        //yield() //adding yield here makes this test pass

        ThreadContextElementInstrumentation.printIndented("toplevel final threadlocal value check (==1)")
        //!!fails!!
        assertEquals("1", threadLocal.get(), "Unexpected threadlocal value after channelFlow collection")
    }
}

Results in:

Updating thread context: null => {1
  toplevel threadlocal value check initial (==1)
  toplevel threadlocal value check after channelFlow construction (==1)
  toplevel threadlocal value check after mapping (==1)
  toplevel threadlocal value check after flowOn (==1)
  Updating thread context: 1 => {2
  Restoring thread context: 2} => 1
Restoring thread context: 1} => null
Updating thread context: null => {2
  channelFlow threadlocal value check before send (==2)
  channelFlow threadlocal value check after send (==2)
Restoring thread context: 2} => null
Updating thread context: null => {2
  channelFlow map threadlocal value check (==2)
  Updating thread context: 2 => {1
    channelFlow collect threadlocal value check (==1)
  Restoring thread context: 1} => 2
//<<<< without yield
  toplevel final threadlocal value check (==1)
Restoring thread context: 2} => null
//====
  toplevel threadlocal before yield
Restoring thread context: 2} => null
Updating thread context: null => {1
  toplevel final threadlocal value check (==1)
Restoring thread context: 1} => null
// with yield >>>>

Issue origin

  • Originally discovered during instrumentation of private project with opentelemetry traces at call with Flow<T>.toSet(), where the Flow was constructed by a private http-request-library, but the context leaked from the toSet() collection point.
  • Added tests to the private http-request-library (that targets common kotlin), issue did not occur when targeting js as that platform uses a custom CoroutineInterceptor instead of the KotlinContextElement (based on ThreadContextElement) from opentelemetry-java.
  • Reported to opentelemetry-java as Kotlin extension: Collecting ChannelFlow can result in mismatching Contexts (also see there for a reproducer using their library)
  • Further analysis / instrumentation showed that the ThreadContextElement itself seems to not be correctly restored/updated after collection, leaving either UserError (my own or opentelemetry-java) or a bug in kotlinx-coroutines?
  • Seems distinct from ThreadLocal.asContextElement may not be cleaned up when used with Dispatchers.Main.immediate #4121, as the setup for this issue includes a ThreadContextElement set as parent across the whole test, which also actually does update/restore correctly if the yield() call is added -> Does "accessing thread-local from coroutine without the corresponding context element returns undefined value" still apply?
  • -> New issue here

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions