Skip to content

Ability to specify the number of permits to acquire and release #1553

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
269 changes: 190 additions & 79 deletions kotlinx-coroutines-core/common/src/sync/Semaphore.kt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public interface Semaphore {
public val availablePermits: Int

/**
* Acquires a permit from this semaphore, suspending until one is available.
* Acquires the given number of permits from this semaphore, suspending until ones are available.
* All suspending acquirers are processed in first-in-first-out (FIFO) order.
*
* This suspending function is cancellable. If the [Job] of the current coroutine is cancelled or completed while this
Expand All @@ -36,23 +36,34 @@ public interface Semaphore {
* Use [CoroutineScope.isActive] or [CoroutineScope.ensureActive] to periodically
* check for cancellation in tight loops if needed.
*
* Use [tryAcquire] to try acquire a permit of this semaphore without suspension.
* Use [tryAcquire] to try acquire the given number of permits of this semaphore without suspension.
*
* @param permits the number of permits to acquire
*
* @throws [IllegalArgumentException] if [permits] is less than or equal to zero.
*/
public suspend fun acquire()
public suspend fun acquire(permits: Int = 1)

/**
* Tries to acquire a permit from this semaphore without suspension.
* Tries to acquire the given number of permits from this semaphore without suspension.
*
* @param permits the number of permits to acquire
* @return `true` if all permits were acquired, `false` otherwise.
*
* @return `true` if a permit was acquired, `false` otherwise.
* @throws [IllegalArgumentException] if [permits] is less than or equal to zero.
*/
public fun tryAcquire(): Boolean
public fun tryAcquire(permits: Int = 1): Boolean

/**
* Releases a permit, returning it into this semaphore. Resumes the first
* suspending acquirer if there is one at the point of invocation.
* Throws [IllegalStateException] if the number of [release] invocations is greater than the number of preceding [acquire].
* Releases the given number of permits, returning them into this semaphore. Resumes the first
* suspending acquirer if there is one at the point of invocation and the requested number of permits is available.
*
* @param permits the number of permits to release
*
* @throws [IllegalArgumentException] if [permits] is less than or equal to zero.
* @throws [IllegalStateException] if the number of [release] invocations is greater than the number of preceding [acquire].
*/
public fun release()
public fun release(permits: Int = 1)
}

/**
Expand Down Expand Up @@ -96,8 +107,8 @@ private class SemaphoreImpl(
* and the maximum number of waiting acquirers cannot be greater than 2^31 in any
* real application.
*/
private val _availablePermits = atomic(permits - acquiredPermits)
override val availablePermits: Int get() = max(_availablePermits.value, 0)
private val permitsBalance = atomic(permits - acquiredPermits)
override val availablePermits: Int get() = max(permitsBalance.value, 0)

