Skip to content

Commit 5a628e5

Browse files
authored
KAFKA-20321: Mark lost partitions before callbacks for Consumer and Streams (apache#21781)
Mark partitions lost to avoid returning buffered records in the case of a race after the callback completes (app thread moves onto polling, while background moves onto clearing the lost assignment) Same fix applied for Consumer and Stream managers (Share has no callbacks so this does not apply) Reviewers: Nilesh Kumar [nileshkumar3@gmail.com](mailto:nileshkumar3@gmail.com), Lan Ding <isDing_L@163.com>, Ken Huang <s7133700@gmail.com>, PoAn Yang <payang@apache.org>, Lucas Brutschy <lbrutschy@confluent.io>
1 parent 93918c5 commit 5a628e5

5 files changed

Lines changed: 133 additions & 4 deletions

File tree

clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractMembershipManager.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -773,8 +773,9 @@ public void maybeRejoinStaleMember() {
773773
* expired poll timer. This will trigger the onPartitionsLost callback. Once the callback
774774
* completes, the member will remain stale until the poll timer is reset by an application
775775
* poll event. See {@link #maybeRejoinStaleMember()}.
776+
* Visible for testing.
776777
*/
777-
private void transitionToStale() {
778+
void transitionToStale() {
778779
transitionTo(MemberState.STALE);
779780

780781
// Release assignment
@@ -1246,7 +1247,7 @@ public CompletableFuture<Void> signalPartitionsRevoked(Set<TopicPartition> parti
12461247
* <li>Previous in-flight fetch requests that may complete while the partitions are being revoked won't be processed.</li>
12471248
* </ul>
12481249
*/
1249-
private void markPendingRevocationToPauseFetching(Set<TopicPartition> partitionsToRevoke) {
1250+
protected void markPendingRevocationToPauseFetching(Set<TopicPartition> partitionsToRevoke) {
12501251
// When asynchronously committing offsets prior to the revocation of a set of partitions, there will be a
12511252
// window of time between when the offset commit is sent and when it returns and revocation completes. It is
12521253
// possible for pending fetches for these partitions to return during this time, which means the application's

clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerMembershipManager.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,9 @@ protected CompletableFuture<Void> signalMemberLeavingGroup() {
301301
*/
302302
@Override
303303
protected CompletableFuture<Void> signalPartitionsLost(Set<TopicPartition> partitionsLost) {
304+
// Mark partitions as pending revocation to stop fetching from the partitions (no new
305+
// fetches sent out, and no in-flight fetches responses processed).
306+
markPendingRevocationToPauseFetching(partitionsLost);
304307
return invokeOnPartitionsLostCallback(partitionsLost);
305308
}
306309

clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsMembershipManager.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,9 @@ private void finalizeLeaving() {
462462
private void transitionToStale() {
463463
transitionTo(MemberState.STALE);
464464

465+
// Mark partitions as pending revocation to stop fetching before callback
466+
subscriptionState.markPendingRevocation(subscriptionState.assignedPartitions());
467+
465468
final CompletableFuture<Void> onAllTasksLostCallbackExecuted = requestOnAllTasksLostCallbackInvocation();
466469
staleMemberAssignmentRelease = onAllTasksLostCallbackExecuted.whenComplete((result, error) -> {
467470
if (error != null) {
@@ -500,6 +503,9 @@ public void transitionToFatal() {
500503
return;
501504
}
502505

506+
// Mark partitions as pending revocation to stop fetching before callback
507+
subscriptionState.markPendingRevocation(subscriptionState.assignedPartitions());
508+
503509
CompletableFuture<Void> onAllTasksLostCallbackExecuted = requestOnAllTasksLostCallbackInvocation();
504510
onAllTasksLostCallbackExecuted.whenComplete((result, error) -> {
505511
if (error != null) {
@@ -805,6 +811,9 @@ public void onFenced() {
805811
log.debug("Member {} with epoch {} transitioned to {} state. It will release its " +
806812
"assignment and rejoin the group.", memberId, memberEpoch, MemberState.FENCED);
807813

814+
// Mark partitions as pending revocation to stop fetching before callback
815+
subscriptionState.markPendingRevocation(subscriptionState.assignedPartitions());
816+
808817
CompletableFuture<Void> onAllTasksLostCallbackExecuted = requestOnAllTasksLostCallbackInvocation();
809818
onAllTasksLostCallbackExecuted.whenComplete((result, error) -> {
810819
if (error != null) {

clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerMembershipManagerTest.java

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ public void setup() {
124124
commitRequestManager = mock(CommitRequestManager.class);
125125
backgroundEventQueue = new LinkedBlockingQueue<>();
126126
time = new MockTime(0);
127-
backgroundEventHandler = new BackgroundEventHandler(backgroundEventQueue, time, mock(AsyncConsumerMetrics.class));
127+
backgroundEventHandler = spy(new BackgroundEventHandler(backgroundEventQueue, time, mock(AsyncConsumerMetrics.class)));
128128
metrics = new Metrics(time);
129129
rebalanceMetricsManager = new ConsumerRebalanceMetricsManager(metrics, subscriptionState);
130130

@@ -271,6 +271,72 @@ public void testFencingWhenStateIsStable() {
271271
verify(subscriptionState).assignFromSubscribed(Collections.emptySet());
272272
}
273273

274+
@Test
275+
public void testTransitionToFencedMarksPendingRevocationBeforeSignalingPartitionsLost() {
276+
ConsumerMembershipManager membershipManager = createMemberInStableState();
277+
String topicName = "topic1";
278+
TopicPartition ownedPartition = new TopicPartition(topicName, 0);
279+
Set<TopicPartition> ownedPartitions = Collections.singleton(ownedPartition);
280+
281+
CounterConsumerRebalanceListener listener = new CounterConsumerRebalanceListener();
282+
when(subscriptionState.assignedPartitions()).thenReturn(ownedPartitions);
283+
when(subscriptionState.hasAutoAssignedPartitions()).thenReturn(true);
284+
when(subscriptionState.rebalanceListener()).thenReturn(Optional.of(listener));
285+
286+
membershipManager.transitionToFenced();
287+
288+
// Verify markPendingRevocation is called before enqueueing the callback event
289+
InOrder inOrder = inOrder(subscriptionState, backgroundEventHandler);
290+
inOrder.verify(subscriptionState).markPendingRevocation(ownedPartitions);
291+
inOrder.verify(backgroundEventHandler).add(any(PartitionsRemovedEvent.class));
292+
}
293+
294+
@Test
295+
public void testTransitionToFatalMarksPendingRevocationBeforeSignalingPartitionsLost() {
296+
ConsumerMembershipManager membershipManager = createMemberInStableState();
297+
String topicName = "topic1";
298+
TopicPartition ownedPartition = new TopicPartition(topicName, 0);
299+
Set<TopicPartition> ownedPartitions = Collections.singleton(ownedPartition);
300+
301+
CounterConsumerRebalanceListener listener = new CounterConsumerRebalanceListener();
302+
when(subscriptionState.assignedPartitions()).thenReturn(ownedPartitions);
303+
when(subscriptionState.hasAutoAssignedPartitions()).thenReturn(true);
304+
when(subscriptionState.rebalanceListener()).thenReturn(Optional.of(listener));
305+
306+
membershipManager.transitionToFatal();
307+
308+
// Verify markPendingRevocation is called before enqueueing the callback event
309+
InOrder inOrder = inOrder(subscriptionState, backgroundEventHandler);
310+
inOrder.verify(subscriptionState).markPendingRevocation(ownedPartitions);
311+
inOrder.verify(backgroundEventHandler).add(any(PartitionsRemovedEvent.class));
312+
}
313+
314+
@Test
315+
public void testTransitionToStaleMarksPendingRevocationBeforeSignalingPartitionsLost() {
316+
ConsumerMembershipManager membershipManager = createMemberInStableState();
317+
String topicName = "topic1";
318+
TopicPartition ownedPartition = new TopicPartition(topicName, 0);
319+
Set<TopicPartition> ownedPartitions = Collections.singleton(ownedPartition);
320+
321+
CounterConsumerRebalanceListener listener = new CounterConsumerRebalanceListener();
322+
when(subscriptionState.assignedPartitions()).thenReturn(ownedPartitions);
323+
when(subscriptionState.hasAutoAssignedPartitions()).thenReturn(true);
324+
when(subscriptionState.rebalanceListener()).thenReturn(Optional.of(listener));
325+
326+
// First transition to LEAVING (required before transitioning to STALE)
327+
membershipManager.transitionToSendingLeaveGroup(true);
328+
clearInvocations(subscriptionState, backgroundEventHandler);
329+
when(subscriptionState.assignedPartitions()).thenReturn(ownedPartitions);
330+
when(subscriptionState.rebalanceListener()).thenReturn(Optional.of(listener));
331+
332+
membershipManager.transitionToStale();
333+
334+
// Verify markPendingRevocation is called before enqueueing the callback event
335+
InOrder inOrder = inOrder(subscriptionState, backgroundEventHandler);
336+
inOrder.verify(subscriptionState).markPendingRevocation(ownedPartitions);
337+
inOrder.verify(backgroundEventHandler).add(any(PartitionsRemovedEvent.class));
338+
}
339+
274340
@Test
275341
public void testListenersGetNotifiedOnTransitionsToFatal() {
276342
when(subscriptionState.rebalanceListener()).thenReturn(Optional.empty());
@@ -2841,7 +2907,7 @@ private ConsumerMembershipManager createMemberInStableState(String groupInstance
28412907
membershipManager.onHeartbeatRequestGenerated();
28422908
assertEquals(MemberState.STABLE, membershipManager.state());
28432909

2844-
clearInvocations(subscriptionState, membershipManager, commitRequestManager);
2910+
clearInvocations(subscriptionState, membershipManager, commitRequestManager, backgroundEventHandler);
28452911
return membershipManager;
28462912
}
28472913

clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsMembershipManagerTest.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,13 @@
4242
import org.junit.jupiter.params.provider.ValueSource;
4343
import org.mockito.ArgumentCaptor;
4444
import org.mockito.Captor;
45+
import org.mockito.InOrder;
4546
import org.mockito.Mock;
4647
import org.mockito.Mockito;
4748
import org.mockito.junit.jupiter.MockitoExtension;
4849

4950
import java.util.Collection;
51+
import java.util.Collections;
5052
import java.util.List;
5153
import java.util.Map;
5254
import java.util.Optional;
@@ -70,6 +72,7 @@
7072
import static org.junit.jupiter.api.Assertions.assertTrue;
7173
import static org.mockito.ArgumentMatchers.any;
7274
import static org.mockito.ArgumentMatchers.argThat;
75+
import static org.mockito.Mockito.inOrder;
7376
import static org.mockito.Mockito.lenient;
7477
import static org.mockito.Mockito.never;
7578
import static org.mockito.Mockito.times;
@@ -1770,6 +1773,53 @@ public void testTransitionToFatalWhenInUnsubscribe() {
17701773
verify(subscriptionState, never()).assignFromSubscribed(Set.of());
17711774
}
17721775

1776+
@Test
1777+
public void testOnFencedMarksPendingRevocationBeforeCallback() {
1778+
TopicPartition ownedPartition = new TopicPartition(TOPIC_0, PARTITION_0);
1779+
Set<TopicPartition> ownedPartitions = Collections.singleton(ownedPartition);
1780+
when(subscriptionState.assignedPartitions()).thenReturn(ownedPartitions);
1781+
joining();
1782+
1783+
membershipManager.onFenced();
1784+
1785+
// Verify markPendingRevocation is called before the callback event is enqueued
1786+
InOrder inOrder = inOrder(subscriptionState, backgroundEventHandler);
1787+
inOrder.verify(subscriptionState).markPendingRevocation(ownedPartitions);
1788+
inOrder.verify(backgroundEventHandler).add(any(StreamsOnAllTasksLostCallbackNeededEvent.class));
1789+
}
1790+
1791+
@Test
1792+
public void testTransitionToFatalMarksPendingRevocationBeforeCallback() {
1793+
TopicPartition ownedPartition = new TopicPartition(TOPIC_0, PARTITION_0);
1794+
Set<TopicPartition> ownedPartitions = Collections.singleton(ownedPartition);
1795+
when(subscriptionState.assignedPartitions()).thenReturn(ownedPartitions);
1796+
joining();
1797+
1798+
membershipManager.transitionToFatal();
1799+
1800+
// Verify markPendingRevocation is called before the callback event is enqueued
1801+
InOrder inOrder = inOrder(subscriptionState, backgroundEventHandler);
1802+
inOrder.verify(subscriptionState).markPendingRevocation(ownedPartitions);
1803+
inOrder.verify(backgroundEventHandler).add(any(StreamsOnAllTasksLostCallbackNeededEvent.class));
1804+
}
1805+
1806+
@Test
1807+
public void testTransitionToStaleMarksPendingRevocationBeforeCallback() {
1808+
TopicPartition ownedPartition = new TopicPartition(TOPIC_0, PARTITION_0);
1809+
Set<TopicPartition> ownedPartitions = Collections.singleton(ownedPartition);
1810+
when(subscriptionState.assignedPartitions()).thenReturn(ownedPartitions);
1811+
joining();
1812+
1813+
// Trigger poll timer expiry to transition to LEAVING, then STALE on heartbeat generated
1814+
membershipManager.onPollTimerExpired();
1815+
membershipManager.onHeartbeatRequestGenerated();
1816+
1817+
// Verify markPendingRevocation is called before the callback event is enqueued
1818+
InOrder inOrder = inOrder(subscriptionState, backgroundEventHandler);
1819+
inOrder.verify(subscriptionState).markPendingRevocation(ownedPartitions);
1820+
inOrder.verify(backgroundEventHandler).add(any(StreamsOnAllTasksLostCallbackNeededEvent.class));
1821+
}
1822+
17731823
@Test
17741824
public void testOnTasksAssignedCallbackCompleted() {
17751825
final CompletableFuture<Void> future = new CompletableFuture<>();

0 commit comments

Comments
 (0)