Skip to content

Commit 67fa365

Browse files
authored
MINOR: Fix Streams Position thread-safety (#19480)
* Fixes a thread-safety bug in the Kafka Streams Position class * Adds a multithreaded test to validate the fix and prevent regressions Reviewers: John Roesler <[email protected]>
1 parent c465abc commit 67fa365

File tree

2 files changed

+129
-6
lines changed

2 files changed

+129
-6
lines changed

Diff for: streams/src/main/java/org/apache/kafka/streams/query/Position.java

+2-6
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,11 @@ public Position merge(final Position other) {
104104
} else {
105105
for (final Entry<String, ConcurrentHashMap<Integer, Long>> entry : other.position.entrySet()) {
106106
final String topic = entry.getKey();
107-
final Map<Integer, Long> partitionMap =
108-
position.computeIfAbsent(topic, k -> new ConcurrentHashMap<>());
107+
109108
for (final Entry<Integer, Long> partitionOffset : entry.getValue().entrySet()) {
110109
final Integer partition = partitionOffset.getKey();
111110
final Long offset = partitionOffset.getValue();
112-
if (!partitionMap.containsKey(partition)
113-
|| partitionMap.get(partition) < offset) {
114-
partitionMap.put(partition, offset);
115-
}
111+
withComponent(topic, partition, offset);
116112
}
117113
}
118114
return this;

Diff for: streams/src/test/java/org/apache/kafka/streams/query/PositionTest.java

+127
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,21 @@
1919

2020
import org.junit.jupiter.api.Test;
2121

22+
import java.util.ArrayList;
2223
import java.util.Collections;
2324
import java.util.HashMap;
2425
import java.util.HashSet;
26+
import java.util.List;
2527
import java.util.Map;
28+
import java.util.Random;
2629
import java.util.Set;
30+
import java.util.concurrent.CountDownLatch;
31+
import java.util.concurrent.ExecutionException;
32+
import java.util.concurrent.ExecutorService;
33+
import java.util.concurrent.Executors;
34+
import java.util.concurrent.Future;
35+
import java.util.concurrent.TimeUnit;
36+
import java.util.concurrent.TimeoutException;
2737

2838
import static org.apache.kafka.common.utils.Utils.mkEntry;
2939
import static org.apache.kafka.common.utils.Utils.mkMap;
@@ -32,9 +42,12 @@
3242
import static org.junit.jupiter.api.Assertions.assertEquals;
3343
import static org.junit.jupiter.api.Assertions.assertNotEquals;
3444
import static org.junit.jupiter.api.Assertions.assertThrows;
45+
import static org.junit.jupiter.api.Assertions.assertTrue;
3546

3647
public class PositionTest {
3748

49+
private static final Random RANDOM = new Random();
50+
3851
@Test
3952
public void shouldCreateFromMap() {
4053
final Map<String, Map<Integer, Long>> map = mkMap(
@@ -221,4 +234,118 @@ public void shouldNotHash() {
221234
final HashMap<Position, Integer> map = new HashMap<>();
222235
assertThrows(UnsupportedOperationException.class, () -> map.put(position, 5));
223236
}
237+
238+
@Test
239+
public void shouldMonotonicallyIncreasePartitionPosition() throws InterruptedException, ExecutionException, TimeoutException {
240+
final int threadCount = 10;
241+
final int maxTopics = 50;
242+
final int maxPartitions = 50;
243+
final int maxOffset = 1000;
244+
final CountDownLatch startLatch = new CountDownLatch(threadCount);
245+
final Position mergePosition = Position.emptyPosition();
246+
final Position withComponentPosition = Position.emptyPosition();
247+
final List<Future<?>> futures = new ArrayList<>();
248+
ExecutorService executorService = null;
249+
250+
try {
251+
executorService = Executors.newFixedThreadPool(threadCount);
252+
253+
for (int i = 0; i < threadCount; i++) {
254+
futures.add(executorService.submit(() -> {
255+
final Position threadPosition = Position.emptyPosition();
256+
final int topicCount = RANDOM.nextInt(maxTopics) + 1;
257+
258+
// build the thread's position
259+
for (int topicNum = 0; topicNum < topicCount; topicNum++) {
260+
final String topic = "topic-" + topicNum;
261+
final int partitionCount = RANDOM.nextInt(maxPartitions) + 1;
262+
for (int partitionNum = 0; partitionNum < partitionCount; partitionNum++) {
263+
final long offset = RANDOM.nextInt(maxOffset) + 1;
264+
threadPosition.withComponent(topic, partitionNum, offset);
265+
}
266+
}
267+
268+
startLatch.countDown();
269+
try {
270+
startLatch.await();
271+
} catch (final InterruptedException e) {
272+
// convert to unchecked exception so the future completes exceptionally and fails the test
273+
throw new RuntimeException(e);
274+
}
275+
276+
// merge with the shared position
277+
mergePosition.merge(threadPosition);
278+
// duplicate the shared position to get a snapshot of its state
279+
final Position threadMergePositionState = mergePosition.copy();
280+
281+
// update the shared position using withComponent
282+
for (final String topic : threadPosition.getTopics()) {
283+
for (final Map.Entry<Integer, Long> partitionOffset : threadPosition
284+
.getPartitionPositions(topic)
285+
.entrySet()) {
286+
withComponentPosition.withComponent(topic, partitionOffset.getKey(), partitionOffset.getValue());
287+
}
288+
}
289+
// duplicate the shared position to get a snapshot of its state
290+
final Position threadWithComponentPositionState = withComponentPosition.copy();
291+
292+
// validate that any offsets in the merged position and the withComponent position are >= the thread position
293+
for (final String topic : threadPosition.getTopics()) {
294+
final Map<Integer, Long> threadOffsets = threadPosition.getPartitionPositions(topic);
295+
final Map<Integer, Long> mergedOffsets = threadMergePositionState.getPartitionPositions(topic);
296+
final Map<Integer, Long> withComponentOffsets = threadWithComponentPositionState.getPartitionPositions(topic);
297+
298+
for (final Map.Entry<Integer, Long> threadOffset : threadOffsets.entrySet()) {
299+
final int partition = threadOffset.getKey();
300+
final long offsetValue = threadOffset.getValue();
301+
302+
// merge checks
303+
assertTrue(
304+
mergedOffsets.containsKey(partition),
305+
"merge method failure. Missing partition " + partition + " for topic " + topic
306+
);
307+
assertTrue(
308+
mergedOffsets.get(partition) >= offsetValue,
309+
"merge method failure. Offset for topic " +
310+
topic +
311+
" partition " +
312+
partition +
313+
" expected >= " +
314+
offsetValue +
315+
" but got " +
316+
mergedOffsets.get(partition)
317+
);
318+
319+
// withComponent checks
320+
assertTrue(
321+
withComponentOffsets.containsKey(partition),
322+
"withComponent method failure. Missing partition " + partition + " for topic " + topic
323+
);
324+
assertTrue(
325+
withComponentOffsets.get(partition) >= offsetValue,
326+
"withComponent method failure. Offset for topic " +
327+
topic +
328+
" partition " +
329+
partition +
330+
" expected >= " +
331+
offsetValue +
332+
" but got " +
333+
withComponentOffsets.get(partition)
334+
);
335+
}
336+
}
337+
}));
338+
}
339+
340+
for (final Future<?> future : futures) {
341+
// Wait for all threads to complete
342+
future.get(1, TimeUnit.SECONDS); // Check for exceptions
343+
}
344+
} finally {
345+
if (executorService != null) {
346+
executorService.shutdown();
347+
assertTrue(executorService.awaitTermination(10, TimeUnit.SECONDS));
348+
}
349+
}
350+
}
224351
}

0 commit comments

Comments
 (0)