diff --git a/core/shared/src/main/scala/fs2/Pull.scala b/core/shared/src/main/scala/fs2/Pull.scala index c583263b15..6a241b8fb5 100644 --- a/core/shared/src/main/scala/fs2/Pull.scala +++ b/core/shared/src/main/scala/fs2/Pull.scala @@ -464,10 +464,7 @@ object Pull extends PullLowPriority { def extendScopeTo[F[_], O]( s: Stream[F, O] )(implicit F: MonadError[F, Throwable]): Pull[F, Nothing, Stream[F, O]] = - for { - scope <- Pull.getScope[F] - lease <- Pull.eval(scope.lease) - } yield s.onFinalize(lease.cancel.redeemWith(F.raiseError(_), _ => F.unit)) + Pull.getScope[F].map(scope => Stream.bracket(scope.lease)(_.cancel.rethrow) *> s) /** Repeatedly uses the output of the pull as input for the next step of the * pull. Halts when a step terminates with `None` or `Pull.raiseError`. diff --git a/core/shared/src/main/scala/fs2/Stream.scala b/core/shared/src/main/scala/fs2/Stream.scala index b70a1372ac..f65d72e428 100644 --- a/core/shared/src/main/scala/fs2/Stream.scala +++ b/core/shared/src/main/scala/fs2/Stream.scala @@ -238,29 +238,31 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, */ def broadcastThrough[F2[x] >: F[x]: Concurrent, O2](pipes: Pipe[F2, O, O2]*): Stream[F2, O2] = { assert(pipes.nonEmpty, s"pipes should not be empty") - Stream.force { - for { - // topic: contains the chunk that the pipes are processing at one point. - // until and unless all pipes are finished with it, won't move to next one - topic <- Topic[F2, Chunk[O]] - // Coordination: neither the producer nor any consumer starts - // until and unless all consumers are subscribed to topic. - allReady <- CountDownLatch[F2](pipes.length) - } yield { - val checkIn = allReady.release >> allReady.await - - def dump(pipe: Pipe[F2, O, O2]): Stream[F2, O2] = - Stream.resource(topic.subscribeAwait(1)).flatMap { sub => - // Wait until all pipes are ready before consuming. - // Crucial: checkin is not passed to the pipe, - // so pipe cannot interrupt it and alter the latch count - Stream.exec(checkIn) ++ pipe(sub.unchunks) - } + extendScopeThrough { source => + Stream.force { + for { + // topic: contains the chunk that the pipes are processing at one point. + // until and unless all pipes are finished with it, won't move to next one + topic <- Topic[F2, Chunk[O]] + // Coordination: neither the producer nor any consumer starts + // until and unless all consumers are subscribed to topic. + allReady <- CountDownLatch[F2](pipes.length) + } yield { + val checkIn = allReady.release >> allReady.await + + def dump(pipe: Pipe[F2, O, O2]): Stream[F2, O2] = + Stream.resource(topic.subscribeAwait(1)).flatMap { sub => + // Wait until all pipes are ready before consuming. + // Crucial: checkin is not passed to the pipe, + // so pipe cannot interrupt it and alter the latch count + Stream.exec(checkIn) ++ pipe(sub.unchunks) + } - val dumpAll: Stream[F2, O2] = Stream(pipes: _*).map(dump).parJoinUnbounded - // Wait until all pipes are checked in before pulling - val pump = Stream.exec(allReady.await) ++ topic.publish(chunks) - dumpAll.concurrently(pump) + val dumpAll: Stream[F2, O2] = Stream(pipes: _*).map(dump).parJoinUnbounded + // Wait until all pipes are checked in before pulling + val pump = Stream.exec(allReady.await) ++ topic.publish(source.chunks) + dumpAll.concurrently(pump) + } } } } @@ -2323,64 +2325,65 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, channel: F2[Channel[F2, F2[Either[Throwable, O2]]]], isOrdered: Boolean, f: O => F2[O2] - )(implicit F: Concurrent[F2]): Stream[F2, O2] = { - val action = - ( - Semaphore[F2](concurrency), - channel, - Deferred[F2, Unit], - Deferred[F2, Unit] - ).mapN { (semaphore, channel, stop, end) => - def initFork(release: F2[Unit]): F2[Either[Throwable, O2] => F2[Unit]] = { - def ordered: F2[Either[Throwable, O2] => F2[Unit]] = { - def send(v: Deferred[F2, Either[Throwable, O2]]) = - (el: Either[Throwable, O2]) => v.complete(el).void - - Deferred[F2, Either[Throwable, O2]] - .flatTap(value => channel.send(release *> value.get)) - .map(send) - } + )(implicit F: Concurrent[F2]): Stream[F2, O2] = + extendScopeThrough { source => + Stream.force { + ( + Semaphore[F2](concurrency), + channel, + Deferred[F2, Unit], + Deferred[F2, Unit] + ).mapN { (semaphore, channel, stop, end) => + def initFork(release: F2[Unit]): F2[Either[Throwable, O2] => F2[Unit]] = { + def ordered: F2[Either[Throwable, O2] => F2[Unit]] = { + def send(v: Deferred[F2, Either[Throwable, O2]]) = + (el: Either[Throwable, O2]) => v.complete(el).void + + Deferred[F2, Either[Throwable, O2]] + .flatTap(value => channel.send(release *> value.get)) + .map(send) + } - def unordered: Either[Throwable, O2] => F2[Unit] = - (el: Either[Throwable, O2]) => release <* channel.send(F.pure(el)) + def unordered: Either[Throwable, O2] => F2[Unit] = + (el: Either[Throwable, O2]) => release <* channel.send(F.pure(el)) - if (isOrdered) ordered else F.pure(unordered) - } + if (isOrdered) ordered else F.pure(unordered) + } - val releaseAndCheckCompletion = - semaphore.release *> - semaphore.available.flatMap { - case `concurrency` => channel.close *> end.complete(()).void - case _ => F.unit - } + val releaseAndCheckCompletion = + semaphore.release *> + semaphore.available.flatMap { + case `concurrency` => channel.close *> end.complete(()).void + case _ => F.unit + } - def forkOnElem(el: O): F2[Unit] = - F.uncancelable { poll => - poll(semaphore.acquire) <* - Deferred[F2, Unit].flatMap { pushed => - val init = initFork(pushed.complete(()).void) - poll(init).onCancel(releaseAndCheckCompletion).flatMap { send => - val action = F.catchNonFatal(f(el)).flatten.attempt.flatMap(send) *> pushed.get - F.start(stop.get.race(action) *> releaseAndCheckCompletion) + def forkOnElem(el: O): F2[Unit] = + F.uncancelable { poll => + poll(semaphore.acquire) <* + Deferred[F2, Unit].flatMap { pushed => + val init = initFork(pushed.complete(()).void) + poll(init).onCancel(releaseAndCheckCompletion).flatMap { send => + val action = F.catchNonFatal(f(el)).flatten.attempt.flatMap(send) *> pushed.get + F.start(stop.get.race(action) *> releaseAndCheckCompletion) + } } - } - } + } - val background = - Stream.exec(semaphore.acquire) ++ - interruptWhen(stop.get.map(_.asRight[Throwable])) - .foreach(forkOnElem) - .onFinalizeCase { - case ExitCase.Succeeded => releaseAndCheckCompletion - case _ => stop.complete(()) *> releaseAndCheckCompletion - } + val background = + Stream.exec(semaphore.acquire) ++ + source + .interruptWhen(stop.get.map(_.asRight[Throwable])) + .foreach(forkOnElem) + .onFinalizeCase { + case ExitCase.Succeeded => releaseAndCheckCompletion + case _ => stop.complete(()) *> releaseAndCheckCompletion + } - val foreground = channel.stream.evalMap(_.rethrow) - foreground.onFinalize(stop.complete(()) *> end.get).concurrently(background) + val foreground = channel.stream.evalMap(_.rethrow) + foreground.onFinalize(stop.complete(()) *> end.get).concurrently(background) + } } - - Stream.force(action) - } + } /** Concurrent zip. * @@ -2455,12 +2458,13 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, */ def prefetchN[F2[x] >: F[x]: Concurrent]( n: Int - ): Stream[F2, O] = + ): Stream[F2, O] = extendScopeThrough { source => Stream.eval(Channel.bounded[F2, Chunk[O]](n)).flatMap { chan => chan.stream.unchunks.concurrently { - chunks.through(chan.sendAll) + source.chunks.through(chan.sendAll) } } + } /** Prints each element of this stream to standard out, converting each element to a `String` via `Show`. */ def printlns[F2[x] >: F[x], O2 >: O](implicit @@ -2921,6 +2925,23 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, )(f: (Stream[F, O], Stream[F2, O2]) => Stream[F2, O3]): Stream[F2, O3] = f(this, s2) + /** Transforms this stream, explicitly extending the current scope through the given pipe. + * + * Use this when implementing a pipe where the resulting stream is not directly constructed from + * the source stream, e.g. when sending the source stream through a Channel and returning the + * channel's stream. + */ + def extendScopeThrough[F2[x] >: F[x], O2]( + f: Stream[F, O] => Stream[F2, O2] + )(implicit F: MonadError[F2, Throwable]): Stream[F2, O2] = + this.pull.peek + .flatMap { + case Some((_, tl)) => Pull.extendScopeTo(f(tl)) + case None => Pull.extendScopeTo(f(Stream.empty)) + } + .flatMap(_.underlying) + .stream + /** Fails this stream with a `TimeoutException` if it does not complete within given `timeout`. */ def timeout[F2[x] >: F[x]: Temporal]( timeout: FiniteDuration diff --git a/core/shared/src/test/scala/fs2/ParEvalMapSuite.scala b/core/shared/src/test/scala/fs2/ParEvalMapSuite.scala index d394969623..3588434175 100644 --- a/core/shared/src/test/scala/fs2/ParEvalMapSuite.scala +++ b/core/shared/src/test/scala/fs2/ParEvalMapSuite.scala @@ -293,4 +293,34 @@ class ParEvalMapSuite extends Fs2Suite { .timeout(2.seconds) } } + + group("issue-3076, parEvalMap* runs resource finaliser before usage") { + test("parEvalMap") { + Deferred[IO, Unit] + .flatMap { d => + Stream + .bracket(IO.unit)(_ => d.complete(()).void) + .parEvalMap(2)(_ => IO.sleep(1.second)) + .evalMap(_ => IO.sleep(1.second) >> d.complete(())) + .timeout(5.seconds) + .compile + .last + } + .assertEquals(Some(true)) + } + + test("broadcastThrough") { + Deferred[IO, Unit] + .flatMap { d => + Stream + .bracket(IO.unit)(_ => d.complete(()).void) + .broadcastThrough(identity[Stream[IO, Unit]]) + .evalMap(_ => IO.sleep(1.second) >> d.complete(())) + .timeout(5.seconds) + .compile + .last + } + .assertEquals(Some(true)) + } + } } diff --git a/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala b/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala index d7cc21d93b..fe27fc5492 100644 --- a/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala +++ b/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala @@ -1328,6 +1328,20 @@ class StreamCombinatorsSuite extends Fs2Suite { ) .assertEquals(4.seconds) } + + test("scope propagation") { + Deferred[IO, Unit] + .flatMap { d => + Stream + .bracket(IO.unit)(_ => d.complete(()).void) + .prefetch + .evalMap(_ => IO.sleep(1.second) >> d.complete(())) + .timeout(5.seconds) + .compile + .last + } + .assertEquals(Some(true)) + } } test("range") {