Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #3076 parEvalMap resource scoping #3512

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions core/shared/src/main/scala/fs2/Pull.scala
Original file line number Diff line number Diff line change
Expand Up @@ -464,10 +464,7 @@ object Pull extends PullLowPriority {
def extendScopeTo[F[_], O](
s: Stream[F, O]
)(implicit F: MonadError[F, Throwable]): Pull[F, Nothing, Stream[F, O]] =
for {
scope <- Pull.getScope[F]
lease <- Pull.eval(scope.lease)
} yield s.onFinalize(lease.cancel.redeemWith(F.raiseError(_), _ => F.unit))
Pull.getScope[F].map(scope => Stream.bracket(scope.lease)(_.cancel.rethrow) *> s)

/** Repeatedly uses the output of the pull as input for the next step of the
* pull. Halts when a step terminates with `None` or `Pull.raiseError`.
Expand Down
169 changes: 95 additions & 74 deletions core/shared/src/main/scala/fs2/Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -238,29 +238,31 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
*/
def broadcastThrough[F2[x] >: F[x]: Concurrent, O2](pipes: Pipe[F2, O, O2]*): Stream[F2, O2] = {
assert(pipes.nonEmpty, s"pipes should not be empty")
Stream.force {
for {
// topic: contains the chunk that the pipes are processing at one point.
// until and unless all pipes are finished with it, won't move to next one
topic <- Topic[F2, Chunk[O]]
// Coordination: neither the producer nor any consumer starts
// until and unless all consumers are subscribed to topic.
allReady <- CountDownLatch[F2](pipes.length)
} yield {
val checkIn = allReady.release >> allReady.await

def dump(pipe: Pipe[F2, O, O2]): Stream[F2, O2] =
Stream.resource(topic.subscribeAwait(1)).flatMap { sub =>
// Wait until all pipes are ready before consuming.
// Crucial: checkin is not passed to the pipe,
// so pipe cannot interrupt it and alter the latch count
Stream.exec(checkIn) ++ pipe(sub.unchunks)
}
extendScopeThrough { source =>
Stream.force {
for {
// topic: contains the chunk that the pipes are processing at one point.
// until and unless all pipes are finished with it, won't move to next one
topic <- Topic[F2, Chunk[O]]
// Coordination: neither the producer nor any consumer starts
// until and unless all consumers are subscribed to topic.
allReady <- CountDownLatch[F2](pipes.length)
} yield {
val checkIn = allReady.release >> allReady.await

def dump(pipe: Pipe[F2, O, O2]): Stream[F2, O2] =
Stream.resource(topic.subscribeAwait(1)).flatMap { sub =>
// Wait until all pipes are ready before consuming.
// Crucial: checkin is not passed to the pipe,
// so pipe cannot interrupt it and alter the latch count
Stream.exec(checkIn) ++ pipe(sub.unchunks)
}

val dumpAll: Stream[F2, O2] = Stream(pipes: _*).map(dump).parJoinUnbounded
// Wait until all pipes are checked in before pulling
val pump = Stream.exec(allReady.await) ++ topic.publish(chunks)
dumpAll.concurrently(pump)
val dumpAll: Stream[F2, O2] = Stream(pipes: _*).map(dump).parJoinUnbounded
// Wait until all pipes are checked in before pulling
val pump = Stream.exec(allReady.await) ++ topic.publish(source.chunks)
dumpAll.concurrently(pump)
}
}
}
}
Expand Down Expand Up @@ -2323,64 +2325,65 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
channel: F2[Channel[F2, F2[Either[Throwable, O2]]]],
isOrdered: Boolean,
f: O => F2[O2]
)(implicit F: Concurrent[F2]): Stream[F2, O2] = {
val action =
(
Semaphore[F2](concurrency),
channel,
Deferred[F2, Unit],
Deferred[F2, Unit]
).mapN { (semaphore, channel, stop, end) =>
def initFork(release: F2[Unit]): F2[Either[Throwable, O2] => F2[Unit]] = {
def ordered: F2[Either[Throwable, O2] => F2[Unit]] = {
def send(v: Deferred[F2, Either[Throwable, O2]]) =
(el: Either[Throwable, O2]) => v.complete(el).void

Deferred[F2, Either[Throwable, O2]]
.flatTap(value => channel.send(release *> value.get))
.map(send)
}
)(implicit F: Concurrent[F2]): Stream[F2, O2] =
extendScopeThrough { source =>
Stream.force {
(
Semaphore[F2](concurrency),
channel,
Deferred[F2, Unit],
Deferred[F2, Unit]
).mapN { (semaphore, channel, stop, end) =>
def initFork(release: F2[Unit]): F2[Either[Throwable, O2] => F2[Unit]] = {
def ordered: F2[Either[Throwable, O2] => F2[Unit]] = {
def send(v: Deferred[F2, Either[Throwable, O2]]) =
(el: Either[Throwable, O2]) => v.complete(el).void

Deferred[F2, Either[Throwable, O2]]
.flatTap(value => channel.send(release *> value.get))
.map(send)
}

def unordered: Either[Throwable, O2] => F2[Unit] =
(el: Either[Throwable, O2]) => release <* channel.send(F.pure(el))
def unordered: Either[Throwable, O2] => F2[Unit] =
(el: Either[Throwable, O2]) => release <* channel.send(F.pure(el))

if (isOrdered) ordered else F.pure(unordered)
}
if (isOrdered) ordered else F.pure(unordered)
}

val releaseAndCheckCompletion =
semaphore.release *>
semaphore.available.flatMap {
case `concurrency` => channel.close *> end.complete(()).void
case _ => F.unit
}
val releaseAndCheckCompletion =
semaphore.release *>
semaphore.available.flatMap {
case `concurrency` => channel.close *> end.complete(()).void
case _ => F.unit
}

def forkOnElem(el: O): F2[Unit] =
F.uncancelable { poll =>
poll(semaphore.acquire) <*
Deferred[F2, Unit].flatMap { pushed =>
val init = initFork(pushed.complete(()).void)
poll(init).onCancel(releaseAndCheckCompletion).flatMap { send =>
val action = F.catchNonFatal(f(el)).flatten.attempt.flatMap(send) *> pushed.get
F.start(stop.get.race(action) *> releaseAndCheckCompletion)
def forkOnElem(el: O): F2[Unit] =
F.uncancelable { poll =>
poll(semaphore.acquire) <*
Deferred[F2, Unit].flatMap { pushed =>
val init = initFork(pushed.complete(()).void)
poll(init).onCancel(releaseAndCheckCompletion).flatMap { send =>
val action = F.catchNonFatal(f(el)).flatten.attempt.flatMap(send) *> pushed.get
F.start(stop.get.race(action) *> releaseAndCheckCompletion)
}
}
}
}
}

val background =
Stream.exec(semaphore.acquire) ++
interruptWhen(stop.get.map(_.asRight[Throwable]))
.foreach(forkOnElem)
.onFinalizeCase {
case ExitCase.Succeeded => releaseAndCheckCompletion
case _ => stop.complete(()) *> releaseAndCheckCompletion
}
val background =
Stream.exec(semaphore.acquire) ++
source
.interruptWhen(stop.get.map(_.asRight[Throwable]))
.foreach(forkOnElem)
.onFinalizeCase {
case ExitCase.Succeeded => releaseAndCheckCompletion
case _ => stop.complete(()) *> releaseAndCheckCompletion
}

val foreground = channel.stream.evalMap(_.rethrow)
foreground.onFinalize(stop.complete(()) *> end.get).concurrently(background)
val foreground = channel.stream.evalMap(_.rethrow)
foreground.onFinalize(stop.complete(()) *> end.get).concurrently(background)
}
}

Stream.force(action)
}
}

/** Concurrent zip.
*
Expand Down Expand Up @@ -2455,12 +2458,13 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
*/
def prefetchN[F2[x] >: F[x]: Concurrent](
n: Int
): Stream[F2, O] =
): Stream[F2, O] = extendScopeThrough { source =>
Stream.eval(Channel.bounded[F2, Chunk[O]](n)).flatMap { chan =>
chan.stream.unchunks.concurrently {
chunks.through(chan.sendAll)
source.chunks.through(chan.sendAll)
}
}
}

/** Prints each element of this stream to standard out, converting each element to a `String` via `Show`. */
def printlns[F2[x] >: F[x], O2 >: O](implicit
Expand Down Expand Up @@ -2921,6 +2925,23 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
)(f: (Stream[F, O], Stream[F2, O2]) => Stream[F2, O3]): Stream[F2, O3] =
f(this, s2)

/** Transforms this stream, explicitly extending the current scope through the given pipe.
*
* Use this when implementing a pipe where the resulting stream is not directly constructed from
* the source stream, e.g. when sending the source stream through a Channel and returning the
* channel's stream.
*/
def extendScopeThrough[F2[x] >: F[x], O2](
f: Stream[F, O] => Stream[F2, O2]
)(implicit F: MonadError[F2, Throwable]): Stream[F2, O2] =
this.pull.peek
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it guaranteed to be safe (in the context of scopes) to use .peek before Pull.extendScopeTo?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming scopes work properly for simple streams. 🤞 Do you have a specific scenario in mind?

The fundamental issue is that I need to get ahold of a scope before we go adding more finalizers to source to avoid deadlocking in parEvalMap and I don't see a way to do that without uncons. peek is just a convenience helper.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a specific scenario in mind?

No, not really, it's just me being paranoid :)

I don't see a way to do it without uncons either, I was thinking about swapping the order of those Pulls:

def extendScopeThrough[F2[x] >: F[x], O2](
      f: Stream[F, O] => Stream[F2, O2]
  )(implicit F: MonadError[F2, Throwable]): Stream[F2, O2] =
    Pull
      .extendScopeTo(this.covary[F2])
      .flatMap { stream =>
        stream.pull.peek.flatMap {
          case Some((_, tl)) => f(tl).underlying // <---------------
          case None          => f(Stream.empty).underlying
        }
      }
      .stream

but the problem is tl is not a Stream[F, O] (it's a Stream[F2, O]), and .covary[F2] at the beginning has to be there because we only have MonadError for F2.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just make it f: Stream[F2, O] => Stream[F2, O2]. Requires type annotations at the call sites, but otherwise, this version also passes the tests.

Unless we can find an observable difference I'd rather keep the current version to avoid the extra annotations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree :)

.flatMap {
case Some((_, tl)) => Pull.extendScopeTo(f(tl))
case None => Pull.extendScopeTo(f(Stream.empty))
}
.flatMap(_.underlying)
.stream

/** Fails this stream with a `TimeoutException` if it does not complete within given `timeout`. */
def timeout[F2[x] >: F[x]: Temporal](
timeout: FiniteDuration
Expand Down
30 changes: 30 additions & 0 deletions core/shared/src/test/scala/fs2/ParEvalMapSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -293,4 +293,34 @@ class ParEvalMapSuite extends Fs2Suite {
.timeout(2.seconds)
}
}

group("issue-3076, parEvalMap* runs resource finaliser before usage") {
test("parEvalMap") {
Deferred[IO, Unit]
.flatMap { d =>
Stream
.bracket(IO.unit)(_ => d.complete(()).void)
.parEvalMap(2)(_ => IO.sleep(1.second))
.evalMap(_ => IO.sleep(1.second) >> d.complete(()))
.timeout(5.seconds)
.compile
.last
}
.assertEquals(Some(true))
}

test("broadcastThrough") {
Deferred[IO, Unit]
.flatMap { d =>
Stream
.bracket(IO.unit)(_ => d.complete(()).void)
.broadcastThrough(identity[Stream[IO, Unit]])
.evalMap(_ => IO.sleep(1.second) >> d.complete(()))
.timeout(5.seconds)
.compile
.last
}
.assertEquals(Some(true))
}
}
}
14 changes: 14 additions & 0 deletions core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,20 @@ class StreamCombinatorsSuite extends Fs2Suite {
)
.assertEquals(4.seconds)
}

test("scope propagation") {
Deferred[IO, Unit]
.flatMap { d =>
Stream
.bracket(IO.unit)(_ => d.complete(()).void)
.prefetch
.evalMap(_ => IO.sleep(1.second) >> d.complete(()))
.timeout(5.seconds)
.compile
.last
}
.assertEquals(Some(true))
}
}

test("range") {
Expand Down
Loading