Skip to content

Remove reference counters in the concurrent doubly-linked list used in BufferedChannel and Semaphore #4302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: develop
Choose a base branch
from
7 changes: 5 additions & 2 deletions benchmarks/src/jmh/kotlin/benchmarks/ChannelSinkBenchmark.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import kotlin.coroutines.*
@State(Scope.Benchmark)
@Fork(1)
open class ChannelSinkBenchmark {
@Param("${Channel.RENDEZVOUS}", "${Channel.BUFFERED}")
var capacity: Int = 0

private val tl = ThreadLocal.withInitial({ 42 })
private val tl2 = ThreadLocal.withInitial({ 239 })

Expand Down Expand Up @@ -42,15 +45,15 @@ open class ChannelSinkBenchmark {
.fold(0) { a, b -> a + b }
}

private fun Channel.Factory.range(start: Int, count: Int, context: CoroutineContext) = GlobalScope.produce(context) {
private fun Channel.Factory.range(start: Int, count: Int, context: CoroutineContext) = GlobalScope.produce(context, capacity) {
for (i in start until (start + count))
send(i)
}

// Migrated from deprecated operators, are good only for stressing channels

private fun <E> ReceiveChannel<E>.filter(context: CoroutineContext = Dispatchers.Unconfined, predicate: suspend (E) -> Boolean): ReceiveChannel<E> =
GlobalScope.produce(context, onCompletion = { cancel() }) {
GlobalScope.produce(context, capacity, onCompletion = { cancel() }) {
for (e in this@filter) {
if (predicate(e)) send(e)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import kotlin.coroutines.*
@State(Scope.Benchmark)
@Fork(2)
open class ChannelSinkDepthBenchmark {
@Param("${Channel.RENDEZVOUS}", "${Channel.BUFFERED}")
var capacity: Int = 0

private val tl = ThreadLocal.withInitial({ 42 })

private val unconfinedOneElement = Dispatchers.Unconfined + tl.asContextElement()
Expand Down Expand Up @@ -45,7 +48,7 @@ open class ChannelSinkDepthBenchmark {
}

private fun Channel.Factory.range(start: Int, count: Int, context: CoroutineContext) =
GlobalScope.produce(context) {
GlobalScope.produce(context, capacity) {
for (i in start until (start + count))
send(i)
}
Expand All @@ -57,7 +60,7 @@ open class ChannelSinkDepthBenchmark {
context: CoroutineContext = Dispatchers.Unconfined,
predicate: suspend (Int) -> Boolean
): ReceiveChannel<Int> =
GlobalScope.produce(context, onCompletion = { cancel() }) {
GlobalScope.produce(context, capacity, onCompletion = { cancel() }) {
deeplyNestedFilter(this, callTraceDepth, predicate)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import kotlin.coroutines.*
@State(Scope.Benchmark)
@Fork(1)
open class ChannelSinkNoAllocationsBenchmark {
@Param("${Channel.RENDEZVOUS}", "${Channel.BUFFERED}")
var capacity: Int = 0

private val unconfined = Dispatchers.Unconfined

@Benchmark
Expand All @@ -26,7 +29,7 @@ open class ChannelSinkNoAllocationsBenchmark {
return size
}

private fun Channel.Factory.range(context: CoroutineContext) = GlobalScope.produce(context) {
private fun Channel.Factory.range(context: CoroutineContext) = GlobalScope.produce(context, capacity) {
for (i in 0 until 100_000)
send(Unit) // no allocations
}
Expand Down
150 changes: 59 additions & 91 deletions kotlinx-coroutines-core/common/src/channels/BufferedChannel.kt
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,16 @@ internal open class BufferedChannel<E>(
private val receiveSegment: AtomicRef<ChannelSegment<E>>
private val bufferEndSegment: AtomicRef<ChannelSegment<E>>

/*
These values are used in [ChannelSegment.isLeftmostOrProcessed].
They help to detect when the `prev` reference of the segment should be cleaned.
*/
internal val sendSegmentId: Long get() = sendSegment.value.id
internal val receiveSegmentId: Long get() = receiveSegment.value.id

init {
@Suppress("LeakingThis")
val firstSegment = ChannelSegment(id = 0, prev = null, channel = this, pointers = 3)
val firstSegment = ChannelSegment(id = 0, prev = null, channel = this)
sendSegment = atomic(firstSegment)
receiveSegment = atomic(firstSegment)
// If this channel is rendezvous or has unlimited capacity, the algorithm never
Expand Down Expand Up @@ -143,9 +150,9 @@ internal open class BufferedChannel<E>(
the segment and the index in it. */
segment: ChannelSegment<E>,
index: Int,
/** The element to be inserted. */
/* The element to be inserted. */
element: E,
/** The global index of the cell. */
/* The global index of the cell. */
s: Long
) = suspendCancellableCoroutineReusable sc@{ cont ->
sendImplOnNoWaiter( // <-- this is an inline function
Expand Down Expand Up @@ -299,16 +306,8 @@ internal open class BufferedChannel<E>(
// the channel is already closed, storing a waiter is illegal, so
// the algorithm stores the `INTERRUPTED_SEND` token in this case.
when (updateCellSend(segment, i, element, s, waiter, closed)) {
RESULT_RENDEZVOUS -> {
// A rendezvous with a receiver has happened.
// The previous segments are no longer needed
// for the upcoming requests, so the algorithm
// resets the link to the previous segment.
segment.cleanPrev()
return onRendezvousOrBuffered()
}
RESULT_BUFFERED -> {
// The element has been buffered.
RESULT_RENDEZVOUS, RESULT_BUFFERED -> {
// The element has been buffered or a rendezvous with a receiver has happened.
return onRendezvousOrBuffered()
}
RESULT_SUSPEND -> {
Expand All @@ -325,17 +324,11 @@ internal open class BufferedChannel<E>(
}
RESULT_CLOSED -> {
// This channel is closed.
// In case this segment is already or going to be
// processed by a receiver, ensure that all the
// previous segments are unreachable.
if (s < receiversCounter) segment.cleanPrev()
return onClosed()
}
RESULT_FAILED -> {
// Either the cell stores an interrupted receiver,
// or it was poisoned by a concurrent receiver.
// In both cases, all the previous segments are already processed,
segment.cleanPrev()
continue
}
RESULT_SUSPEND_NO_WAITER -> {
Expand Down Expand Up @@ -392,22 +385,16 @@ internal open class BufferedChannel<E>(
// restarting the operation from the beginning on failure.
// Check the `sendImpl(..)` function for the comments.
when (updateCellSend(segment, index, element, s, waiter, false)) {
RESULT_RENDEZVOUS -> {
segment.cleanPrev()
onRendezvousOrBuffered()
}
RESULT_BUFFERED -> {
RESULT_RENDEZVOUS, RESULT_BUFFERED -> {
onRendezvousOrBuffered()
}
RESULT_SUSPEND -> {
waiter.prepareSenderForSuspension(segment, index)
}
RESULT_CLOSED -> {
if (s < receiversCounter) segment.cleanPrev()
onClosed()
}
RESULT_FAILED -> {
segment.cleanPrev()
sendImpl(
element = element,
waiter = waiter,
Expand Down Expand Up @@ -857,14 +844,9 @@ internal open class BufferedChannel<E>(
when {
updCellResult === FAILED -> {
// The cell is poisoned; restart from the beginning.
// To avoid memory leaks, we also need to reset
// the `prev` pointer of the working segment.
if (r < sendersCounter) segment.cleanPrev()
}
else -> { // element
// A buffered element was retrieved from the cell.
// Clean the reference to the previous segment.
segment.cleanPrev()
@Suppress("UNCHECKED_CAST")
onUndeliveredElement?.callUndeliveredElementCatchingException(updCellResult as E)?.let { throw it }
}
Expand Down Expand Up @@ -938,9 +920,6 @@ internal open class BufferedChannel<E>(
// but failed: either the opposite request has
// already been cancelled or the cell is poisoned.
// Restart from the beginning in this case.
// To avoid memory leaks, we also need to reset
// the `prev` pointer of the working segment.
if (r < sendersCounter) segment.cleanPrev()
continue
}
updCellResult === SUSPEND_NO_WAITER -> {
Expand All @@ -951,8 +930,6 @@ internal open class BufferedChannel<E>(
else -> { // element
// Either a buffered element was retrieved from the cell
// or a rendezvous with a waiting sender has happened.
// Clean the reference to the previous segment before finishing.
segment.cleanPrev()
@Suppress("UNCHECKED_CAST")
onElementRetrieved(updCellResult as E)
}
Expand Down Expand Up @@ -987,7 +964,6 @@ internal open class BufferedChannel<E>(
waiter.prepareReceiverForSuspension(segment, index)
}
updCellResult === FAILED -> {
if (r < sendersCounter) segment.cleanPrev()
receiveImpl(
waiter = waiter,
onElementRetrieved = onElementRetrieved,
Expand All @@ -996,7 +972,6 @@ internal open class BufferedChannel<E>(
)
}
else -> {
segment.cleanPrev()
@Suppress("UNCHECKED_CAST")
onElementRetrieved(updCellResult as E)
}
Expand Down Expand Up @@ -2299,7 +2274,6 @@ internal open class BufferedChannel<E>(
// Otherwise, if the required segment is removed, the operation restarts.
if (receiveSegment.value.id < id) return false else continue
}
segment.cleanPrev() // all the previous segments are no longer needed.
// Does the `r`-th cell contain waiting sender or buffered element?
val i = (r % SEGMENT_SIZE).toInt()
if (isCellNonEmpty(segment, i, r)) return true
Expand Down Expand Up @@ -2398,12 +2372,6 @@ internal open class BufferedChannel<E>(
// This channel is already closed or cancelled; help to complete
// the closing or cancellation procedure.
completeCloseOrCancel()
// Clean the `prev` reference of the provided segment
// if all the previous cells are already covered by senders.
// It is important to clean the `prev` reference only in
// this case, as the closing/cancellation procedure may
// need correct value to traverse the linked list from right to left.
if (startFrom.id * SEGMENT_SIZE < receiversCounter) startFrom.cleanPrev()
// As the required segment is not found and cannot be allocated, return `null`.
null
} else {
Expand All @@ -2415,12 +2383,6 @@ internal open class BufferedChannel<E>(
// segment with `id` not lower than the required one.
// Skip the sequence of removed cells in O(1).
updateSendersCounterIfLower(segment.id * SEGMENT_SIZE)
// Clean the `prev` reference of the provided segment
// if all the previous cells are already covered by senders.
// It is important to clean the `prev` reference only in
// this case, as the closing/cancellation procedure may
// need correct value to traverse the linked list from right to left.
if (segment.id * SEGMENT_SIZE < receiversCounter) segment.cleanPrev()
// As the required segment is not found and cannot be allocated, return `null`.
null
} else {
Expand Down Expand Up @@ -2453,33 +2415,21 @@ internal open class BufferedChannel<E>(
// This channel is already closed or cancelled; help to complete
// the closing or cancellation procedure.
completeCloseOrCancel()
// Clean the `prev` reference of the provided segment
// if all the previous cells are already covered by senders.
// It is important to clean the `prev` reference only in
// this case, as the closing/cancellation procedure may
// need correct value to traverse the linked list from right to left.
if (startFrom.id * SEGMENT_SIZE < sendersCounter) startFrom.cleanPrev()
// As the required segment is not found and cannot be allocated, return `null`.
null
} else {
// Get the found segment.
val segment = it.segment
// Advance the `bufferEnd` segment if required.
if (!isRendezvousOrUnlimited && id <= bufferEndCounter / SEGMENT_SIZE) {
bufferEndSegment.moveForward(segment)
moveSegmentBufferEndToSpecifiedOrLast(id, bufferEndSegment.value)
}
// Is the required segment removed?
if (segment.id > id) {
// The required segment has been removed; `segment` is the first
// segment with `id` not lower than the required one.
// Skip the sequence of removed cells in O(1).
updateReceiversCounterIfLower(segment.id * SEGMENT_SIZE)
// Clean the `prev` reference of the provided segment
// if all the previous cells are already covered by senders.
// It is important to clean the `prev` reference only in
// this case, as the closing/cancellation procedure may
// need correct value to traverse the linked list from right to left.
if (segment.id * SEGMENT_SIZE < sendersCounter) segment.cleanPrev()
// As the required segment is already removed, return `null`.
null
} else {
Expand Down Expand Up @@ -2535,30 +2485,22 @@ internal open class BufferedChannel<E>(
}

/**
* Updates [bufferEndSegment] to the one with the specified [id] or
* to the last existing segment, if the required segment is not yet created.
*
* Unlike [findSegmentBufferEnd], this function does not allocate new segments.
* Serves as a wrapper for an inline function [AtomicRef.moveToSpecifiedOrLast]
*/
private fun moveSegmentBufferEndToSpecifiedOrLast(id: Long, startFrom: ChannelSegment<E>) {
// Start searching the required segment from the specified one.
var segment: ChannelSegment<E> = startFrom
while (segment.id < id) {
segment = segment.next ?: break
}
// Skip all removed segments and try to update `bufferEndSegment`
// to the first non-removed one. This part should succeed eventually,
// as the tail segment is never removed.
while (true) {
while (segment.isRemoved) {
segment = segment.next ?: break
}
// Try to update `bufferEndSegment`. On failure,
// the found segment is already removed, so it
// should be skipped.
if (bufferEndSegment.moveForward(segment)) return
}
}
private fun moveSegmentBufferEndToSpecifiedOrLast(id: Long, startFrom: ChannelSegment<E>) =
bufferEndSegment.moveToSpecifiedOrLast(id, startFrom)

/**
* Serves as a wrapper for an inline function [AtomicRef.moveToSpecifiedOrLast]
*/
private fun moveSegmentReceiveToSpecifiedOrLast(id: Long, startFrom: ChannelSegment<E>) =
receiveSegment.moveToSpecifiedOrLast(id, startFrom)

/**
* Serves as a wrapper for an inline function [AtomicRef.moveToSpecifiedOrLast]
*/
private fun moveSegmentSendToSpecifiedOrLast(id: Long, startFrom: ChannelSegment<E>) =
sendSegment.moveToSpecifiedOrLast(id, startFrom)

/**
* Updates the `senders` counter if its value
Expand Down Expand Up @@ -2588,6 +2530,17 @@ internal open class BufferedChannel<E>(
if (receivers.compareAndSet(cur, value)) return
}

/**
This method is used in the physical removal of the segment. It helps to move pointers forward from
the segment which was physically removed.
*/
internal fun movePointersForwardFromRemovedSegment(removedSegment: ChannelSegment<E>) {
if (!removedSegment.isRemoved) return
if (removedSegment == sendSegment.value) moveSegmentSendToSpecifiedOrLast(removedSegment.id, removedSegment)
if (removedSegment == receiveSegment.value) moveSegmentReceiveToSpecifiedOrLast(removedSegment.id, removedSegment)
if (removedSegment == bufferEndSegment.value) moveSegmentBufferEndToSpecifiedOrLast(removedSegment.id, removedSegment)
}

// ###################
// # Debug Functions #
// ###################
Expand Down Expand Up @@ -2799,7 +2752,7 @@ internal open class BufferedChannel<E>(
* to update [BufferedChannel.sendSegment], [BufferedChannel.receiveSegment],
* and [BufferedChannel.bufferEndSegment] correctly.
*/
internal class ChannelSegment<E>(id: Long, prev: ChannelSegment<E>?, channel: BufferedChannel<E>?, pointers: Int) : Segment<ChannelSegment<E>>(id, prev, pointers) {
internal class ChannelSegment<E>(id: Long, prev: ChannelSegment<E>?, channel: BufferedChannel<E>?) : Segment<ChannelSegment<E>>(id, prev) {
private val _channel: BufferedChannel<E>? = channel
val channel get() = _channel!! // always non-null except for `NULL_SEGMENT`

Expand Down Expand Up @@ -2841,6 +2794,22 @@ internal class ChannelSegment<E>(id: Long, prev: ChannelSegment<E>?, channel: Bu

internal fun getAndSetState(index: Int, update: Any?) = data[index * 2 + 1].getAndSet(update)

/**
* Shows if all segments going before this segment have been processed.
* When the value is true, the [prev] reference of the segment should be `null`.
*/
override val isLeftmostOrProcessed: Boolean get() = id <= channel.sendSegmentId && id <= channel.receiveSegmentId

/**
* Removes the segment physically from the segment list.
*
* After the physical removal is finished and there are channel pointers referencing the removed segment,
* the [BufferedChannel.movePointersForwardFromRemovedSegment] method is invoked to move them further on the segment list.
*/
override fun remove() {
super.remove()
channel.movePointersForwardFromRemovedSegment(this)
}

// ########################
// # Cancellation Support #
Expand Down Expand Up @@ -2926,10 +2895,9 @@ internal fun <E> createSegmentFunction(): KFunction2<Long, ChannelSegment<E>, Ch
private fun <E> createSegment(id: Long, prev: ChannelSegment<E>) = ChannelSegment(
id = id,
prev = prev,
channel = prev.channel,
pointers = 0
channel = prev.channel
)
private val NULL_SEGMENT = ChannelSegment<Any?>(id = -1, prev = null, channel = null, pointers = 0)
private val NULL_SEGMENT = ChannelSegment<Any?>(id = -1, prev = null, channel = null)

/**
* Number of cells in each segment.
Expand Down
Loading