Skip to content

Commit 9813a33

Browse files
committed
+str Add flatmapConcat with parallelism.
1 parent 45e73c3 commit 9813a33

File tree

11 files changed

+473
-11
lines changed

11 files changed

+473
-11
lines changed

akka-bench-jmh/src/main/scala/akka/stream/FlatMapConcatBenchmark.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import java.util.concurrent.TimeUnit
99

1010
import scala.concurrent.Await
1111
import scala.concurrent.duration._
12+
import scala.concurrent.Future
1213

1314
import com.typesafe.config.ConfigFactory
1415
import org.openjdk.jmh.annotations._
@@ -88,6 +89,18 @@ class FlatMapConcatBenchmark {
8889
awaitLatch(latch)
8990
}
9091

92+
@Benchmark
93+
@OperationsPerInvocation(OperationsPerInvocation)
94+
def completedFuture(): Unit = {
95+
val latch = new CountDownLatch(1)
96+
97+
testSource
98+
.flatMapConcat(n => Source.future(Future.successful(n)))
99+
.runWith(new LatchSink(OperationsPerInvocation, latch))
100+
101+
awaitLatch(latch)
102+
}
103+
91104
@Benchmark
92105
@OperationsPerInvocation(OperationsPerInvocation)
93106
def mapBaseline(): Unit = {
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Copyright (C) 2014-2023 Lightbend Inc. <https://www.lightbend.com>
3+
*/
4+
5+
package akka.stream.scaladsl
6+
7+
import akka.stream.OverflowStrategy
8+
import akka.stream.testkit._
9+
import akka.stream.testkit.scaladsl.TestSink
10+
11+
import java.util.concurrent.ThreadLocalRandom
12+
import java.util.concurrent.atomic.AtomicInteger
13+
import scala.concurrent.Future
14+
import scala.concurrent.duration.DurationInt
15+
import scala.util.control.NoStackTrace
16+
17+
class FlowFlatMapConcatSpec extends StreamSpec("""
18+
akka.stream.materializer.initial-input-buffer-size = 2
19+
""") with ScriptedTest {
20+
val toSeq = Flow[Int].grouped(1000).toMat(Sink.head)(Keep.right)
21+
22+
class BoomException extends RuntimeException("BOOM~~") with NoStackTrace
23+
"A flatMapConcat" must {
24+
25+
"work with value presented sources" in {
26+
Source(
27+
List(
28+
Source.empty[Int],
29+
Source.single(1),
30+
Source.empty[Int],
31+
Source(List(2, 3, 4)),
32+
Source.future(Future.successful(5)),
33+
Source.lazyFuture(() => Future.successful(6))))
34+
.flatMapConcat(ThreadLocalRandom.current().nextInt(1, 129), identity)
35+
.runWith(toSeq)
36+
.futureValue should ===(1 to 6)
37+
}
38+
39+
"work with value presented failed sources" in {
40+
val ex = new BoomException
41+
Source(
42+
List(
43+
Source.empty[Int],
44+
Source.single(1),
45+
Source.empty[Int],
46+
Source(List(2, 3, 4)),
47+
Source.future(Future.failed(ex)),
48+
Source.lazyFuture(() => Future.successful(5))))
49+
.flatMapConcat(ThreadLocalRandom.current().nextInt(1, 129), identity)
50+
.onErrorComplete[BoomException]()
51+
.runWith(toSeq)
52+
.futureValue should ===(1 to 4)
53+
}
54+
55+
"work with value presented sources when demands slow" in {
56+
val prob = Source(
57+
List(Source.empty[Int], Source.single(1), Source(List(2, 3, 4)), Source.lazyFuture(() => Future.successful(5))))
58+
.flatMapConcat(ThreadLocalRandom.current().nextInt(1, 129), identity)
59+
.runWith(TestSink())
60+
61+
prob.request(1)
62+
prob.expectNext(1)
63+
prob.expectNoMessage(1.seconds)
64+
prob.request(2)
65+
prob.expectNext(2, 3)
66+
prob.expectNoMessage(1.seconds)
67+
prob.request(2)
68+
prob.expectNext(4, 5)
69+
prob.expectComplete()
70+
}
71+
72+
"can do pre materialization when parallelism > 1" in {
73+
val materializationCounter = new AtomicInteger(0)
74+
val randomParallelism = ThreadLocalRandom.current().nextInt(4, 65)
75+
val prob = Source(1 to (randomParallelism * 3))
76+
.flatMapConcat(
77+
randomParallelism,
78+
value => {
79+
Source
80+
.lazySingle(() => {
81+
materializationCounter.incrementAndGet()
82+
value
83+
})
84+
.buffer(1, overflowStrategy = OverflowStrategy.backpressure)
85+
})
86+
.runWith(TestSink())
87+
88+
expectNoMessage(1.seconds)
89+
materializationCounter.get() shouldBe 0
90+
91+
prob.request(1)
92+
prob.expectNext(1.seconds, 1)
93+
expectNoMessage(1.seconds)
94+
materializationCounter.get() shouldBe (randomParallelism + 1)
95+
materializationCounter.set(0)
96+
97+
prob.request(2)
98+
prob.expectNextN(List(2, 3))
99+
expectNoMessage(1.seconds)
100+
materializationCounter.get() shouldBe 2
101+
materializationCounter.set(0)
102+
103+
prob.request(randomParallelism - 3)
104+
prob.expectNextN(4 to randomParallelism)
105+
expectNoMessage(1.seconds)
106+
materializationCounter.get() shouldBe (randomParallelism - 3)
107+
materializationCounter.set(0)
108+
109+
prob.request(randomParallelism)
110+
prob.expectNextN(randomParallelism + 1 to randomParallelism * 2)
111+
expectNoMessage(1.seconds)
112+
materializationCounter.get() shouldBe randomParallelism
113+
materializationCounter.set(0)
114+
115+
prob.request(randomParallelism)
116+
prob.expectNextN(randomParallelism * 2 + 1 to randomParallelism * 3)
117+
expectNoMessage(1.seconds)
118+
materializationCounter.get() shouldBe 0
119+
prob.expectComplete()
120+
}
121+
122+
}
123+
124+
}

akka-stream/src/main/scala/akka/stream/impl/Stages.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ import akka.stream.Attributes._
7979
val mergePreferred = name("mergePreferred")
8080
val mergePrioritized = name("mergePrioritized")
8181
val flattenMerge = name("flattenMerge")
82+
val flattenConcat = name("flattenConcat")
8283
val recoverWith = name("recoverWith")
8384
val onErrorComplete = name("onErrorComplete")
8485
val broadcast = name("broadcast")

akka-stream/src/main/scala/akka/stream/impl/TraversalBuilder.scala

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ package akka.stream.impl
66

77
import scala.collection.immutable.Map.Map1
88
import scala.language.existentials
9-
109
import akka.annotation.{ DoNotInherit, InternalApi }
1110
import akka.stream._
1211
import akka.stream.impl.StreamLayout.AtomicModule
1312
import akka.stream.impl.TraversalBuilder.{ AnyFunction1, AnyFunction2 }
1413
import akka.stream.impl.fusing.GraphStageModule
14+
import akka.stream.impl.fusing.GraphStages.IterableSource
1515
import akka.stream.impl.fusing.GraphStages.SingleSource
1616
import akka.stream.scaladsl.Keep
1717
import akka.util.OptionVal
@@ -371,6 +371,37 @@ import akka.util.unused
371371
}
372372
}
373373

