Skip to content

Commit e82bfe6

Browse files
authored
Merge pull request #533 from LMnet/532-rebalance
2 parents b19ea0e + 7574755 commit e82bfe6

3 files changed

Lines changed: 119 additions & 85 deletions

File tree

modules/core/src/main/scala/fs2/kafka/KafkaConsumer.scala

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ object KafkaConsumer {
157157

158158
def createPartitionStream(
159159
streamId: StreamId,
160-
partitionStreamId: PartitionStreamId,
161-
partition: TopicPartition
160+
partition: TopicPartition,
161+
assignmentRevoked: F[Unit]
162162
): F[Stream[F, CommittableConsumerRecord[F, K, V]]] =
163163
for {
164164
chunks <- chunkQueue
@@ -169,14 +169,17 @@ object KafkaConsumer {
169169
awaitTermination.attempt,
170170
dequeueDone.get
171171
),
172-
stopConsumingDeferred.get
172+
F.race(
173+
stopConsumingDeferred.get,
174+
assignmentRevoked
175+
)
173176
)
174177
.void
175178
stopReqs <- Deferred.tryable[F, Unit]
176179
} yield Stream.eval {
177180
def fetchPartition(deferred: Deferred[F, PartitionRequest]): F[Unit] = {
178181
val request =
179-
Request.Fetch(partition, streamId, partitionStreamId, deferred.complete)
182+
Request.Fetch(partition, streamId, deferred.complete)
180183
val fetch = requests.enqueue1(request) >> deferred.get
181184
F.race(shutdown, fetch).flatMap {
182185
case Left(()) =>
@@ -220,33 +223,20 @@ object KafkaConsumer {
220223

221224
def enqueueAssignment(
222225
streamId: StreamId,
223-
partitionStreamIdRef: Ref[F, PartitionStreamId],
224226
assigned: SortedSet[TopicPartition],
225-
partitionsMapQueue: PartitionsMapQueue
227+
partitionsMapQueue: PartitionsMapQueue,
228+
assignmentRevoked: F[Unit]
226229
): F[Unit] = {
227230
val assignment: F[PartitionsMap] = if (assigned.isEmpty) {
228231
F.pure(Map.empty)
229232
} else {
230-
val indexedAssigned = assigned.toVector.zipWithIndex
231-
partitionStreamIdRef
232-
.modify { id =>
233-
val result = indexedAssigned.map {
234-
case (partition, idx) =>
235-
(partition, idx + id)
233+
assigned.toVector
234+
.traverse { partition =>
235+
createPartitionStream(streamId, partition, assignmentRevoked).map { stream =>
236+
partition -> stream
236237
}
237-
val (_, lastId) = result.last
238-
(lastId + 1, result)
239-
}
240-
.flatMap { (partitions: Vector[(TopicPartition, PartitionStreamId)]) =>
241-
partitions
242-
.traverse {
243-
case (partition, partitionStreamId) =>
244-
createPartitionStream(streamId, partitionStreamId, partition).map { stream =>
245-
partition -> stream
246-
}
247-
}
248-
.map(_.toMap)
249238
}
239+
.map(_.toMap)
250240
}
251241

252242
assignment.flatMap { assignment =>
@@ -261,24 +251,39 @@ object KafkaConsumer {
261251

262252
def onRebalance(
263253
streamId: StreamId,
264-
partitionStreamIdRef: Ref[F, PartitionStreamId],
254+
prevAssignmentFinisherRef: Ref[F, Deferred[F, Unit]],
265255
partitionsMapQueue: PartitionsMapQueue
266-
): OnRebalance[F, K, V] = OnRebalance(
267-
onAssigned = assigned =>
268-
enqueueAssignment(streamId, partitionStreamIdRef, assigned, partitionsMapQueue),
269-
onRevoked = _ => F.unit
270-
)
256+
): OnRebalance[F, K, V] =
257+
OnRebalance(
258+
onRevoked = _ => {
259+
for {
260+
newFinisher <- Deferred[F, Unit]
261+
prevAssignmentFinisher <- prevAssignmentFinisherRef.getAndSet(newFinisher)
262+
_ <- prevAssignmentFinisher.complete(())
263+
} yield ()
264+
},
265+
onAssigned = assigned => {
266+
prevAssignmentFinisherRef.get.flatMap { prevAssignmentFinisher =>
267+
enqueueAssignment(
268+
streamId = streamId,
269+
assigned = assigned,
270+
partitionsMapQueue = partitionsMapQueue,
271+
assignmentRevoked = prevAssignmentFinisher.get
272+
)
273+
}
274+
}
275+
)
271276

272277
def requestAssignment(
273278
streamId: StreamId,
274-
partitionStreamIdRef: Ref[F, PartitionStreamId],
279+
prevAssignmentFinisherRef: Ref[F, Deferred[F, Unit]],
275280
partitionsMapQueue: PartitionsMapQueue
276281
): F[SortedSet[TopicPartition]] =
277282
Deferred[F, Either[Throwable, SortedSet[TopicPartition]]].flatMap { deferred =>
278283
val request =
279284
Request.Assignment[F, K, V](
280285
deferred.complete,
281-
Some(onRebalance(streamId, partitionStreamIdRef, partitionsMapQueue))
286+
Some(onRebalance(streamId, prevAssignmentFinisherRef, partitionsMapQueue))
282287
)
283288
val assignment = requests.enqueue1(request) >> deferred.get.rethrow
284289
F.race(awaitTermination.attempt, assignment).map {
@@ -290,20 +295,24 @@ object KafkaConsumer {
290295
def initialEnqueue(
291296
streamId: StreamId,
292297
partitionsMapQueue: PartitionsMapQueue,
293-
partitionStreamIdRef: Ref[F, PartitionStreamId]
298+
prevAssignmentFinisherRef: Ref[F, Deferred[F, Unit]]
294299
): F[Unit] =
295-
requestAssignment(streamId, partitionStreamIdRef, partitionsMapQueue).flatMap {
296-
assigned =>
297-
enqueueAssignment(streamId, partitionStreamIdRef, assigned, partitionsMapQueue)
298-
}
300+
for {
301+
prevAssignmentFinisher <- prevAssignmentFinisherRef.get
302+
assigned <- requestAssignment(streamId, prevAssignmentFinisherRef, partitionsMapQueue)
303+
assignmentRevoked = prevAssignmentFinisher.get
304+
_ <- enqueueAssignment(streamId, assigned, partitionsMapQueue, assignmentRevoked)
305+
} yield ()
299306

300307
Stream.eval(stopConsumingDeferred.tryGet).flatMap {
301308
case None =>
302309
for {
303310
partitionsMapQueue <- Stream.eval(Queue.noneTerminated[F, PartitionsMap])
304311
streamId <- Stream.eval(streamIdRef.modify(n => (n + 1, n)))
305-
partitionStreamIdRef <- Stream.eval(Ref.of[F, PartitionStreamId](0))
306-
_ <- Stream.eval(initialEnqueue(streamId, partitionsMapQueue, partitionStreamIdRef))
312+
prevAssignmentFinisher <- Stream.eval(Deferred[F, Unit])
313+
prevAssignmentFinisherRef <- Stream.eval(Ref[F].of(prevAssignmentFinisher))
314+
_ <- Stream
315+
.eval(initialEnqueue(streamId, partitionsMapQueue, prevAssignmentFinisherRef))
307316
out <- partitionsMapQueue.dequeue
308317
.interruptWhen(awaitTermination.attempt)
309318
.concurrently(

modules/core/src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala

Lines changed: 11 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,6 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
166166
private[this] def fetch(
167167
partition: TopicPartition,
168168
streamId: StreamId,
169-
partitionStreamId: PartitionStreamId,
170169
callback: ((Chunk[CommittableConsumerRecord[F, K, V]], FetchCompletedReason)) => F[Unit]
171170
): F[Unit] = {
172171
val assigned =
@@ -176,7 +175,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
176175
ref
177176
.modify { state =>
178177
val (newState, oldFetch) =
179-
state.withFetch(partition, streamId, partitionStreamId, callback)
178+
state.withFetch(partition, streamId, callback)
180179
(newState, (newState, oldFetch))
181180
}
182181
.flatMap {
@@ -567,8 +566,8 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
567566
case Request.Assign(partitions, callback) => assign(partitions, callback)
568567
case Request.SubscribePattern(pattern, callback) => subscribe(pattern, callback)
569568
case Request.Unsubscribe(callback) => unsubscribe(callback)
570-
case Request.Fetch(partition, streamId, partitionStreamId, callback) =>
571-
fetch(partition, streamId, partitionStreamId, callback)
569+
case Request.Fetch(partition, streamId, callback) =>
570+
fetch(partition, streamId, callback)
572571
case request @ Request.Commit(_, _) => commit(request)
573572
case request @ Request.ManualCommitAsync(_, _) => manualCommitAsync(request)
574573
case request @ Request.ManualCommitSync(_, _) => manualCommitSync(request)
@@ -634,11 +633,9 @@ private[kafka] object KafkaConsumerActor {
634633
}
635634

636635
type StreamId = Int
637-
type PartitionStreamId = Int
638636

639637
final case class State[F[_], K, V](
640638
fetches: Map[TopicPartition, Map[StreamId, FetchRequest[F, K, V]]],
641-
partitionStreamIds: Map[TopicPartition, PartitionStreamId],
642639
records: Map[TopicPartition, NonEmptyVector[CommittableConsumerRecord[F, K, V]]],
643640
pendingCommits: Chain[Request.Commit[F, K, V]],
644641
onRebalances: Chain[OnRebalance[F, K, V]],
@@ -655,48 +652,29 @@ private[kafka] object KafkaConsumerActor {
655652
def withFetch(
656653
partition: TopicPartition,
657654
streamId: StreamId,
658-
partitionStreamId: PartitionStreamId,
659655
callback: ((Chunk[CommittableConsumerRecord[F, K, V]], FetchCompletedReason)) => F[Unit]
660656
): (State[F, K, V], List[FetchRequest[F, K, V]]) = {
661657
val newFetchRequest =
662658
FetchRequest(callback)
663659

664-
val oldPartitionFetches =
660+
val oldPartitionFetches: Map[StreamId, FetchRequest[F, K, V]] =
665661
fetches.getOrElse(partition, Map.empty)
666662

667-
val oldPartitionFetch =
668-
oldPartitionFetches.get(streamId)
663+
val newFetches: Map[TopicPartition, Map[StreamId, FetchRequest[F, K, V]]] =
664+
fetches.updated(partition, oldPartitionFetches.updated(streamId, newFetchRequest))
669665

670-
val oldPartitionStreamId =
671-
partitionStreamIds.getOrElse(partition, 0)
672-
673-
val hasMoreRecentPartitionStreamIds =
674-
oldPartitionStreamId > partitionStreamId
675-
676-
val newFetches =
677-
fetches.updated(partition, {
678-
if (hasMoreRecentPartitionStreamIds) oldPartitionFetches - streamId
679-
else oldPartitionFetches.updated(streamId, newFetchRequest)
680-
})
681-
682-
val newPartitionStreamIds =
683-
partitionStreamIds.updated(partition, oldPartitionStreamId max partitionStreamId)
684-
685-
val fetchesToRevoke =
686-
if (hasMoreRecentPartitionStreamIds)
687-
newFetchRequest :: oldPartitionFetch.toList
688-
else oldPartitionFetch.toList
666+
val fetchesToRevoke: List[FetchRequest[F, K, V]] =
667+
oldPartitionFetches.get(streamId).toList
689668

690669
(
691-
copy(fetches = newFetches, partitionStreamIds = newPartitionStreamIds),
670+
copy(fetches = newFetches),
692671
fetchesToRevoke
693672
)
694673
}
695674

696675
def withoutFetches(partitions: Set[TopicPartition]): State[F, K, V] =
697676
copy(
698-
fetches = fetches.filterKeysStrict(!partitions.contains(_)),
699-
partitionStreamIds = partitionStreamIds.filterKeysStrict(!partitions.contains(_))
677+
fetches = fetches.filterKeysStrict(!partitions.contains(_))
700678
)
701679

702680
def withRecords(
@@ -707,7 +685,6 @@ private[kafka] object KafkaConsumerActor {
707685
def withoutFetchesAndRecords(partitions: Set[TopicPartition]): State[F, K, V] =
708686
copy(
709687
fetches = fetches.filterKeysStrict(!partitions.contains(_)),
710-
partitionStreamIds = partitionStreamIds.filterKeysStrict(!partitions.contains(_)),
711688
records = records.filterKeysStrict(!partitions.contains(_))
712689
)
713690

@@ -743,25 +720,14 @@ private[kafka] object KafkaConsumerActor {
743720
append(fs.mkString("[", ", ", "]"))
744721
}("", ", ", "")
745722

746-
val partitionStreamIdsString =
747-
partitionStreamIds.toList
748-
.sortBy { case (tp, _) => tp }
749-
.mkStringAppend {
750-
case (append, (tp, id)) =>
751-
append(tp.toString)
752-
append(" -> ")
753-
append(id.toString)
754-
}("", ", ", "")
755-
756-
s"State(fetches = Map($fetchesString), partitionStreamIds = Map($partitionStreamIdsString), records = Map(${recordsString(records)}), pendingCommits = $pendingCommits, onRebalances = $onRebalances, rebalancing = $rebalancing, subscribed = $subscribed, streaming = $streaming)"
723+
s"State(fetches = Map($fetchesString), records = Map(${recordsString(records)}), pendingCommits = $pendingCommits, onRebalances = $onRebalances, rebalancing = $rebalancing, subscribed = $subscribed, streaming = $streaming)"
757724
}
758725
}
759726

760727
object State {
761728
def empty[F[_], K, V]: State[F, K, V] =
762729
State(
763730
fetches = Map.empty,
764-
partitionStreamIds = Map.empty,
765731
records = Map.empty,
766732
pendingCommits = Chain.empty,
767733
onRebalances = Chain.empty,
@@ -830,7 +796,6 @@ private[kafka] object KafkaConsumerActor {
830796
final case class Fetch[F[_], K, V](
831797
partition: TopicPartition,
832798
streamId: StreamId,
833-
partitionStreamId: PartitionStreamId,
834799
callback: ((Chunk[CommittableConsumerRecord[F, K, V]], FetchCompletedReason)) => F[Unit]
835800
) extends Request[F, K, V]
836801

modules/core/src/test/scala/fs2/kafka/KafkaConsumerSpec.scala

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,66 @@ final class KafkaConsumerSpec extends BaseKafkaSpec {
585585
}).unsafeRunSync()
586586
}
587587
}
588+
589+
it("should handle multiple rebalances with multiple instances under load #532") {
590+
withTopic { topic =>
591+
val numPartitions = 3
592+
createCustomTopic(topic, partitions = numPartitions)
593+
594+
val produced = (0 until 10000).map(n => s"key-$n" -> s"value->$n")
595+
publishToKafka(topic, produced)
596+
597+
def run(instance: Int, allAssignments: SignallingRef[IO, Map[Int, Set[Int]]]): IO[Unit] =
598+
KafkaConsumer
599+
.stream(consumerSettings[IO].withGroupId("test"))
600+
.evalTap(_.subscribeTo(topic))
601+
.flatMap(_.partitionsMapStream)
602+
.flatMap { assignment =>
603+
Stream.eval(allAssignments.update { current =>
604+
current.updated(instance, assignment.keySet.map(_.partition()))
605+
}) >> Stream
606+
.emits(assignment.map {
607+
case (_, partitionStream) =>
608+
partitionStream.evalMap(_ => IO.sleep(10.millis)) // imitating some work
609+
}.toList)
610+
.parJoinUnbounded
611+
}
612+
.compile
613+
.drain
614+
615+
def checkAssignments(
616+
allAssignments: SignallingRef[IO, Map[Int, Set[Int]]]
617+
)(instances: Set[Int]) =
618+
allAssignments.discrete
619+
.filter { state =>
620+
state.keySet == instances &&
621+
instances.forall { instance =>
622+
state.get(instance).exists(_.nonEmpty)
623+
} && state.values.toList.flatMap(_.toList).sorted == List(0, 1, 2)
624+
}
625+
.take(1)
626+
.compile
627+
.drain
628+
629+
(for {
630+
allAssignments <- SignallingRef[IO, Map[Int, Set[Int]]](Map.empty)
631+
check = checkAssignments(allAssignments)(_)
632+
fiber0 <- run(0, allAssignments).start
633+
_ <- check(Set(0))
634+
fiber1 <- run(1, allAssignments).start
635+
_ <- check(Set(0, 1))
636+
fiber2 <- run(2, allAssignments).start
637+
_ <- check(Set(0, 1, 2))
638+
_ <- fiber2.cancel
639+
_ <- allAssignments.update(_ - 2)
640+
_ <- check(Set(0, 1))
641+
_ <- fiber1.cancel
642+
_ <- allAssignments.update(_ - 1)
643+
_ <- check(Set(0))
644+
_ <- fiber0.cancel
645+
} yield succeed).unsafeRunSync()
646+
}
647+
}
588648
}
589649

590650
describe("KafkaConsumer#assignmentStream") {

0 commit comments

Comments
 (0)