Skip to content

Commit e229f82

Browse files
committed
Abstract scope extension with extendScopeThrough
Make extendScopeTo cancellation safe (see #3474)
1 parent 810af8a commit e229f82

File tree

3 files changed

+106
-86
lines changed

3 files changed

+106
-86
lines changed

core/shared/src/main/scala/fs2/Pull.scala

+1-4
Original file line numberDiff line numberDiff line change
@@ -464,10 +464,7 @@ object Pull extends PullLowPriority {
464464
def extendScopeTo[F[_], O](
465465
s: Stream[F, O]
466466
)(implicit F: MonadError[F, Throwable]): Pull[F, Nothing, Stream[F, O]] =
467-
for {
468-
scope <- Pull.getScope[F]
469-
lease <- Pull.eval(scope.lease)
470-
} yield s.onFinalize(lease.cancel.redeemWith(F.raiseError(_), _ => F.unit))
467+
Pull.getScope[F].map(scope => Stream.bracket(scope.lease)(_.cancel.rethrow) *> s)
471468

472469
/** Repeatedly uses the output of the pull as input for the next step of the
473470
* pull. Halts when a step terminates with `None` or `Pull.raiseError`.

core/shared/src/main/scala/fs2/Stream.scala

+91-82
Original file line numberDiff line numberDiff line change
@@ -238,39 +238,33 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
238238
*/
239239
def broadcastThrough[F2[x] >: F[x]: Concurrent, O2](pipes: Pipe[F2, O, O2]*): Stream[F2, O2] = {
240240
assert(pipes.nonEmpty, s"pipes should not be empty")
241-
underlying.uncons.flatMap {
242-
case Some((hd, tl)) =>
241+
extendScopeThrough { source =>
242+
Stream.force {
243243
for {
244244
// topic: contains the chunk that the pipes are processing at one point.
245245
// until and unless all pipes are finished with it, won't move to next one
246-
topic <- Pull.eval(Topic[F2, Chunk[O]])
246+
topic <- Topic[F2, Chunk[O]]
247247
// Coordination: neither the producer nor any consumer starts
248248
// until and unless all consumers are subscribed to topic.
249-
allReady <- Pull.eval(CountDownLatch[F2](pipes.length))
250-
251-
checkIn = allReady.release >> allReady.await
249+
allReady <- CountDownLatch[F2](pipes.length)
250+
} yield {
251+
val checkIn = allReady.release >> allReady.await
252252

253-
dump = (pipe: Pipe[F2, O, O2]) =>
253+
def dump(pipe: Pipe[F2, O, O2]): Stream[F2, O2] =
254254
Stream.resource(topic.subscribeAwait(1)).flatMap { sub =>
255255
// Wait until all pipes are ready before consuming.
256256
// Crucial: checkin is not passed to the pipe,
257257
// so pipe cannot interrupt it and alter the latch count
258258
Stream.exec(checkIn) ++ pipe(sub.unchunks)
259259
}
260260

261-
dumpAll: Stream[F2, O2] <-
262-
Pull.extendScopeTo(Stream(pipes: _*).map(dump).parJoinUnbounded)
263-
264-
chunksStream = Stream.chunk(hd).append(tl.stream).chunks
265-
261+
val dumpAll: Stream[F2, O2] = Stream(pipes: _*).map(dump).parJoinUnbounded
266262
// Wait until all pipes are checked in before pulling
267-
pump = Stream.exec(allReady.await) ++ topic.publish(chunksStream)
268-
269-
_ <- dumpAll.concurrently(pump).underlying
270-
} yield ()
271-
272-
case None => Pull.done
273-
}.stream
263+
val pump = Stream.exec(allReady.await) ++ topic.publish(source.chunks)
264+
dumpAll.concurrently(pump)
265+
}
266+
}
267+
}
274268
}
275269

276270
/** Behaves like the identity function, but requests `n` elements at a time from the input.
@@ -548,6 +542,13 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
548542
)(implicit F: Concurrent[F2]): Stream[F2, O] =
549543
concurrentlyAux(that).flatMap { case (startBack, fore) => startBack >> fore }
550544

545+
def concurrentlyExtendingThatScope[F2[x] >: F[x], O2](
546+
that: Stream[F2, O2]
547+
)(implicit F: Concurrent[F2]): Stream[F2, O] =
548+
that.extendScopeThrough(that =>
549+
concurrentlyAux(that).flatMap { case (startBack, fore) => startBack >> fore }
550+
)
551+
551552
private def concurrentlyAux[F2[x] >: F[x], O2](
552553
that: Stream[F2, O2]
553554
)(implicit
@@ -2331,75 +2332,65 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
23312332
channel: F2[Channel[F2, F2[Either[Throwable, O2]]]],
23322333
isOrdered: Boolean,
23332334
f: O => F2[O2]
2334-
)(implicit F: Concurrent[F2]): Stream[F2, O2] = {
2335-
val action =
2336-
(
2337-
Semaphore[F2](concurrency),
2338-
channel,
2339-
Deferred[F2, Unit],
2340-
Deferred[F2, Unit]
2341-
).mapN { (semaphore, channel, stop, end) =>
2342-
def initFork(release: F2[Unit]): F2[Either[Throwable, O2] => F2[Unit]] = {
2343-
def ordered: F2[Either[Throwable, O2] => F2[Unit]] = {
2344-
def send(v: Deferred[F2, Either[Throwable, O2]]) =
2345-
(el: Either[Throwable, O2]) => v.complete(el).void
2346-
2347-
Deferred[F2, Either[Throwable, O2]]
2348-
.flatTap(value => channel.send(release *> value.get))
2349-
.map(send)
2350-
}
2335+
)(implicit F: Concurrent[F2]): Stream[F2, O2] =
2336+
extendScopeThrough { source =>
2337+
Stream.force {
2338+
(
2339+
Semaphore[F2](concurrency),
2340+
channel,
2341+
Deferred[F2, Unit],
2342+
Deferred[F2, Unit]
2343+
).mapN { (semaphore, channel, stop, end) =>
2344+
def initFork(release: F2[Unit]): F2[Either[Throwable, O2] => F2[Unit]] = {
2345+
def ordered: F2[Either[Throwable, O2] => F2[Unit]] = {
2346+
def send(v: Deferred[F2, Either[Throwable, O2]]) =
2347+
(el: Either[Throwable, O2]) => v.complete(el).void
2348+
2349+
Deferred[F2, Either[Throwable, O2]]
2350+
.flatTap(value => channel.send(release *> value.get))
2351+
.map(send)
2352+
}
23512353

2352-
def unordered: Either[Throwable, O2] => F2[Unit] =
2353-
(el: Either[Throwable, O2]) => release <* channel.send(F.pure(el))
2354+
def unordered: Either[Throwable, O2] => F2[Unit] =
2355+
(el: Either[Throwable, O2]) => release <* channel.send(F.pure(el))
23542356

2355-
if (isOrdered) ordered else F.pure(unordered)
2356-
}
2357-
2358-
val releaseAndCheckCompletion =
2359-
semaphore.release *>
2360-
semaphore.available.flatMap {
2361-
case `concurrency` => channel.close *> end.complete(()).void
2362-
case _ => F.unit
2363-
}
2357+
if (isOrdered) ordered else F.pure(unordered)
2358+
}
23642359

2365-
def forkOnElem(el: O): F2[Unit] =
2366-
F.uncancelable { poll =>
2367-
poll(semaphore.acquire) <*
2368-
Deferred[F2, Unit].flatMap { pushed =>
2369-
val init = initFork(pushed.complete(()).void)
2370-
poll(init).onCancel(releaseAndCheckCompletion).flatMap { send =>
2371-
val action = F.catchNonFatal(f(el)).flatten.attempt.flatMap(send) *> pushed.get
2372-
F.start(stop.get.race(action) *> releaseAndCheckCompletion)
2373-
}
2360+
val releaseAndCheckCompletion =
2361+
semaphore.release *>
2362+
semaphore.available.flatMap {
2363+
case `concurrency` => channel.close *> end.complete(()).void
2364+
case _ => F.unit
23742365
}
2375-
}
23762366

2377-
underlying.uncons.flatMap {
2378-
case Some((hd, tl)) =>
2379-
for {
2380-
foreground <- Pull.extendScopeTo(
2381-
channel.stream.evalMap(_.rethrow).onFinalize(stop.complete(()) *> end.get)
2382-
)
2383-
background = Stream
2384-
.exec(semaphore.acquire) ++
2385-
Stream
2386-
.chunk(hd)
2387-
.append(tl.stream)
2388-
.interruptWhen(stop.get.map(_.asRight[Throwable]))
2389-
.foreach(forkOnElem)
2390-
.onFinalizeCase {
2391-
case ExitCase.Succeeded => releaseAndCheckCompletion
2392-
case _ => stop.complete(()) *> releaseAndCheckCompletion
2367+
def forkOnElem(el: O): F2[Unit] =
2368+
F.uncancelable { poll =>
2369+
poll(semaphore.acquire) <*
2370+
Deferred[F2, Unit].flatMap { pushed =>
2371+
val init = initFork(pushed.complete(()).void)
2372+
poll(init).onCancel(releaseAndCheckCompletion).flatMap { send =>
2373+
val action = F.catchNonFatal(f(el)).flatten.attempt.flatMap(send) *> pushed.get
2374+
F.start(stop.get.race(action) *> releaseAndCheckCompletion)
23932375
}
2394-
_ <- foreground.concurrently(background).underlying
2395-
} yield ()
2376+
}
2377+
}
23962378

2397-
case None => Pull.done
2398-
}.stream
2399-
}
2379+
val background =
2380+
Stream.exec(semaphore.acquire) ++
2381+
source
2382+
.interruptWhen(stop.get.map(_.asRight[Throwable]))
2383+
.foreach(forkOnElem)
2384+
.onFinalizeCase {
2385+
case ExitCase.Succeeded => releaseAndCheckCompletion
2386+
case _ => stop.complete(()) *> releaseAndCheckCompletion
2387+
}
24002388

2401-
Stream.force(action)
2402-
}
2389+
val foreground = channel.stream.evalMap(_.rethrow)
2390+
foreground.onFinalize(stop.complete(()) *> end.get).concurrently(background)
2391+
}
2392+
}
2393+
}
24032394

24042395
/** Concurrent zip.
24052396
*
@@ -2474,12 +2465,13 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
24742465
*/
24752466
def prefetchN[F2[x] >: F[x]: Concurrent](
24762467
n: Int
2477-
): Stream[F2, O] =
2468+
): Stream[F2, O] = extendScopeThrough { source =>
24782469
Stream.eval(Channel.bounded[F2, Chunk[O]](n)).flatMap { chan =>
24792470
chan.stream.unchunks.concurrently {
2480-
chunks.through(chan.sendAll)
2471+
source.chunks.through(chan.sendAll)
24812472
}
24822473
}
2474+
}
24832475

