Skip to content

Commit 94e7771

Browse files
committed
feat: Add flatMapConcat with parallelism support.
1 parent eb5dc14 commit 94e7771

File tree

9 files changed

+731
-1
lines changed

9 files changed

+731
-1
lines changed

bench-jmh/src/main/scala/org/apache/pekko/stream/FlatMapConcatBenchmark.scala

+81-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ package org.apache.pekko.stream
1616
import java.util.concurrent.CountDownLatch
1717
import java.util.concurrent.TimeUnit
1818

19-
import scala.concurrent.Await
19+
import scala.concurrent.{ Await, Future }
2020
import scala.concurrent.duration._
2121

2222
import com.typesafe.config.ConfigFactory
@@ -76,6 +76,16 @@ class FlatMapConcatBenchmark {
7676
awaitLatch(latch)
7777
}
7878

79+
@Benchmark
80+
@OperationsPerInvocation(OperationsPerInvocation)
81+
def sourceDotSingleP1(): Unit = {
82+
val latch = new CountDownLatch(1)
83+
84+
testSource.flatMapConcat(1, Source.single).runWith(new LatchSink(OperationsPerInvocation, latch))
85+
86+
awaitLatch(latch)
87+
}
88+
7989
@Benchmark
8090
@OperationsPerInvocation(OperationsPerInvocation)
8191
def internalSingleSource(): Unit = {
@@ -88,6 +98,18 @@ class FlatMapConcatBenchmark {
8898
awaitLatch(latch)
8999
}
90100

101+
@Benchmark
102+
@OperationsPerInvocation(OperationsPerInvocation)
103+
def internalSingleSourceP1(): Unit = {
104+
val latch = new CountDownLatch(1)
105+
106+
testSource
107+
.flatMapConcat(1, elem => new GraphStages.SingleSource(elem))
108+
.runWith(new LatchSink(OperationsPerInvocation, latch))
109+
110+
awaitLatch(latch)
111+
}
112+
91113
@Benchmark
92114
@OperationsPerInvocation(OperationsPerInvocation)
93115
def oneElementList(): Unit = {
@@ -98,6 +120,64 @@ class FlatMapConcatBenchmark {
98120
awaitLatch(latch)
99121
}
100122

123+
@Benchmark
124+
@OperationsPerInvocation(OperationsPerInvocation)
125+
def oneElementListP1(): Unit = {
126+
val latch = new CountDownLatch(1)
127+
128+
testSource.flatMapConcat(1, n => Source(n :: Nil)).runWith(new LatchSink(OperationsPerInvocation, latch))
129+
130+
awaitLatch(latch)
131+
}
132+
133+
@Benchmark
134+
@OperationsPerInvocation(OperationsPerInvocation)
135+
def completedFuture(): Unit = {
136+
val latch = new CountDownLatch(1)
137+
138+
testSource
139+
.flatMapConcat(n => Source.future(Future.successful(n)))
140+
.runWith(new LatchSink(OperationsPerInvocation, latch))
141+
142+
awaitLatch(latch)
143+
}
144+
145+
@Benchmark
146+
@OperationsPerInvocation(OperationsPerInvocation)
147+
def completedFutureP1(): Unit = {
148+
val latch = new CountDownLatch(1)
149+
150+
testSource
151+
.flatMapConcat(1, n => Source.future(Future.successful(n)))
152+
.runWith(new LatchSink(OperationsPerInvocation, latch))
153+
154+
awaitLatch(latch)
155+
}
156+
157+
@Benchmark
158+
@OperationsPerInvocation(OperationsPerInvocation)
159+
def normalFuture(): Unit = {
160+
val latch = new CountDownLatch(1)
161+
162+
testSource
163+
.flatMapConcat(n => Source.future(Future(n)(system.dispatcher)))
164+
.runWith(new LatchSink(OperationsPerInvocation, latch))
165+
166+
awaitLatch(latch)
167+
}
168+
169+
@Benchmark
170+
@OperationsPerInvocation(OperationsPerInvocation)
171+
def normalFutureP1(): Unit = {
172+
val latch = new CountDownLatch(1)
173+
174+
testSource
175+
.flatMapConcat(1, n => Source.future(Future(n)(system.dispatcher)))
176+
.runWith(new LatchSink(OperationsPerInvocation, latch))
177+
178+
awaitLatch(latch)
179+
}
180+
101181
@Benchmark
102182
@OperationsPerInvocation(OperationsPerInvocation)
103183
def mapBaseline(): Unit = {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.pekko.stream.scaladsl
19+
20+
21+
import org.apache.pekko
22+
import pekko.pattern.FutureTimeoutSupport
23+
import pekko.NotUsed
24+
import pekko.stream._
25+
import pekko.stream.testkit.{ ScriptedTest, StreamSpec }
26+
import pekko.stream.testkit.scaladsl.TestSink
27+
28+
import java.util.concurrent.ThreadLocalRandom
29+
import java.util.concurrent.atomic.AtomicInteger
30+
import java.util.Collections
31+
32+
import scala.annotation.switch
33+
import scala.concurrent.duration.DurationInt
34+
import scala.concurrent.Future
35+
import scala.util.control.NoStackTrace
36+
37+
class FlowFlatMapConcatParallelismSpec extends StreamSpec("""
38+
pekko.stream.materializer.initial-input-buffer-size = 2
39+
""") with ScriptedTest with FutureTimeoutSupport {
40+
val toSeq = Flow[Int].grouped(1000).toMat(Sink.head)(Keep.right)
41+
42+
class BoomException extends RuntimeException("BOOM~~") with NoStackTrace
43+
"A flatMapConcat" must {
44+
45+
for (i <- 1 until 129) {
46+
s"work with value presented sources with parallelism: $i" in {
47+
Source(
48+
List(
49+
Source.empty[Int],
50+
Source.single(1),
51+
Source.empty[Int],
52+
Source(List(2, 3, 4)),
53+
Source.future(Future.successful(5)),
54+
Source.lazyFuture(() => Future.successful(6)),
55+
Source.future(after(1.millis)(Future.successful(7)))))
56+
.flatMapConcat(i, identity)
57+
.runWith(toSeq)
58+
.futureValue should ===(1 to 7)
59+
}
60+
}
61+
62+
def generateRandomValuePresentedSources(nums: Int): (Int, List[Source[Int, NotUsed]]) = {
63+
val seq = List.tabulate(nums) { _ =>
64+
val random = ThreadLocalRandom.current().nextInt(1, 10)
65+
(random: @switch) match {
66+
case 1 => Source.single(1)
67+
case 2 => Source(List(1))
68+
case 3 => Source.fromJavaStream(() => Collections.singleton(1).stream())
69+
case 4 => Source.future(Future.successful(1))
70+
case 5 => Source.future(after(1.millis)(Future.successful(1)))
71+
case _ => Source.empty[Int]
72+
}
73+
}
74+
val sum = seq.filterNot(_.eq(Source.empty[Int])).size
75+
(sum, seq)
76+
}
77+
78+
def generateSequencedValuePresentedSources(nums: Int): (Int, List[Source[Int, NotUsed]]) = {
79+
val seq = List.tabulate(nums) { index =>
80+
val random = ThreadLocalRandom.current().nextInt(1, 6)
81+
(random: @switch) match {
82+
case 1 => Source.single(index)
83+
case 2 => Source(List(index))
84+
case 3 => Source.fromJavaStream(() => Collections.singleton(index).stream())
85+
case 4 => Source.future(Future.successful(index))
86+
case 5 => Source.future(after(1.millis)(Future.successful(index)))
87+
case _ => throw new IllegalStateException("unexpected")
88+
}
89+
}
90+
val sum = (0 until nums).sum
91+
(sum, seq)
92+
}
93+
94+
for (i <- 1 until 129) {
95+
s"work with generated value presented sources with parallelism: $i " in {
96+
val (sum, sources @ _) = generateRandomValuePresentedSources(100000)
97+
Source(sources)
98+
.flatMapConcat(i, identity(_)) // scala 2.12 can't infer the type of identity
99+
.runWith(Sink.seq)
100+
.map(_.sum)(pekko.dispatch.ExecutionContexts.parasitic)
101+
.futureValue shouldBe sum
102+
}
103+
}
104+
105+
for (i <- 1 until 129) {
106+
s"work with generated value sequenced sources with parallelism: $i " in {
107+
val (sum, sources @ _) = generateSequencedValuePresentedSources(100000)
108+
Source(sources)
109+
.flatMapConcat(i, identity(_)) // scala 2.12 can't infer the type of identity
110+
// check the order
111+
.statefulMap(() => -1)((pre, current) => {
112+
if (pre + 1 != current) {
113+
throw new IllegalStateException(s"expected $pre + 1 == $current")
114+
}
115+
(current, current)
116+
}, _ => None)
117+
.runWith(Sink.seq)
118+
.map(_.sum)(pekko.dispatch.ExecutionContexts.parasitic)
119+
.futureValue shouldBe sum
120+
}
121+
}
122+
123+
"work with value presented failed sources" in {
124+
val ex = new BoomException
125+
Source(
126+
List(
127+
Source.empty[Int],
128+
Source.single(1),
129+
Source.empty[Int],
130+
Source(List(2, 3, 4)),
131+
Source.future(Future.failed(ex)),
132+
Source.lazyFuture(() => Future.successful(5))))
133+
.flatMapConcat(ThreadLocalRandom.current().nextInt(1, 129), identity)
134+
.onErrorComplete[BoomException]()
135+
.runWith(toSeq)
136+
.futureValue should ===(1 to 4)
137+
}
138+
139+
"work with value presented sources when demands slow" in {
140+
val prob = Source(
141+
List(Source.empty[Int], Source.single(1), Source(List(2, 3, 4)), Source.lazyFuture(() => Future.successful(5))))
142+
.flatMapConcat(ThreadLocalRandom.current().nextInt(1, 129), identity)
143+
.runWith(TestSink())
144+
145+
prob.request(1)
146+
prob.expectNext(1)
147+
prob.expectNoMessage(1.seconds)
148+
prob.request(2)
149+
prob.expectNext(2, 3)
150+
prob.expectNoMessage(1.seconds)
151+
prob.request(2)
152+
prob.expectNext(4, 5)
153+
prob.expectComplete()
154+
}
155+
156+
val parallelism = ThreadLocalRandom.current().nextInt(4, 65)
157+
s"can do pre materialization when parallelism > 1, parallelism is $parallelism" in {
158+
val materializationCounter = new AtomicInteger(0)
159+
val prob = Source(1 to (parallelism * 3))
160+
.flatMapConcat(
161+
parallelism,
162+
value => {
163+
Source
164+
.lazySingle(() => {
165+
materializationCounter.incrementAndGet()
166+
value
167+
})
168+
.buffer(1, overflowStrategy = OverflowStrategy.backpressure)
169+
})
170+
.runWith(TestSink())
171+
172+
expectNoMessage(1.seconds)
173+
materializationCounter.get() shouldBe 0
174+
175+
prob.request(1)
176+
prob.expectNext(1.seconds, 1)
177+
expectNoMessage(1.seconds)
178+
materializationCounter.get() shouldBe (parallelism + 1)
179+
materializationCounter.set(0)
180+
181+
prob.request(2)
182+
prob.expectNextN(List(2, 3))
183+
expectNoMessage(1.seconds)
184+
materializationCounter.get() shouldBe 2
185+
materializationCounter.set(0)
186+
187+
prob.request(parallelism - 3)
188+
prob.expectNextN(4 to parallelism)
189+
expectNoMessage(1.seconds)
190+
materializationCounter.get() shouldBe (parallelism - 3)
191+
materializationCounter.set(0)
192+
193+
prob.request(parallelism)
194+
prob.expectNextN(parallelism + 1 to parallelism * 2)
195+
expectNoMessage(1.seconds)
196+
materializationCounter.get() shouldBe parallelism
197+
materializationCounter.set(0)
198+
199+
prob.request(parallelism)
200+
prob.expectNextN(parallelism * 2 + 1 to parallelism * 3)
201+
expectNoMessage(1.seconds)
202+
materializationCounter.get() shouldBe 0
203+
prob.expectComplete()
204+
}
205+
}
206+
}

stream/src/main/scala/org/apache/pekko/stream/impl/Stages.scala

+1
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ import pekko.stream.Attributes._
9494
val mergePreferred = name("mergePreferred")
9595
val mergePrioritized = name("mergePrioritized")
9696
val flattenMerge = name("flattenMerge")
97+
val flattenConcat = name("flattenConcat")
9798
val recoverWith = name("recoverWith")
9899
val onErrorComplete = name("onErrorComplete")
99100
val broadcast = name("broadcast")

0 commit comments

Comments
 (0)