@@ -159,19 +159,6 @@ private[akka] object FlattenConcat {
159159 def materialize (): Unit = ()
160160 }
161161
162- private final class InflightSingleSource [T ](elem : T ) extends InflightSource [T ] {
163- private var _hasNext = true
164- override def hasNext : Boolean = _hasNext
165- override def next (): T =
166- if (_hasNext) {
167- _hasNext = false
168- elem
169- } else throw new NoSuchElementException (" next called after completion" )
170- override def tryPull (): Unit = ()
171- override def cancel (cause : Throwable ): Unit = ()
172- override def isClosed : Boolean = ! hasNext
173- }
174-
175162 private final class InflightIteratorSource [T ](iterator : Iterator [T ]) extends InflightSource [T ] {
176163 override def hasNext : Boolean = iterator.hasNext
177164 override def next (): T = iterator.next()
@@ -232,16 +219,16 @@ private[akka] final class FlattenConcat[T, M](parallelism: Int)
232219
233220 override def initialAttributes : Attributes = DefaultAttributes .flattenConcat
234221 override val shape : FlowShape [Graph [SourceShape [T ], M ], T ] = FlowShape (in, out)
235- override def createLogic (enclosingAttributes : Attributes ) =
236- new GraphStageLogic (shape) with InHandler with OutHandler {
222+ override def createLogic (enclosingAttributes : Attributes ) = {
223+ final object FlattenConcatLogic extends GraphStageLogic (shape) with InHandler with OutHandler {
237224 import FlattenConcat ._
238- private var queue : BufferImpl [InflightSource [T ]] = _
225+ // InflightSource[T] or SingleSource[T]
226+ // AnyRef here to avoid lift the SingleSource[T] to InflightSource[T]
227+ private var queue : BufferImpl [AnyRef ] = _
239228 private val invokeCb : InflightSource [T ] => Unit =
240229 getAsyncCallback[InflightSource [T ]](futureSourceCompleted).invoke
241230
242- override def preStart (): Unit = {
243- queue = BufferImpl (parallelism, enclosingAttributes)
244- }
231+ override def preStart (): Unit = queue = BufferImpl (parallelism, enclosingAttributes)
245232
246233 private def futureSourceCompleted (futureSource : InflightSource [T ]): Unit = {
247234 if (queue.peek() eq futureSource) {
@@ -257,8 +244,7 @@ private[akka] final class FlattenConcat[T, M](parallelism: Int)
257244 }
258245
259246 override def onPush (): Unit = {
260- val source = grab(in)
261- addSource(source)
247+ addSource(grab(in))
262248 // must try pull after addSource to avoid queue overflow
263249 if (! queue.isFull) { // try to keep the maximum parallelism
264250 tryPull(in)
@@ -273,23 +259,32 @@ private[akka] final class FlattenConcat[T, M](parallelism: Int)
273259 }
274260
275261 override def onPull (): Unit = {
276- if (queue.nonEmpty) {
277- val currentSource = queue.peek()
278- // purge if possible
279- if (currentSource.hasNext) {
280- push(out, currentSource.next())
281- if (currentSource.isClosed) {
282- handleCurrentSourceClosed(currentSource)
262+ // purge if possible
263+ queue.peek() match {
264+ case src : SingleSource [T ] @ unchecked =>
265+ push(out, src.elem)
266+ removeSource()
267+ case src : InflightSource [T ] @ unchecked => pushOut(src)
268+ case null => // queue is empty
269+ if (! hasBeenPulled(in)) {
270+ tryPull(in)
271+ } else if (isClosed(in)) {
272+ completeStage()
283273 }
284- } else if (currentSource.isClosed) {
285- handleCurrentSourceClosed(currentSource)
286- } else {
287- currentSource.tryPull()
274+ case _ => throw new IllegalStateException (" Should not reach here." )
275+ }
276+ }
277+
278+ private def pushOut (src : InflightSource [T ]): Unit = {
279+ if (src.hasNext) {
280+ push(out, src.next())
281+ if (src.isClosed) {
282+ handleCurrentSourceClosed(src)
288283 }
289- } else if (! hasBeenPulled(in) ) {
290- tryPull(in )
291- } else if (isClosed(in)) {
292- completeStage ()
284+ } else if (src.isClosed ) {
285+ handleCurrentSourceClosed(src )
286+ } else {
287+ src.tryPull ()
293288 }
294289 }
295290
@@ -308,18 +303,18 @@ private[akka] final class FlattenConcat[T, M](parallelism: Int)
308303 private def cancelInflightSources (cause : Throwable ): Unit = {
309304 if (queue.nonEmpty) {
310305 var source = queue.dequeue()
311- while (source ne null ) {
312- source.cancel(cause)
306+ while (( source ne null ) && (source. isInstanceOf [ InflightSource [ T ] @ unchecked]) ) {
307+ source.asInstanceOf [ InflightSource [ T ]]. cancel(cause)
313308 source = queue.dequeue()
314309 }
315310 }
316311 }
317312
318- private def addSourceElem ( elem : T ): Unit = {
313+ private def addSource ( singleSource : SingleSource [ T ] ): Unit = {
319314 if (isAvailable(out) && queue.isEmpty) {
320- push(out, elem)
315+ push(out, singleSource. elem)
321316 } else {
322- queue.enqueue(new InflightSingleSource (elem) )
317+ queue.enqueue(singleSource )
323318 }
324319 }
325320
@@ -397,7 +392,7 @@ private[akka] final class FlattenConcat[T, M](parallelism: Int)
397392 (TraversalBuilder .getValuePresentedSource(source): @ nowarn(" cat=lint-infer-any" )) match {
398393 case OptionVal .Some (graph) =>
399394 graph match {
400- case single : SingleSource [T ] @ unchecked => addSourceElem (single.elem )
395+ case single : SingleSource [T ] @ unchecked => addSource (single)
401396 case futureSource : FutureSource [T ] @ unchecked =>
402397 val future = futureSource.future
403398 future.value match {
@@ -415,37 +410,50 @@ private[akka] final class FlattenConcat[T, M](parallelism: Int)
415410
416411 }
417412
413+ private def removeSource (): Unit = {
414+ queue.dequeue()
415+ pullIfNeeded()
416+ }
417+
418418 private def removeSource (source : InflightSource [T ]): Unit = {
419419 if (queue.nonEmpty && (source eq queue.peek())) {
420420 // only dequeue if it's the current emitting source
421- val s = queue.dequeue()
422- if (s != source) {
423- throw new IllegalStateException (" Should not reach here." )
424- }
425- if (isClosed(in)) {
426- if (queue.isEmpty) {
427- completeStage()
428- } else {
429- // pull the new emitting source
430- val nextSource = queue.peek()
431- nextSource.tryPull()
432- }
421+ queue.dequeue()
422+ pullIfNeeded()
423+ } // not the head source, just ignore
424+ }
425+
426+ private def pullIfNeeded (): Unit = {
427+ if (isClosed(in)) {
428+ if (queue.isEmpty) {
429+ completeStage()
433430 } else {
434- if (queue.nonEmpty) {
435- // pull the new emitting source
436- val nextSource = queue.peek()
437- nextSource.tryPull()
438- }
439- if (! hasBeenPulled(in)) {
440- tryPull(in)
441- }
431+ tryPullNextSourceInQueue()
442432 }
443- } // not the head source, just ignore
433+ } else {
434+ if (queue.nonEmpty) {
435+ tryPullNextSourceInQueue()
436+ }
437+ if (! hasBeenPulled(in)) {
438+ tryPull(in)
439+ }
440+ }
441+ }
442+
443+ private def tryPullNextSourceInQueue (): Unit = {
444+ // pull the new emitting source
445+ val nextSource = queue.peek()
446+ if (nextSource.isInstanceOf [InflightSource [T ] @ unchecked]) {
447+ nextSource.asInstanceOf [InflightSource [T ]].tryPull()
448+ }
444449 }
445450
446451 setHandlers(in, out, this )
447452 }
448453
454+ FlattenConcatLogic
455+ }
456+
449457 override def toString : String = s " FlattenConcat( $parallelism) "
450458}
451459
0 commit comments