24842476
/** Prints each element of this stream to standard out, converting each element to a `String` via `Show`. */
24852477
def printlns[F2[x] >: F[x], O2 >: O](implicit
@@ -2940,6 +2932,23 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
29402932
)(f: (Stream[F, O], Stream[F2, O2]) => Stream[F2, O3]): Stream[F2, O3] =
29412933
f(this, s2)
29422934

2935+
/** Transforms this stream, explicitly extending the current scope through the given pipe.
2936+
*
2937+
* Use this when implementing a pipe where the resulting stream is not directly constructed from
2938+
* the source stream, e.g. when sending the source stream through a Channel and returning the
2939+
* channel's stream.
2940+
*/
2941+
def extendScopeThrough[F2[x] >: F[x], O2](
2942+
f: Stream[F, O] => Stream[F2, O2]
2943+
)(implicit F: MonadError[F2, Throwable]): Stream[F2, O2] =
2944+
this.pull.peek
2945+
.flatMap {
2946+
case Some((_, tl)) => Pull.extendScopeTo(f(tl))
2947+
case None => Pull.extendScopeTo(f(Stream.empty))
2948+
}
2949+
.flatMap(_.underlying)
2950+
.stream
2951+
29432952
/** Fails this stream with a `TimeoutException` if it does not complete within given `timeout`. */
29442953
def timeout[F2[x] >: F[x]: Temporal](
29452954
timeout: FiniteDuration

core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala

+14
Original file line numberDiff line numberDiff line change
@@ -1328,6 +1328,20 @@ class StreamCombinatorsSuite extends Fs2Suite {
13281328
)
13291329
.assertEquals(4.seconds)
13301330
}
1331+
1332+
test("scope propagation") {
1333+
Deferred[IO, Unit]
1334+
.flatMap { d =>
1335+
Stream
1336+
.bracket(IO.unit)(_ => d.complete(()).void)
1337+
.prefetch
1338+
.evalMap(_ => IO.sleep(1.second) >> d.complete(()))
1339+
.timeout(5.seconds)
1340+
.compile
1341+
.last
1342+
}
1343+
.assertEquals(Some(true))
1344+
}
13311345
}
13321346

13331347
test("range") {

0 commit comments

Comments
 (0)