// The queue of waiting acquirers is essentially an infinite array based on `SegmentQueue`;
// each segment contains a fixed number of slots. To determine a slot for each enqueue
Expand All @@ -107,105 +118,205 @@ private class SemaphoreImpl(
private val enqIdx = atomic(0L)
private val deqIdx = atomic(0L)

override fun tryAcquire(): Boolean {
_availablePermits.loop { p ->
if (p <= 0) return false
if (_availablePermits.compareAndSet(p, p - 1)) return true
/**
* The remaining permits from release operations, which could not be spent, because the next slot was not defined
*/
internal val accumulator = atomic(0)

override fun tryAcquire(permits: Int): Boolean {
require(permits > 0) { "The number of acquired permits must be greater than 0" }
permitsBalance.loop { p ->
if (p < permits) return false
if (permitsBalance.compareAndSet(p, p - permits)) return true
}
}

override suspend fun acquire() {
val p = _availablePermits.getAndDecrement()
if (p > 0) return // permit acquired
addToQueueAndSuspend()
override suspend fun acquire(permits: Int) {
require(permits > 0) { "The number of acquired permits must be greater than 0" }
val p = permitsBalance.getAndAdd(-permits)
if (p >= permits) return // permits are acquired
tryToAddToQueue(permits)
}

override fun release() {
val p = incPermits()
override fun release(permits: Int) {
require(permits > 0) { "The number of released permits must be greater than 0" }
val p = incPermits(permits)
if (p >= 0) return // no waiters
resumeNextFromQueue()
tryToResumeFromQueue(permits)
}

fun incPermits() = _availablePermits.getAndUpdate { cur ->
check(cur < permits) { "The number of released permits cannot be greater than $permits" }
cur + 1
internal fun incPermits(delta: Int = 1) = permitsBalance.getAndUpdate { cur ->
assert { delta >= 1 }
check(cur + delta <= permits) { "The number of released permits cannot be greater than $permits" }
cur + delta
}

private suspend fun addToQueueAndSuspend() = suspendAtomicCancellableCoroutine<Unit> sc@ { cont ->
private suspend fun tryToAddToQueue(permits: Int) = suspendAtomicCancellableCoroutine<Unit> sc@{ cont ->
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how permits is used by this method. I'm very surprised that tests pass. Seems like some tests are missing.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous commit didn't contain significant changes, sorry. Please, recheck.

val last = this.tail
val enqIdx = enqIdx.getAndIncrement()
val segment = getSegment(last, enqIdx / SEGMENT_SIZE)
val i = (enqIdx % SEGMENT_SIZE).toInt()
if (segment === null || segment.get(i) === RESUMED || !segment.cas(i, null, cont)) {
// already resumed
val enqueueId = enqIdx.getAndIncrement()
val segmentId = enqueueId / SEGMENT_SIZE
val segment = getSegment(last, segmentId)
if (segment == null) {
// The segment is already removed
// Probably, this is the unreachable case
cont.resume(Unit)
return@sc
} else {
val slotId = (enqueueId % SEGMENT_SIZE).toInt()
val prevSlot = segment.slots[slotId].getAndSet(Slot(State.SUSPEND, permits, cont))
// The assertion is true, cause [RESUMED] can be set up only after [SUSPEND]
// and [CANCELLED] can be set up only in the handler, which will be added next
assert { prevSlot == null }
cont.invokeOnCancellation(CancelSemaphoreAcquisitionHandler(this, segment, slotId, permits).asHandler)
}
cont.invokeOnCancellation(CancelSemaphoreAcquisitionHandler(this, segment, i).asHandler)
// Help to resume slots, if accumulator has permits
tryToResumeFromQueue(0)
}

@Suppress("UNCHECKED_CAST")
internal fun resumeNextFromQueue() {
try_again@while (true) {
val first = this.head
val deqIdx = deqIdx.getAndIncrement()
val segment = getSegmentAndMoveHead(first, deqIdx / SEGMENT_SIZE) ?: continue@try_again
val i = (deqIdx % SEGMENT_SIZE).toInt()
val cont = segment.getAndSet(i, RESUMED)
if (cont === null) return // just resumed
if (cont === CANCELLED) continue@try_again
(cont as CancellableContinuation<Unit>).resume(Unit)
internal fun tryToResumeFromQueue(permits: Int) {
accumulator.getAndAdd(permits) // add thread permits to common accumulator
var remain = accumulator.getAndSet(0) // try to take possession of all the accumulated permits at the moment
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ The pair of getAndAdd and getAndAdd is not efficient. Should be replaced with one atomic CAS-loop.

if (remain == 0) {
// another thread stole permits
return
}
try_again@ while (true) {
val first = this.head
val dequeueId = deqIdx.value
val segmentId = dequeueId / SEGMENT_SIZE
val segment = getSegmentAndMoveHead(first, segmentId)
if (segment == null) {
// The segment is already removed
// Try to help to increment [deqIdx] once, because multiple threads can increment the [deqIdx] in parallel otherwise
deqIdx.compareAndSet(dequeueId, dequeueId + 1)
continue@try_again
}
val slotId = (dequeueId % SEGMENT_SIZE).toInt()
val slot = segment.slots[slotId].value
if (slot == null) {
// If the slot is not defined yet we can't spent permits for it, so return [remain] to [accumulator]
accumulator.addAndGet(remain)
return
}
if (slot.state == State.CANCELLED) {
// The slot was cancelled in the another thread
// Try to help to increment [deqIdx] once, because multiple threads can increment the [deqIdx] in parallel otherwise
if (deqIdx.compareAndSet(dequeueId, dequeueId + 1)) {
removeSegmentIfNeeded(segment, dequeueId + 1)
}
continue@try_again
}
if (slot.state == State.RESUMED) {
assert { slot.permits == 0 }
// The slot was updated in the another thread
// The another thread was supposed to increment [deqIdx]
continue@try_again
}
val diff = min(slot.permits, remain) // How many permits we can spent for the slot at most
val newPermits = slot.permits - diff
val newState = if (newPermits == 0) State.RESUMED else slot.state
val newSlot = Slot(newState, newPermits, slot.cont)
if (!segment.slots[slotId].compareAndSet(slot, newSlot)) {
// The slot was updated in another thread, let's try again
continue
}
// Here we successfully updated the slot
remain -= diff // remove spent permits
if (newState == State.RESUMED) {
slot.cont.resume(Unit)
removeSegmentIfNeeded(segment, deqIdx.incrementAndGet())
}
if (remain == 0) {
// We spent all available permits, so let's finish
return
}
// We still have permits, so we continue to spent them
}
}

/**
* Remove the segment if needed. The method checks, that all segment's slots were processed
*
* @param segment the segment to validation
* @param dequeueId the current dequeue operation ID
*/
internal fun removeSegmentIfNeeded(segment: SemaphoreSegment, dequeueId: Long) {
val slotId = (dequeueId % SEGMENT_SIZE).toInt()
if (slotId == SEGMENT_SIZE) {
segment.remove()
}
}

override fun toString(): String {
return "Semaphore=(balance=${permitsBalance.value}, accumulator=${accumulator.value})"
}
}

private enum class State {
SUSPEND,
RESUMED,
CANCELLED
}

private data class Slot(
val state: State,
/**
* Remaining permits to resume slot
*/
val permits: Int,
val cont: CancellableContinuation<Unit>
) {
init {
assert { permits >= 0 }
assert { permits != 0 || state == State.RESUMED }
}

override fun toString(): String {
return "Slot($state, $permits)"
}
}

/**
* Cleans the acquirer slot located by the specified index and removes this segment physically if all slots are cleaned.
*/
private class CancelSemaphoreAcquisitionHandler(
private val semaphore: SemaphoreImpl,
private val segment: SemaphoreSegment,
private val index: Int
private val semaphore: SemaphoreImpl,
private val segment: SemaphoreSegment,
private val slotId: Int,
private val permits: Int
) : CancelHandler() {
override fun invoke(cause: Throwable?) {
val p = semaphore.incPermits()
// Don't wait and use [prevSlot.permits] to handle permits, because it start races with release (see StressTest)
val p = semaphore.incPermits(permits)
if (p >= 0) return
if (segment.cancel(index)) return
semaphore.resumeNextFromQueue()
// Copy [slotId] to local variable to prevent exception:
// "Complex data flow is not allowed for calculation of an array element index at the point of loading the reference to this element."
val temp = slotId
val prevSlot = segment.slots[temp].getAndUpdate { Slot(State.CANCELLED, it!!.permits, it.cont) }
// The assertion is true, cause the slot has [SUSPEND] state at least
assert { prevSlot != null }

// Remove this segment if needed
if (segment.cancelledSlots.incrementAndGet() == SEGMENT_SIZE) {
segment.remove()
}
if (prevSlot!!.state == State.RESUMED) {
// The slot has already resumed, so return free permits to semaphore
semaphore.tryToResumeFromQueue(prevSlot.permits)
}
}

override fun toString() = "CancelSemaphoreAcquisitionHandler[$semaphore, $segment, $index]"
override fun toString() = "CancelSemaphoreAcquisitionHandler[$semaphore, $segment, $slotId]"
}

private class SemaphoreSegment(id: Long, prev: SemaphoreSegment?): Segment<SemaphoreSegment>(id, prev) {
val acquirers = atomicArrayOfNulls<Any?>(SEGMENT_SIZE)
private val cancelledSlots = atomic(0)
private class SemaphoreSegment(id: Long, prev: SemaphoreSegment?) : Segment<SemaphoreSegment>(id, prev) {
val slots = atomicArrayOfNulls<Slot>(SEGMENT_SIZE)
val cancelledSlots = atomic(0)
override val removed get() = cancelledSlots.value == SEGMENT_SIZE

@Suppress("NOTHING_TO_INLINE")
inline fun get(index: Int): Any? = acquirers[index].value

@Suppress("NOTHING_TO_INLINE")
inline fun cas(index: Int, expected: Any?, value: Any?): Boolean = acquirers[index].compareAndSet(expected, value)

@Suppress("NOTHING_TO_INLINE")
inline fun getAndSet(index: Int, value: Any?) = acquirers[index].getAndSet(value)

// Cleans the acquirer slot located by the specified index
// and removes this segment physically if all slots are cleaned.
fun cancel(index: Int): Boolean {
// Try to cancel the slot
val cancelled = getAndSet(index, CANCELLED) !== RESUMED
// Remove this segment if needed
if (cancelledSlots.incrementAndGet() == SEGMENT_SIZE)
remove()
return cancelled
override fun toString(): String {
return "SemaphoreSegment(id=$id)"
}

override fun toString() = "SemaphoreSegment[id=$id, hashCode=${hashCode()}]"
}

@SharedImmutable
private val RESUMED = Symbol("RESUMED")
@SharedImmutable
private val CANCELLED = Symbol("CANCELLED")
@SharedImmutable
private val SEGMENT_SIZE = systemProp("kotlinx.coroutines.semaphore.segmentSize", 16)
Loading