Skip to content

Commit 78bda86

Browse files
committed
chore: Avoid lift SingleSource to InflightSource.
1 parent 05d5b17 commit 78bda86

File tree

1 file changed

+71
-63
lines changed

1 file changed

+71
-63
lines changed

akka-stream/src/main/scala/akka/stream/impl/fusing/StreamOfStreams.scala

Lines changed: 71 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -159,19 +159,6 @@ private[akka] object FlattenConcat {
159159
def materialize(): Unit = ()
160160
}
161161

162-
private final class InflightSingleSource[T](elem: T) extends InflightSource[T] {
163-
private var _hasNext = true
164-
override def hasNext: Boolean = _hasNext
165-
override def next(): T =
166-
if (_hasNext) {
167-
_hasNext = false
168-
elem
169-
} else throw new NoSuchElementException("next called after completion")
170-
override def tryPull(): Unit = ()
171-
override def cancel(cause: Throwable): Unit = ()
172-
override def isClosed: Boolean = !hasNext
173-
}
174-
175162
private final class InflightIteratorSource[T](iterator: Iterator[T]) extends InflightSource[T] {
176163
override def hasNext: Boolean = iterator.hasNext
177164
override def next(): T = iterator.next()
@@ -232,16 +219,16 @@ private[akka] final class FlattenConcat[T, M](parallelism: Int)
232219

233220
override def initialAttributes: Attributes = DefaultAttributes.flattenConcat
234221
override val shape: FlowShape[Graph[SourceShape[T], M], T] = FlowShape(in, out)
235-
override def createLogic(enclosingAttributes: Attributes) =
236-
new GraphStageLogic(shape) with InHandler with OutHandler {
222+
override def createLogic(enclosingAttributes: Attributes) = {
223+
final object FlattenConcatLogic extends GraphStageLogic(shape) with InHandler with OutHandler {
237224
import FlattenConcat._
238-
private var queue: BufferImpl[InflightSource[T]] = _
225+
// InflightSource[T] or SingleSource[T]
226+
// AnyRef here to avoid lift the SingleSource[T] to InflightSource[T]
227+
private var queue: BufferImpl[AnyRef] = _
239228
private val invokeCb: InflightSource[T] => Unit =
240229
getAsyncCallback[InflightSource[T]](futureSourceCompleted).invoke
241230

242-
override def preStart(): Unit = {
243-
queue = BufferImpl(parallelism, enclosingAttributes)
244-
}
231+
override def preStart(): Unit = queue = BufferImpl(parallelism, enclosingAttributes)
245232

246233
private def futureSourceCompleted(futureSource: InflightSource[T]): Unit = {
247234
if (queue.peek() eq futureSource) {
@@ -257,8 +244,7 @@ private[akka] final class FlattenConcat[T, M](parallelism: Int)
257244
}
258245

259246
override def onPush(): Unit = {
260-
val source = grab(in)
261-
addSource(source)
247+
addSource(grab(in))
262248
//must try pull after addSource to avoid queue overflow
263249
if (!queue.isFull) { // try to keep the maximum parallelism
264250
tryPull(in)
@@ -273,23 +259,32 @@ private[akka] final class FlattenConcat[T, M](parallelism: Int)
273259
}
274260

275261
override def onPull(): Unit = {
276-
if (queue.nonEmpty) {
277-
val currentSource = queue.peek()
278-
//purge if possible
279-
if (currentSource.hasNext) {
280-
push(out, currentSource.next())
281-
if (currentSource.isClosed) {
282-
handleCurrentSourceClosed(currentSource)
262+
//purge if possible
263+
queue.peek() match {
264+
case src: SingleSource[T] @unchecked =>
265+
push(out, src.elem)
266+
removeSource()
267+
case src: InflightSource[T] @unchecked => pushOut(src)
268+
case null => //queue is empty
269+
if (!hasBeenPulled(in)) {
270+
tryPull(in)
271+
} else if (isClosed(in)) {
272+
completeStage()
283273
}
284-
} else if (currentSource.isClosed) {
285-
handleCurrentSourceClosed(currentSource)
286-
} else {
287-
currentSource.tryPull()
274+
case _ => throw new IllegalStateException("Should not reach here.")
275+
}
276+
}
277+
278+
private def pushOut(src: InflightSource[T]): Unit = {
279+
if (src.hasNext) {
280+
push(out, src.next())
281+
if (src.isClosed) {
282+
handleCurrentSourceClosed(src)
288283
}
289-
} else if (!hasBeenPulled(in)) {
290-
tryPull(in)
291-
} else if (isClosed(in)) {
292-
completeStage()
284+
} else if (src.isClosed) {
285+
handleCurrentSourceClosed(src)
286+
} else {
287+
src.tryPull()
293288
}
294289
}
295290

@@ -308,18 +303,18 @@ private[akka] final class FlattenConcat[T, M](parallelism: Int)
308303
private def cancelInflightSources(cause: Throwable): Unit = {
309304
if (queue.nonEmpty) {
310305
var source = queue.dequeue()
311-
while (source ne null) {
312-
source.cancel(cause)
306+
while ((source ne null) && (source.isInstanceOf[InflightSource[T] @unchecked])) {
307+
source.asInstanceOf[InflightSource[T]].cancel(cause)
313308
source = queue.dequeue()
314309
}
315310
}
316311
}
317312

318-
private def addSourceElem(elem: T): Unit = {
313+
private def addSource(singleSource: SingleSource[T]): Unit = {
319314
if (isAvailable(out) && queue.isEmpty) {
320-
push(out, elem)
315+
push(out, singleSource.elem)
321316
} else {
322-
queue.enqueue(new InflightSingleSource(elem))
317+
queue.enqueue(singleSource)
323318
}
324319
}
325320

@@ -397,7 +392,7 @@ private[akka] final class FlattenConcat[T, M](parallelism: Int)
397392
(TraversalBuilder.getValuePresentedSource(source): @nowarn("cat=lint-infer-any")) match {
398393
case OptionVal.Some(graph) =>
399394
graph match {
400-
case single: SingleSource[T] @unchecked => addSourceElem(single.elem)
395+
case single: SingleSource[T] @unchecked => addSource(single)
401396
case futureSource: FutureSource[T] @unchecked =>
402397
val future = futureSource.future
403398
future.value match {
@@ -415,37 +410,50 @@ private[akka] final class FlattenConcat[T, M](parallelism: Int)
415410

416411
}
417412

413+
private def removeSource(): Unit = {
414+
queue.dequeue()
415+
pullIfNeeded()
416+
}
417+
418418
private def removeSource(source: InflightSource[T]): Unit = {
419419
if (queue.nonEmpty && (source eq queue.peek())) {
420420
//only dequeue if it's the current emitting source
421-
val s = queue.dequeue()
422-
if (s != source) {
423-
throw new IllegalStateException("Should not reach here.")
424-
}
425-
if (isClosed(in)) {
426-
if (queue.isEmpty) {
427-
completeStage()
428-
} else {
429-
//pull the new emitting source
430-
val nextSource = queue.peek()
431-
nextSource.tryPull()
432-
}
421+
queue.dequeue()
422+
pullIfNeeded()
423+
} //not the head source, just ignore
424+
}
425+
426+
private def pullIfNeeded(): Unit = {
427+
if (isClosed(in)) {
428+
if (queue.isEmpty) {
429+
completeStage()
433430
} else {
434-
if (queue.nonEmpty) {
435-
//pull the new emitting source
436-
val nextSource = queue.peek()
437-
nextSource.tryPull()
438-
}
439-
if (!hasBeenPulled(in)) {
440-
tryPull(in)
441-
}
431+
tryPullNextSourceInQueue()
442432
}
443-
} //not the head source, just ignore
433+
} else {
434+
if (queue.nonEmpty) {
435+
tryPullNextSourceInQueue()
436+
}
437+
if (!hasBeenPulled(in)) {
438+
tryPull(in)
439+
}
440+
}
441+
}
442+
443+
private def tryPullNextSourceInQueue(): Unit = {
444+
//pull the new emitting source
445+
val nextSource = queue.peek()
446+
if (nextSource.isInstanceOf[InflightSource[T] @unchecked]) {
447+
nextSource.asInstanceOf[InflightSource[T]].tryPull()
448+
}
444449
}
445450

446451
setHandlers(in, out, this)
447452
}
448453

454+
FlattenConcatLogic
455+
}
456+
449457
override def toString: String = s"FlattenConcat($parallelism)"
450458
}
451459

0 commit comments

Comments
 (0)