Skip to content

Commit 3f7cbf3

Browse files
committed
fix(Topic): Fix race condition on racing subscribe and close.
1 parent 156ad3c commit 3f7cbf3

File tree

2 files changed

+69
-46
lines changed

2 files changed

+69
-46
lines changed

core/shared/src/main/scala/fs2/concurrent/Topic.scala

Lines changed: 65 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ package fs2
2323
package concurrent
2424

2525
import cats.effect._
26-
import cats.effect.implicits._
2726
import cats.syntax.all._
2827
import scala.collection.immutable.LongMap
2928

@@ -151,7 +150,7 @@ object Topic {
151150
/** Constructs a Topic */
152151
def apply[F[_], A](implicit F: Concurrent[F]): F[Topic[F, A]] =
153152
(
154-
F.ref(LongMap.empty[Channel[F, A]] -> 1L),
153+
F.ref(State.initial[F, A]),
155154
SignallingRef[F, Int](0),
156155
F.deferred[Unit]
157156
).mapN { case (state, subscriberCount, signalClosure) =>
@@ -161,11 +160,11 @@ object Topic {
161160
lm.foldLeft(F.unit) { case (op, (_, b)) => op >> f(b) }
162161

163162
def publish1(a: A): F[Either[Topic.Closed, Unit]] =
164-
signalClosure.tryGet.flatMap {
165-
case Some(_) => Topic.closed.pure[F]
166-
case None =>
167-
state.get
168-
.flatMap { case (subs, _) => foreach(subs)(_.send(a).void) }
163+
state.get.flatMap {
164+
case State.Closed() =>
165+
Topic.closed.pure[F]
166+
case State.Active(subs, _) =>
167+
foreach(subs)(_.send(a).void)
169168
.as(Topic.rightUnit)
170169
}
171170

@@ -180,31 +179,44 @@ object Topic {
180179
.flatMap(subscribeAwaitImpl)
181180

182181
def subscribeAwaitImpl(chan: Channel[F, A]): Resource[F, Stream[F, A]] = {
183-
val subscribe = state.modify { case (subs, id) =>
184-
(subs.updated(id, chan), id + 1) -> id
185-
} <* subscriberCount.update(_ + 1)
186-
187-
def unsubscribe(id: Long) =
188-
state.modify { case (subs, nextId) =>
189-
// _After_ we remove the bounded channel for this
190-
// subscriber, we need to drain it to unblock to
191-
// publish loop which might have already enqueued
192-
// something.
193-
def drainChannel: F[Unit] =
194-
subs.get(id).traverse_ { chan =>
195-
chan.close >> chan.stream.compile.drain
196-
}
197-
198-
(subs - id, nextId) -> drainChannel
199-
}.flatten >> subscriberCount.update(_ - 1)
200-
201-
Resource.eval(signalClosure.tryGet).flatMap {
202-
case Some(_) => Resource.pure(Stream.empty)
203-
case None =>
204-
Resource
205-
.make(subscribe)(unsubscribe)
206-
.as(chan.stream)
207-
}
182+
val subscribe: F[Option[Long]] =
183+
state.flatModify {
184+
case State.Active(subs, nextId) =>
185+
val newState = State.Active(subs.updated(nextId, chan), nextId + 1)
186+
val action = subscriberCount.update(_ + 1)
187+
val result = Some(nextId)
188+
newState -> action.as(result)
189+
case closed @ State.Closed() =>
190+
closed -> F.pure(None)
191+
}
192+
193+
def unsubscribe(id: Long): F[Unit] =
194+
state.flatModify {
195+
case State.Active(subs, nextId) =>
196+
// _After_ we remove the bounded channel for this
197+
// subscriber, we need to drain it to unblock to
198+
// publish loop which might have already enqueued
199+
// something.
200+
def drainChannel: F[Unit] =
201+
subs.get(id).traverse_ { chan =>
202+
chan.close >> chan.stream.compile.drain
203+
}
204+
205+
State.Active(subs - id, nextId) -> (drainChannel *> subscriberCount.update(_ - 1))
206+
207+
case closed @ State.Closed() =>
208+
closed -> F.unit
209+
}
210+
211+
Resource
212+
.make(subscribe) {
213+
case Some(id) => unsubscribe(id)
214+
case None => F.unit
215+
}
216+
.map {
217+
case Some(_) => chan.stream
218+
case None => Stream.empty
219+
}
208220
}
209221

210222
def publish: Pipe[F, A, Nothing] = { in =>
@@ -223,22 +235,33 @@ object Topic {
223235
def subscribers: Stream[F, Int] = subscriberCount.discrete
224236

225237
def close: F[Either[Topic.Closed, Unit]] =
226-
signalClosure
227-
.complete(())
228-
.flatMap { completedNow =>
229-
val result = if (completedNow) Topic.rightUnit else Topic.closed
230-
231-
state.get
232-
.flatMap { case (subs, _) => foreach(subs)(_.close.void) }
233-
.as(result)
234-
}
235-
.uncancelable
238+
state.flatModify {
239+
case State.Active(subs, _) =>
240+
val action = foreach(subs)(_.close.void) *> signalClosure.complete(())
241+
(State.Closed(), action.as(Topic.rightUnit))
242+
case closed @ State.Closed() =>
243+
(closed, Topic.closed.pure[F])
244+
}
236245

237246
def closed: F[Unit] = signalClosure.get
238247
def isClosed: F[Boolean] = signalClosure.tryGet.map(_.isDefined)
239248
}
240249
}
241250

251+
private sealed trait State[F[_], A]
252+
253+
private object State {
254+
case class Active[F[_], A](
255+
subscribers: LongMap[Channel[F, A]],
256+
nextId: Long
257+
) extends State[F, A]
258+
259+
case class Closed[F[_], A]() extends State[F, A]
260+
261+
def initial[F[_], A]: State[F, A] =
262+
Active(LongMap.empty, 1L)
263+
}
264+
242265
private final val closed: Either[Closed, Unit] = Left(Closed)
243266
private final val rightUnit: Either[Closed, Unit] = Right(())
244267
}

core/shared/src/test/scala/fs2/concurrent/TopicSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ class TopicSuite extends Fs2Suite {
187187
}
188188

189189
// https://github.com/typelevel/fs2/issues/3642
190-
test("subscribe and close concurrently".flaky) {
190+
test("subscribe and close concurrently") {
191191
val check: IO[Unit] =
192192
for {
193193
t <- Topic[IO, Int]
@@ -200,11 +200,11 @@ class TopicSuite extends Fs2Suite {
200200
_ <- fiber.join.timeout(5.seconds) // checking termination of the subscription stream
201201
} yield ()
202202

203-
check.replicateA_(100000)
203+
check.replicateA_(10000)
204204
}
205205

206206
// https://github.com/typelevel/fs2/issues/3642
207-
test("subscribeAwait and close concurrently".flaky) {
207+
test("subscribeAwait and close concurrently") {
208208
val check: IO[Unit] =
209209
for {
210210
t <- Topic[IO, Int]
@@ -218,6 +218,6 @@ class TopicSuite extends Fs2Suite {
218218
_ <- fiber.join.timeout(5.seconds) // checking termination of the subscription stream
219219
} yield ()
220220

221-
check.replicateA_(100000)
221+
check.replicateA_(10000)
222222
}
223223
}

0 commit comments

Comments
 (0)