374+
def getValuePresentedSource[A >: Null](graph: Graph[SourceShape[A], _]): OptionVal[Graph[SourceShape[A], _]] = {
375+
def isValuePresentedSource(graph: Graph[SourceShape[_ <: A], _]): Boolean = graph match {
376+
case _: SingleSource[_] | _: IterableSource[_] | EmptySource => true
377+
case _ => false
378+
}
379+
graph match {
380+
case _ if isValuePresentedSource(graph) => OptionVal.Some(graph)
381+
case _ =>
382+
graph.traversalBuilder match {
383+
case l: LinearTraversalBuilder =>
384+
l.pendingBuilder match {
385+
case OptionVal.Some(a: AtomicTraversalBuilder) =>
386+
a.module match {
387+
case m: GraphStageModule[_, _] =>
388+
m.stage match {
389+
case _ if isValuePresentedSource(m.stage.asInstanceOf[Graph[SourceShape[A], _]]) =>
390+
// It would be != EmptyTraversal if mapMaterializedValue was used and then we can't optimize.
391+
if ((l.traversalSoFar eq EmptyTraversal) && !l.attributes.isAsync)
392+
OptionVal.Some(m.stage.asInstanceOf[Graph[SourceShape[A], _]])
393+
else OptionVal.None
394+
case _ => OptionVal.None
395+
}
396+
case _ => OptionVal.None
397+
}
398+
case _ => OptionVal.None
399+
}
400+
case _ => OptionVal.None
401+
}
402+
}
403+
}
404+
374405
/**
375406
* Test if a Graph is an empty Source.
376407
* */

akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,6 +1272,7 @@ private[stream] object Collect {
12721272
*/
12731273
@InternalApi private[akka] final case class MapAsync[In, Out](parallelism: Int, f: In => Future[Out])
12741274
extends GraphStage[FlowShape[In, Out]] {
1275+
require(parallelism >= 1, "parallelism should >= 1")
12751276

12761277
import MapAsync._
12771278

0 commit comments

Comments
 (0)