Skip to content

Commit 9ed93c5

Browse files
committed
feat: Add flatmapConcat with parallelism.
1 parent 27eee92 commit 9ed93c5

File tree

13 files changed

+727
-24
lines changed

13 files changed

+727
-24
lines changed

akka-bench-jmh/src/main/scala/akka/stream/FlatMapConcatBenchmark.scala

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@ package akka.stream
66

77
import java.util.concurrent.CountDownLatch
88
import java.util.concurrent.TimeUnit
9-
10-
import scala.concurrent.Await
9+
import scala.concurrent.{ Await, Future }
1110
import scala.concurrent.duration._
12-
1311
import com.typesafe.config.ConfigFactory
1412
import org.openjdk.jmh.annotations._
1513

@@ -60,31 +58,91 @@ class FlatMapConcatBenchmark {
6058
@OperationsPerInvocation(OperationsPerInvocation)
6159
def sourceDotSingle(): Unit = {
6260
val latch = new CountDownLatch(1)
63-
6461
testSource.flatMapConcat(Source.single).runWith(new LatchSink(OperationsPerInvocation, latch))
62+
awaitLatch(latch)
63+
}
6564

65+
@Benchmark
66+
@OperationsPerInvocation(OperationsPerInvocation)
67+
def sourceDotSingleP1(): Unit = {
68+
val latch = new CountDownLatch(1)
69+
testSource.flatMapConcat(1, Source.single).runWith(new LatchSink(OperationsPerInvocation, latch))
6670
awaitLatch(latch)
6771
}
6872

6973
@Benchmark
7074
@OperationsPerInvocation(OperationsPerInvocation)
7175
def internalSingleSource(): Unit = {
7276
val latch = new CountDownLatch(1)
73-
7477
testSource
7578
.flatMapConcat(elem => new GraphStages.SingleSource(elem))
7679
.runWith(new LatchSink(OperationsPerInvocation, latch))
80+
awaitLatch(latch)
81+
}
7782

83+
@Benchmark
84+
@OperationsPerInvocation(OperationsPerInvocation)
85+
def internalSingleSourceP1(): Unit = {
86+
val latch = new CountDownLatch(1)
87+
testSource
88+
.flatMapConcat(1, elem => new GraphStages.SingleSource(elem))
89+
.runWith(new LatchSink(OperationsPerInvocation, latch))
7890
awaitLatch(latch)
7991
}
8092

8193
@Benchmark
8294
@OperationsPerInvocation(OperationsPerInvocation)
8395
def oneElementList(): Unit = {
8496
val latch = new CountDownLatch(1)
85-
8697
testSource.flatMapConcat(n => Source(n :: Nil)).runWith(new LatchSink(OperationsPerInvocation, latch))
98+
awaitLatch(latch)
99+
}
100+
101+
@Benchmark
102+
@OperationsPerInvocation(OperationsPerInvocation)
103+
def oneElementListP1(): Unit = {
104+
val latch = new CountDownLatch(1)
105+
testSource.flatMapConcat(1, n => Source(n :: Nil)).runWith(new LatchSink(OperationsPerInvocation, latch))
106+
awaitLatch(latch)
107+
}
108+
109+
@Benchmark
110+
@OperationsPerInvocation(OperationsPerInvocation)
111+
def completedFuture(): Unit = {
112+
val latch = new CountDownLatch(1)
113+
testSource
114+
.flatMapConcat(n => Source.future(Future.successful(n)))
115+
.runWith(new LatchSink(OperationsPerInvocation, latch))
116+
awaitLatch(latch)
117+
}
118+
119+
@Benchmark
120+
@OperationsPerInvocation(OperationsPerInvocation)
121+
def completedFutureP1(): Unit = {
122+
val latch = new CountDownLatch(1)
123+
testSource
124+
.flatMapConcat(1, n => Source.future(Future.successful(n)))
125+
.runWith(new LatchSink(OperationsPerInvocation, latch))
126+
awaitLatch(latch)
127+
}
87128

129+
@Benchmark
130+
@OperationsPerInvocation(OperationsPerInvocation)
131+
def normalFuture(): Unit = {
132+
val latch = new CountDownLatch(1)
133+
testSource
134+
.flatMapConcat(n => Source.future(Future(n)(system.dispatcher)))
135+
.runWith(new LatchSink(OperationsPerInvocation, latch))
136+
awaitLatch(latch)
137+
}
138+
139+
@Benchmark
140+
@OperationsPerInvocation(OperationsPerInvocation)
141+
def normalFutureP1(): Unit = {
142+
val latch = new CountDownLatch(1)
143+
testSource
144+
.flatMapConcat(1, n => Source.future(Future(n)(system.dispatcher)))
145+
.runWith(new LatchSink(OperationsPerInvocation, latch))
88146
awaitLatch(latch)
89147
}
90148

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
/*
2+
* Copyright (C) 2014-2024 Lightbend Inc. <https://www.lightbend.com>
3+
*/
4+
5+
package akka.stream.scaladsl
6+
7+
import akka.NotUsed
8+
import akka.pattern.FutureTimeoutSupport
9+
import akka.stream.OverflowStrategy
10+
import akka.stream.testkit._
11+
import akka.stream.testkit.scaladsl.TestSink
12+
13+
import java.util.concurrent.ThreadLocalRandom
14+
import java.util.concurrent.atomic.AtomicInteger
15+
import scala.annotation.switch
16+
import scala.concurrent.{ ExecutionContext, Future }
17+
import scala.concurrent.duration.DurationInt
18+
import scala.util.control.NoStackTrace
19+
20+
class FlowFlatMapConcatParallelismSpec extends StreamSpec("""
21+
akka.stream.materializer.initial-input-buffer-size = 2
22+
""") with ScriptedTest with FutureTimeoutSupport {
23+
val toSeq = Flow[Int].grouped(1000).toMat(Sink.head)(Keep.right)
24+
25+
class BoomException extends RuntimeException("BOOM~~") with NoStackTrace
26+
"A flatMapConcat" must {
27+
28+
for (i <- 1 until 129) {
29+
s"work with value presented sources with parallelism: $i" in {
30+
Source(
31+
List(
32+
Source.empty[Int],
33+
Source.single(1),
34+
Source.empty[Int],
35+
Source(List(2, 3, 4)),
36+
Source.future(Future.successful(5)),
37+
Source.lazyFuture(() => Future.successful(6)),
38+
Source.future(after(1.millis)(Future.successful(7)))))
39+
.flatMapConcat(i, identity)
40+
.runWith(toSeq)
41+
.futureValue should ===(1 to 7)
42+
}
43+
}
44+
45+
def generateRandomValuePresentedSources(nums: Int): (Int, Seq[Source[Int, NotUsed]]) = {
46+
val seq = Seq.tabulate(nums) { _ =>
47+
val random = ThreadLocalRandom.current().nextInt(1, 10)
48+
(random: @switch) match {
49+
case 1 => Source.single(1)
50+
case 2 => Source(List(1))
51+
case 3 => Source.fromJavaStream(() => java.util.stream.Stream.of(1))
52+
case 4 => Source.future(Future.successful(1))
53+
case 5 => Source.future(after(1.millis)(Future.successful(1)))
54+
case _ => Source.empty[Int]
55+
}
56+
}
57+
val sum = seq.filterNot(_.eq(Source.empty[Int])).size
58+
(sum, seq)
59+
}
60+
61+
def generateSequencedValuePresentedSources(nums: Int): (Int, Seq[Source[Int, NotUsed]]) = {
62+
val seq = Seq.tabulate(nums) { index =>
63+
val random = ThreadLocalRandom.current().nextInt(1, 6)
64+
(random: @switch) match {
65+
case 1 => Source.single(index)
66+
case 2 => Source(List(index))
67+
case 3 => Source.fromJavaStream(() => java.util.stream.Stream.of(index))
68+
case 4 => Source.future(Future.successful(index))
69+
case 5 => Source.future(after(1.millis)(Future.successful(index)))
70+
case _ => throw new IllegalStateException("unexpected")
71+
}
72+
}
73+
val sum = (0 until nums).sum
74+
(sum, seq)
75+
}
76+
77+
for (i <- 1 until 129) {
78+
s"work with generated value presented sources with parallelism: $i " in {
79+
val (sum, sources) = generateRandomValuePresentedSources(100000)
80+
Source(sources)
81+
.flatMapConcat(i, identity)
82+
.runWith(Sink.seq)
83+
.map(_.sum)(ExecutionContext.parasitic)
84+
.futureValue shouldBe sum
85+
}
86+
}
87+
88+
for (i <- 1 until 129) {
89+
s"work with generated value sequenced sources with parallelism: $i " in {
90+
val (sum, sources) = generateSequencedValuePresentedSources(100000)
91+
Source(sources)
92+
.flatMapConcat(i, identity)
93+
//check the order
94+
.statefulMap(() => -1)((pre, current) => {
95+
if (pre + 1 != current) {
96+
throw new IllegalStateException(s"expected $pre + 1 == $current")
97+
}
98+
(current, current)
99+
}, _ => None)
100+
.runWith(Sink.seq)
101+
.map(_.sum)(ExecutionContext.parasitic)
102+
.futureValue shouldBe sum
103+
}
104+
}
105+
106+
"work with value presented failed sources" in {
107+
val ex = new BoomException
108+
Source(
109+
List(
110+
Source.empty[Int],
111+
Source.single(1),
112+
Source.empty[Int],
113+
Source(List(2, 3, 4)),
114+
Source.future(Future.failed(ex)),
115+
Source.lazyFuture(() => Future.successful(5))))
116+
.flatMapConcat(ThreadLocalRandom.current().nextInt(1, 129), identity)
117+
.onErrorComplete[BoomException]()
118+
.runWith(toSeq)
119+
.futureValue should ===(1 to 4)
120+
}
121+
122+
"work with value presented sources when demands slow" in {
123+
val prob = Source(
124+
List(Source.empty[Int], Source.single(1), Source(List(2, 3, 4)), Source.lazyFuture(() => Future.successful(5))))
125+
.flatMapConcat(ThreadLocalRandom.current().nextInt(1, 129), identity)
126+
.runWith(TestSink())
127+
128+
prob.request(1)
129+
prob.expectNext(1)
130+
prob.expectNoMessage(1.seconds)
131+
prob.request(2)
132+
prob.expectNext(2, 3)
133+
prob.expectNoMessage(1.seconds)
134+
prob.request(2)
135+
prob.expectNext(4, 5)
136+
prob.expectComplete()
137+
}
138+
139+
"can do pre materialization when parallelism > 1" in {
140+
val materializationCounter = new AtomicInteger(0)
141+
val randomParallelism = ThreadLocalRandom.current().nextInt(4, 65)
142+
val prob = Source(1 to (randomParallelism * 3))
143+
.flatMapConcat(
144+
randomParallelism,
145+
value => {
146+
Source
147+
.lazySingle(() => {
148+
materializationCounter.incrementAndGet()
149+
value
150+
})
151+
.buffer(1, overflowStrategy = OverflowStrategy.backpressure)
152+
})
153+
.runWith(TestSink())
154+
155+
expectNoMessage(1.seconds)
156+
materializationCounter.get() shouldBe 0
157+
158+
prob.request(1)
159+
prob.expectNext(1.seconds, 1)
160+
expectNoMessage(1.seconds)
161+
materializationCounter.get() shouldBe (randomParallelism + 1)
162+
materializationCounter.set(0)
163+
164+
prob.request(2)
165+
prob.expectNextN(List(2, 3))
166+
expectNoMessage(1.seconds)
167+
materializationCounter.get() shouldBe 2
168+
materializationCounter.set(0)
169+
170+
prob.request(randomParallelism - 3)
171+
prob.expectNextN(4 to randomParallelism)
172+
expectNoMessage(1.seconds)
173+
materializationCounter.get() shouldBe (randomParallelism - 3)
174+
materializationCounter.set(0)
175+
176+
prob.request(randomParallelism)
177+
prob.expectNextN(randomParallelism + 1 to randomParallelism * 2)
178+
expectNoMessage(1.seconds)
179+
materializationCounter.get() shouldBe randomParallelism
180+
materializationCounter.set(0)
181+
182+
prob.request(randomParallelism)
183+
prob.expectNextN(randomParallelism * 2 + 1 to randomParallelism * 3)
184+
expectNoMessage(1.seconds)
185+
materializationCounter.get() shouldBe 0
186+
prob.expectComplete()
187+
}
188+
189+
}
190+
191+
}

akka-stream/src/main/scala/akka/stream/impl/FailedSource.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import akka.stream.stage.{ GraphStage, GraphStageLogic, OutHandler }
1212
/**
1313
* INTERNAL API
1414
*/
15-
@InternalApi private[akka] final class FailedSource[T](failure: Throwable) extends GraphStage[SourceShape[T]] {
15+
@InternalApi private[akka] final class FailedSource[T](val failure: Throwable) extends GraphStage[SourceShape[T]] {
1616
val out = Outlet[T]("FailedSource.out")
1717
override val shape = SourceShape(out)
1818

akka-stream/src/main/scala/akka/stream/impl/JavaStreamSource.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import java.util.function.Consumer
1212

1313
/** INTERNAL API */
1414
@InternalApi private[stream] final class JavaStreamSource[T, S <: java.util.stream.BaseStream[T, S]](
15-
open: () => java.util.stream.BaseStream[T, S])
15+
val open: () => java.util.stream.BaseStream[T, S])
1616
extends GraphStage[SourceShape[T]] {
1717

1818
val out: Outlet[T] = Outlet("JavaStreamSource")

akka-stream/src/main/scala/akka/stream/impl/Stages.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ import akka.stream.Attributes._
7979
val mergePreferred = name("mergePreferred")
8080
val mergePrioritized = name("mergePrioritized")
8181
val flattenMerge = name("flattenMerge")
82+
val flattenConcat = name("flattenConcat")
8283
val recoverWith = name("recoverWith")
8384
val onErrorComplete = name("onErrorComplete")
8485
val broadcast = name("broadcast")

akka-stream/src/main/scala/akka/stream/impl/TraversalBuilder.scala

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@ package akka.stream.impl
66

77
import scala.collection.immutable.Map.Map1
88
import scala.language.existentials
9-
109
import akka.annotation.{ DoNotInherit, InternalApi }
1110
import akka.stream._
1211
import akka.stream.impl.StreamLayout.AtomicModule
1312
import akka.stream.impl.TraversalBuilder.{ AnyFunction1, AnyFunction2 }
1413
import akka.stream.impl.fusing.GraphStageModule
15-
import akka.stream.impl.fusing.GraphStages.SingleSource
14+
import akka.stream.impl.fusing.GraphStages.{ FutureSource, IterableSource, SingleSource }
1615
import akka.stream.scaladsl.Keep
1716
import akka.util.OptionVal
1817

@@ -369,12 +368,51 @@ import akka.util.OptionVal
369368
}
370369
}
371370

371+
/**
372+
* Try to find `SingleSource` or wrapped such. This is used as a
373+
* performance optimization in FlattenConcat and possibly other places.
374+
*/
375+
def getValuePresentedSource[A >: Null](graph: Graph[SourceShape[A], _]): OptionVal[Graph[SourceShape[A], _]] = {
376+
def isValuePresentedSource(graph: Graph[SourceShape[_ <: A], _]): Boolean = graph match {
377+
case _: SingleSource[_] | _: FutureSource[_] | _: IterableSource[_] | _: JavaStreamSource[_, _] |
378+
_: FailedSource[_] =>
379+
true
380+
case maybeEmpty if isEmptySource(maybeEmpty) => true
381+
case _ => false
382+
}
383+
graph match {
384+
case _ if isValuePresentedSource(graph) => OptionVal.Some(graph)
385+
case _ =>
386+
graph.traversalBuilder match {
387+
case l: LinearTraversalBuilder =>
388+
l.pendingBuilder match {
389+
case OptionVal.Some(a: AtomicTraversalBuilder) =>
390+
a.module match {
391+
case m: GraphStageModule[_, _] =>
392+
m.stage match {
393+
case _ if isValuePresentedSource(m.stage.asInstanceOf[Graph[SourceShape[A], _]]) =>
394+
// It would be != EmptyTraversal if mapMaterializedValue was used and then we can't optimize.
395+
if ((l.traversalSoFar eq EmptyTraversal) && !l.attributes.isAsync)
396+
OptionVal.Some(m.stage.asInstanceOf[Graph[SourceShape[A], _]])
397+
else OptionVal.None
398+
case _ => OptionVal.None
399+
}
400+
case _ => OptionVal.None
401+
}
402+
case _ => OptionVal.None
403+
}
404+
case _ => OptionVal.None
405+
}
406+
}
407+
}
408+
372409
/**
373410
* Test if a Graph is an empty Source.
374411
* */
375412
def isEmptySource(graph: Graph[SourceShape[_], _]): Boolean = graph match {
376413
case source: scaladsl.Source[_, _] if source eq scaladsl.Source.empty => true
377414
case source: javadsl.Source[_, _] if source eq javadsl.Source.empty() => true
415+
case EmptySource => true
378416
case _ => false
379417
}
380418

akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,6 +1268,7 @@ private[stream] object Collect {
12681268
*/
12691269
@InternalApi private[akka] final case class MapAsync[In, Out](parallelism: Int, f: In => Future[Out])
12701270
extends GraphStage[FlowShape[In, Out]] {
1271+
require(parallelism >= 1, "parallelism should >= 1")
12711272

12721273
import MapAsync._
12731274

0 commit comments

Comments
 (0)