From 810af8a435b975875605922e12d77d629e191414 Mon Sep 17 00:00:00 2001 From: Justin Reardon Date: Wed, 25 Dec 2024 23:20:37 -0500 Subject: [PATCH 1/2] Fix #3076 parEvalMap resource scoping Updates parEvalMap* and broadcastThrough to extend the resource scope past the channel/topic used to implement concurrency for these operators. --- core/shared/src/main/scala/fs2/Stream.scala | 87 +++++++++++-------- .../src/test/scala/fs2/ParEvalMapSuite.scala | 30 +++++++ 2 files changed, 83 insertions(+), 34 deletions(-) diff --git a/core/shared/src/main/scala/fs2/Stream.scala b/core/shared/src/main/scala/fs2/Stream.scala index b70a1372ac..0b052f2c5f 100644 --- a/core/shared/src/main/scala/fs2/Stream.scala +++ b/core/shared/src/main/scala/fs2/Stream.scala @@ -238,31 +238,39 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, */ def broadcastThrough[F2[x] >: F[x]: Concurrent, O2](pipes: Pipe[F2, O, O2]*): Stream[F2, O2] = { assert(pipes.nonEmpty, s"pipes should not be empty") - Stream.force { - for { - // topic: contains the chunk that the pipes are processing at one point. - // until and unless all pipes are finished with it, won't move to next one - topic <- Topic[F2, Chunk[O]] - // Coordination: neither the producer nor any consumer starts - // until and unless all consumers are subscribed to topic. - allReady <- CountDownLatch[F2](pipes.length) - } yield { - val checkIn = allReady.release >> allReady.await - - def dump(pipe: Pipe[F2, O, O2]): Stream[F2, O2] = - Stream.resource(topic.subscribeAwait(1)).flatMap { sub => - // Wait until all pipes are ready before consuming. - // Crucial: checkin is not passed to the pipe, - // so pipe cannot interrupt it and alter the latch count - Stream.exec(checkIn) ++ pipe(sub.unchunks) - } + underlying.uncons.flatMap { + case Some((hd, tl)) => + for { + // topic: contains the chunk that the pipes are processing at one point. + // until and unless all pipes are finished with it, won't move to next one + topic <- Pull.eval(Topic[F2, Chunk[O]]) + // Coordination: neither the producer nor any consumer starts + // until and unless all consumers are subscribed to topic. + allReady <- Pull.eval(CountDownLatch[F2](pipes.length)) + + checkIn = allReady.release >> allReady.await + + dump = (pipe: Pipe[F2, O, O2]) => + Stream.resource(topic.subscribeAwait(1)).flatMap { sub => + // Wait until all pipes are ready before consuming. + // Crucial: checkin is not passed to the pipe, + // so pipe cannot interrupt it and alter the latch count + Stream.exec(checkIn) ++ pipe(sub.unchunks) + } - val dumpAll: Stream[F2, O2] = Stream(pipes: _*).map(dump).parJoinUnbounded - // Wait until all pipes are checked in before pulling - val pump = Stream.exec(allReady.await) ++ topic.publish(chunks) - dumpAll.concurrently(pump) - } - } + dumpAll: Stream[F2, O2] <- + Pull.extendScopeTo(Stream(pipes: _*).map(dump).parJoinUnbounded) + + chunksStream = Stream.chunk(hd).append(tl.stream).chunks + + // Wait until all pipes are checked in before pulling + pump = Stream.exec(allReady.await) ++ topic.publish(chunksStream) + + _ <- dumpAll.concurrently(pump).underlying + } yield () + + case None => Pull.done + }.stream } /** Behaves like the identity function, but requests `n` elements at a time from the input. @@ -2366,17 +2374,28 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, } } - val background = - Stream.exec(semaphore.acquire) ++ - interruptWhen(stop.get.map(_.asRight[Throwable])) - .foreach(forkOnElem) - .onFinalizeCase { - case ExitCase.Succeeded => releaseAndCheckCompletion - case _ => stop.complete(()) *> releaseAndCheckCompletion - } + underlying.uncons.flatMap { + case Some((hd, tl)) => + for { + foreground <- Pull.extendScopeTo( + channel.stream.evalMap(_.rethrow).onFinalize(stop.complete(()) *> end.get) + ) + background = Stream + .exec(semaphore.acquire) ++ + Stream + .chunk(hd) + .append(tl.stream) + .interruptWhen(stop.get.map(_.asRight[Throwable])) + .foreach(forkOnElem) + .onFinalizeCase { + case ExitCase.Succeeded => releaseAndCheckCompletion + case _ => stop.complete(()) *> releaseAndCheckCompletion + } + _ <- foreground.concurrently(background).underlying + } yield () - val foreground = channel.stream.evalMap(_.rethrow) - foreground.onFinalize(stop.complete(()) *> end.get).concurrently(background) + case None => Pull.done + }.stream } Stream.force(action) diff --git a/core/shared/src/test/scala/fs2/ParEvalMapSuite.scala b/core/shared/src/test/scala/fs2/ParEvalMapSuite.scala index d394969623..3588434175 100644 --- a/core/shared/src/test/scala/fs2/ParEvalMapSuite.scala +++ b/core/shared/src/test/scala/fs2/ParEvalMapSuite.scala @@ -293,4 +293,34 @@ class ParEvalMapSuite extends Fs2Suite { .timeout(2.seconds) } } + + group("issue-3076, parEvalMap* runs resource finaliser before usage") { + test("parEvalMap") { + Deferred[IO, Unit] + .flatMap { d => + Stream + .bracket(IO.unit)(_ => d.complete(()).void) + .parEvalMap(2)(_ => IO.sleep(1.second)) + .evalMap(_ => IO.sleep(1.second) >> d.complete(())) + .timeout(5.seconds) + .compile + .last + } + .assertEquals(Some(true)) + } + + test("broadcastThrough") { + Deferred[IO, Unit] + .flatMap { d => + Stream + .bracket(IO.unit)(_ => d.complete(()).void) + .broadcastThrough(identity[Stream[IO, Unit]]) + .evalMap(_ => IO.sleep(1.second) >> d.complete(())) + .timeout(5.seconds) + .compile + .last + } + .assertEquals(Some(true)) + } + } } From cc55983834c90e0fcb5cef6eb9d6f50fa382f64c Mon Sep 17 00:00:00 2001 From: Justin Reardon Date: Thu, 26 Dec 2024 09:04:49 -0500 Subject: [PATCH 2/2] Abstract scope extension with extendScopeThrough Make extendScopeTo cancellation safe (see #3474) --- core/shared/src/main/scala/fs2/Pull.scala | 5 +- core/shared/src/main/scala/fs2/Stream.scala | 166 +++++++++--------- .../scala/fs2/StreamCombinatorsSuite.scala | 14 ++ 3 files changed, 99 insertions(+), 86 deletions(-) diff --git a/core/shared/src/main/scala/fs2/Pull.scala b/core/shared/src/main/scala/fs2/Pull.scala index c583263b15..6a241b8fb5 100644 --- a/core/shared/src/main/scala/fs2/Pull.scala +++ b/core/shared/src/main/scala/fs2/Pull.scala @@ -464,10 +464,7 @@ object Pull extends PullLowPriority { def extendScopeTo[F[_], O]( s: Stream[F, O] )(implicit F: MonadError[F, Throwable]): Pull[F, Nothing, Stream[F, O]] = - for { - scope <- Pull.getScope[F] - lease <- Pull.eval(scope.lease) - } yield s.onFinalize(lease.cancel.redeemWith(F.raiseError(_), _ => F.unit)) + Pull.getScope[F].map(scope => Stream.bracket(scope.lease)(_.cancel.rethrow) *> s) /** Repeatedly uses the output of the pull as input for the next step of the * pull. Halts when a step terminates with `None` or `Pull.raiseError`. diff --git a/core/shared/src/main/scala/fs2/Stream.scala b/core/shared/src/main/scala/fs2/Stream.scala index 0b052f2c5f..f65d72e428 100644 --- a/core/shared/src/main/scala/fs2/Stream.scala +++ b/core/shared/src/main/scala/fs2/Stream.scala @@ -238,19 +238,19 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, */ def broadcastThrough[F2[x] >: F[x]: Concurrent, O2](pipes: Pipe[F2, O, O2]*): Stream[F2, O2] = { assert(pipes.nonEmpty, s"pipes should not be empty") - underlying.uncons.flatMap { - case Some((hd, tl)) => + extendScopeThrough { source => + Stream.force { for { // topic: contains the chunk that the pipes are processing at one point. // until and unless all pipes are finished with it, won't move to next one - topic <- Pull.eval(Topic[F2, Chunk[O]]) + topic <- Topic[F2, Chunk[O]] // Coordination: neither the producer nor any consumer starts // until and unless all consumers are subscribed to topic. - allReady <- Pull.eval(CountDownLatch[F2](pipes.length)) - - checkIn = allReady.release >> allReady.await + allReady <- CountDownLatch[F2](pipes.length) + } yield { + val checkIn = allReady.release >> allReady.await - dump = (pipe: Pipe[F2, O, O2]) => + def dump(pipe: Pipe[F2, O, O2]): Stream[F2, O2] = Stream.resource(topic.subscribeAwait(1)).flatMap { sub => // Wait until all pipes are ready before consuming. // Crucial: checkin is not passed to the pipe, @@ -258,19 +258,13 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, Stream.exec(checkIn) ++ pipe(sub.unchunks) } - dumpAll: Stream[F2, O2] <- - Pull.extendScopeTo(Stream(pipes: _*).map(dump).parJoinUnbounded) - - chunksStream = Stream.chunk(hd).append(tl.stream).chunks - + val dumpAll: Stream[F2, O2] = Stream(pipes: _*).map(dump).parJoinUnbounded // Wait until all pipes are checked in before pulling - pump = Stream.exec(allReady.await) ++ topic.publish(chunksStream) - - _ <- dumpAll.concurrently(pump).underlying - } yield () - - case None => Pull.done - }.stream + val pump = Stream.exec(allReady.await) ++ topic.publish(source.chunks) + dumpAll.concurrently(pump) + } + } + } } /** Behaves like the identity function, but requests `n` elements at a time from the input. @@ -2331,75 +2325,65 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, channel: F2[Channel[F2, F2[Either[Throwable, O2]]]], isOrdered: Boolean, f: O => F2[O2] - )(implicit F: Concurrent[F2]): Stream[F2, O2] = { - val action = - ( - Semaphore[F2](concurrency), - channel, - Deferred[F2, Unit], - Deferred[F2, Unit] - ).mapN { (semaphore, channel, stop, end) => - def initFork(release: F2[Unit]): F2[Either[Throwable, O2] => F2[Unit]] = { - def ordered: F2[Either[Throwable, O2] => F2[Unit]] = { - def send(v: Deferred[F2, Either[Throwable, O2]]) = - (el: Either[Throwable, O2]) => v.complete(el).void - - Deferred[F2, Either[Throwable, O2]] - .flatTap(value => channel.send(release *> value.get)) - .map(send) - } - - def unordered: Either[Throwable, O2] => F2[Unit] = - (el: Either[Throwable, O2]) => release <* channel.send(F.pure(el)) + )(implicit F: Concurrent[F2]): Stream[F2, O2] = + extendScopeThrough { source => + Stream.force { + ( + Semaphore[F2](concurrency), + channel, + Deferred[F2, Unit], + Deferred[F2, Unit] + ).mapN { (semaphore, channel, stop, end) => + def initFork(release: F2[Unit]): F2[Either[Throwable, O2] => F2[Unit]] = { + def ordered: F2[Either[Throwable, O2] => F2[Unit]] = { + def send(v: Deferred[F2, Either[Throwable, O2]]) = + (el: Either[Throwable, O2]) => v.complete(el).void + + Deferred[F2, Either[Throwable, O2]] + .flatTap(value => channel.send(release *> value.get)) + .map(send) + } - if (isOrdered) ordered else F.pure(unordered) - } + def unordered: Either[Throwable, O2] => F2[Unit] = + (el: Either[Throwable, O2]) => release <* channel.send(F.pure(el)) - val releaseAndCheckCompletion = - semaphore.release *> - semaphore.available.flatMap { - case `concurrency` => channel.close *> end.complete(()).void - case _ => F.unit - } + if (isOrdered) ordered else F.pure(unordered) + } - def forkOnElem(el: O): F2[Unit] = - F.uncancelable { poll => - poll(semaphore.acquire) <* - Deferred[F2, Unit].flatMap { pushed => - val init = initFork(pushed.complete(()).void) - poll(init).onCancel(releaseAndCheckCompletion).flatMap { send => - val action = F.catchNonFatal(f(el)).flatten.attempt.flatMap(send) *> pushed.get - F.start(stop.get.race(action) *> releaseAndCheckCompletion) - } + val releaseAndCheckCompletion = + semaphore.release *> + semaphore.available.flatMap { + case `concurrency` => channel.close *> end.complete(()).void + case _ => F.unit } - } - underlying.uncons.flatMap { - case Some((hd, tl)) => - for { - foreground <- Pull.extendScopeTo( - channel.stream.evalMap(_.rethrow).onFinalize(stop.complete(()) *> end.get) - ) - background = Stream - .exec(semaphore.acquire) ++ - Stream - .chunk(hd) - .append(tl.stream) - .interruptWhen(stop.get.map(_.asRight[Throwable])) - .foreach(forkOnElem) - .onFinalizeCase { - case ExitCase.Succeeded => releaseAndCheckCompletion - case _ => stop.complete(()) *> releaseAndCheckCompletion + def forkOnElem(el: O): F2[Unit] = + F.uncancelable { poll => + poll(semaphore.acquire) <* + Deferred[F2, Unit].flatMap { pushed => + val init = initFork(pushed.complete(()).void) + poll(init).onCancel(releaseAndCheckCompletion).flatMap { send => + val action = F.catchNonFatal(f(el)).flatten.attempt.flatMap(send) *> pushed.get + F.start(stop.get.race(action) *> releaseAndCheckCompletion) } - _ <- foreground.concurrently(background).underlying - } yield () + } + } - case None => Pull.done - }.stream - } + val background = + Stream.exec(semaphore.acquire) ++ + source + .interruptWhen(stop.get.map(_.asRight[Throwable])) + .foreach(forkOnElem) + .onFinalizeCase { + case ExitCase.Succeeded => releaseAndCheckCompletion + case _ => stop.complete(()) *> releaseAndCheckCompletion + } - Stream.force(action) - } + val foreground = channel.stream.evalMap(_.rethrow) + foreground.onFinalize(stop.complete(()) *> end.get).concurrently(background) + } + } + } /** Concurrent zip. * @@ -2474,12 +2458,13 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, */ def prefetchN[F2[x] >: F[x]: Concurrent]( n: Int - ): Stream[F2, O] = + ): Stream[F2, O] = extendScopeThrough { source => Stream.eval(Channel.bounded[F2, Chunk[O]](n)).flatMap { chan => chan.stream.unchunks.concurrently { - chunks.through(chan.sendAll) + source.chunks.through(chan.sendAll) } } + } /** Prints each element of this stream to standard out, converting each element to a `String` via `Show`. */ def printlns[F2[x] >: F[x], O2 >: O](implicit @@ -2940,6 +2925,23 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, )(f: (Stream[F, O], Stream[F2, O2]) => Stream[F2, O3]): Stream[F2, O3] = f(this, s2) + /** Transforms this stream, explicitly extending the current scope through the given pipe. + * + * Use this when implementing a pipe where the resulting stream is not directly constructed from + * the source stream, e.g. when sending the source stream through a Channel and returning the + * channel's stream. + */ + def extendScopeThrough[F2[x] >: F[x], O2]( + f: Stream[F, O] => Stream[F2, O2] + )(implicit F: MonadError[F2, Throwable]): Stream[F2, O2] = + this.pull.peek + .flatMap { + case Some((_, tl)) => Pull.extendScopeTo(f(tl)) + case None => Pull.extendScopeTo(f(Stream.empty)) + } + .flatMap(_.underlying) + .stream + /** Fails this stream with a `TimeoutException` if it does not complete within given `timeout`. */ def timeout[F2[x] >: F[x]: Temporal]( timeout: FiniteDuration diff --git a/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala b/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala index d7cc21d93b..fe27fc5492 100644 --- a/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala +++ b/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala @@ -1328,6 +1328,20 @@ class StreamCombinatorsSuite extends Fs2Suite { ) .assertEquals(4.seconds) } + + test("scope propagation") { + Deferred[IO, Unit] + .flatMap { d => + Stream + .bracket(IO.unit)(_ => d.complete(()).void) + .prefetch + .evalMap(_ => IO.sleep(1.second) >> d.complete(())) + .timeout(5.seconds) + .compile + .last + } + .assertEquals(Some(true)) + } } test("range") {