@@ -13,7 +13,7 @@ import scala.collection.immutable.SortedSet
1313import scala .concurrent .duration .FiniteDuration
1414import scala .util .matching .Regex
1515
16- import cats .{Foldable , Functor , Reducible }
16+ import cats .{Applicative , Foldable , Functor , Reducible }
1717import cats .data .{NonEmptySet , OptionT }
1818import cats .effect .*
1919import cats .effect .implicits .*
@@ -126,6 +126,7 @@ object KafkaConsumer {
126126
127127 private def createKafkaConsumer [F [_], K , V ](
128128 requests : QueueSink [F , Request [F , K , V ]],
129+ settings : ConsumerSettings [F , K , V ],
129130 actor : KafkaConsumerActor [F , K , V ],
130131 fiber : Fiber [F , Throwable , Unit ],
131132 id : Int ,
@@ -141,7 +142,8 @@ object KafkaConsumer {
141142 type PartitionsMapQueue = Queue [F , Option [PartitionsMap ]]
142143
143144 def partitionStream (
144- partition : TopicPartition
145+ partition : TopicPartition ,
146+ signalCompletion : F [Unit ]
145147 ): Stream [F , CommittableConsumerRecord [F , K , V ]] =
146148 Stream .force {
147149 actor
@@ -155,54 +157,91 @@ object KafkaConsumer {
155157 .void
156158 .attempt
157159
158- Stream .fromQueueUnterminated(chunksQueue, 1 ).unchunks.interruptWhen(stopStream)
160+ Stream
161+ .fromQueueUnterminated(chunksQueue, 1 )
162+ .unchunks
163+ .interruptWhen(stopStream)
164+ .onFinalize(signalCompletion)
159165 }
160166 }
161167
162168 def enqueueAssignment (
163- assigned : Set [TopicPartition ],
169+ assigned : Map [TopicPartition , AssignmentSignals [ F ] ],
164170 partitionsMapQueue : PartitionsMapQueue
165171 ): F [Unit ] =
166172 stopConsumingDeferred
167173 .tryGet
168174 .flatMap {
169175 case None =>
170- val assignment : PartitionsMap = assigned
171- .view
172- .map { partition =>
173- partition -> partitionStream(partition)
174- }
175- .toMap
176+ val assignment : PartitionsMap = assigned.map { case (partition, signals) =>
177+ partition -> partitionStream(partition, signals.signalStreamFinished.void)
178+ }
176179
177180 partitionsMapQueue.offer(Some (assignment))
178181
179182 case Some (()) =>
180183 F .unit
181184 }
182185
183- def onRebalance (partitionsMapQueue : PartitionsMapQueue ): OnRebalance [F ] =
186+ def onRebalance (
187+ assignmentRef : Ref [F , Map [TopicPartition , AssignmentSignals [F ]]],
188+ partitionsMapQueue : PartitionsMapQueue
189+ ): OnRebalance [F ] =
184190 OnRebalance (
185- onRevoked = _ => F .unit,
186- onAssigned = assigned => enqueueAssignment(assigned, partitionsMapQueue)
191+ onRevoked = revoked =>
192+ for {
193+ assignment <- assignmentRef.get
194+ _ <- revoked.toVector.flatMap(assignment.get).traverse_(_.awaitStreamFinishedSignal)
195+ } yield (),
196+ onAssigned = assigned =>
197+ for {
198+ assignment <- buildAssignment(assigned)
199+ _ <- assignmentRef.update(_ ++ assignment)
200+ _ <- enqueueAssignment(assignment, partitionsMapQueue)
201+ } yield ()
187202 )
188203
189- def requestAssignment (partitionsMapQueue : PartitionsMapQueue ): F [Set [TopicPartition ]] = {
204+ def buildAssignment (
205+ assignedPartitions : SortedSet [TopicPartition ]
206+ ): F [Map [TopicPartition , AssignmentSignals [F ]]] = {
207+ assignedPartitions
208+ .toVector
209+ .traverse { partition =>
210+ settings.rebalanceRevokeMode match {
211+ case RebalanceRevokeMode .EagerMode =>
212+ (partition -> AssignmentSignals .eager[F ]).pure[F ]
213+ case RebalanceRevokeMode .GracefulMode =>
214+ Deferred [F , Unit ].map(revokeFinisher =>
215+ partition -> AssignmentSignals .graceful(revokeFinisher)
216+ )
217+ }
218+ }
219+ .map(_.toMap)
220+ }
221+
222+ def requestAssignment (
223+ assignmentRef : Ref [F , Map [TopicPartition , AssignmentSignals [F ]]],
224+ partitionsMapQueue : PartitionsMapQueue
225+ ): F [Map [TopicPartition , AssignmentSignals [F ]]] = {
190226 val assignment = this .assignment(
191227 Some (
192- onRebalance(partitionsMapQueue)
228+ onRebalance(assignmentRef, partitionsMapQueue)
193229 )
194230 )
195231
196232 F .race(awaitTermination.attempt, assignment)
197233 .flatMap {
198- case Left (_) => F .pure(Set .empty)
199- case Right (assigned) => F .pure (assigned)
234+ case Left (_) => F .pure(Map .empty)
235+ case Right (assigned) => buildAssignment (assigned).flatTap(assignmentRef.set )
200236 }
201237 }
202238
203- def initialEnqueue (partitionsMapQueue : PartitionsMapQueue ): F [Unit ] =
239+ def initialEnqueue (
240+ assignmentRef : Ref [F , Map [TopicPartition , AssignmentSignals [F ]]],
241+ partitionsMapQueue : PartitionsMapQueue
242+ ): F [Unit ] =
204243 for {
205- assigned <- requestAssignment(partitionsMapQueue)
244+ assigned <- requestAssignment(assignmentRef, partitionsMapQueue)
206245 _ <- enqueueAssignment(assigned, partitionsMapQueue)
207246 } yield ()
208247
@@ -212,7 +251,9 @@ object KafkaConsumer {
212251 case None =>
213252 for {
214253 partitionsMapQueue <- Stream .eval(Queue .unbounded[F , Option [PartitionsMap ]])
215- _ <- Stream .eval(initialEnqueue(partitionsMapQueue))
254+ assignmentRef <-
255+ Stream .eval(Ref [F ].of(Map .empty[TopicPartition , AssignmentSignals [F ]]))
256+ _ <- Stream .eval(initialEnqueue(assignmentRef, partitionsMapQueue))
216257 out <- Stream
217258 .fromQueueNoneTerminated(partitionsMapQueue)
218259 .interruptWhen(awaitTermination.attempt)
@@ -574,6 +615,7 @@ object KafkaConsumer {
574615 fiber <- startBackgroundConsumer(requests, polls, actor, settings.pollInterval)
575616 } yield createKafkaConsumer(
576617 requests,
618+ settings,
577619 actor,
578620 fiber,
579621 id,
@@ -695,6 +737,47 @@ object KafkaConsumer {
695737
696738 }
697739
740+ /**
741+ * Utility class to provide clarity for internals. Goal is to make [[RebalanceRevokeMode ]]
742+ * transparent to the rest of implementation internals.
743+ * @tparam F
744+ * effect used
745+ */
746+ sealed abstract private class AssignmentSignals [F [_]] {
747+
748+ def signalStreamFinished : F [Boolean ]
749+ def awaitStreamFinishedSignal : F [Unit ]
750+
751+ }
752+
753+ private object AssignmentSignals {
754+
755+ def eager [F [_]: Applicative ]: AssignmentSignals [F ] =
756+ EagerSignals ()
757+
758+ def graceful [F [_]](
759+ revokeFinisher : Deferred [F , Unit ]
760+ ): AssignmentSignals [F ] =
761+ GracefulSignals [F ](revokeFinisher)
762+
763+ final private case class EagerSignals [F [_]: Applicative ]() extends AssignmentSignals [F ] {
764+
765+ override def signalStreamFinished : F [Boolean ] = true .pure[F ]
766+ override def awaitStreamFinishedSignal : F [Unit ] = ().pure[F ]
767+
768+ }
769+
770+ final private case class GracefulSignals [F [_]](
771+ revokeFinisher : Deferred [F , Unit ]
772+ ) extends AssignmentSignals [F ] {
773+
774+ override def signalStreamFinished : F [Boolean ] = revokeFinisher.complete(())
775+ override def awaitStreamFinishedSignal : F [Unit ] = revokeFinisher.get
776+
777+ }
778+
779+ }
780+
698781 /*
699782 * Prevents the default `MkConsumer` instance from being implicitly available
700783 * to code defined in this object, ensuring factory methods require an instance
0 commit comments