diff --git a/streams/src/main/java/org/apache/kafka/streams/query/Position.java b/streams/src/main/java/org/apache/kafka/streams/query/Position.java index 089bca12cdc56..94acd9d8adff4 100644 --- a/streams/src/main/java/org/apache/kafka/streams/query/Position.java +++ b/streams/src/main/java/org/apache/kafka/streams/query/Position.java @@ -104,15 +104,11 @@ public Position merge(final Position other) { } else { for (final Entry> entry : other.position.entrySet()) { final String topic = entry.getKey(); - final Map partitionMap = - position.computeIfAbsent(topic, k -> new ConcurrentHashMap<>()); + for (final Entry partitionOffset : entry.getValue().entrySet()) { final Integer partition = partitionOffset.getKey(); final Long offset = partitionOffset.getValue(); - if (!partitionMap.containsKey(partition) - || partitionMap.get(partition) < offset) { - partitionMap.put(partition, offset); - } + withComponent(topic, partition, offset); } } return this; diff --git a/streams/src/test/java/org/apache/kafka/streams/query/PositionTest.java b/streams/src/test/java/org/apache/kafka/streams/query/PositionTest.java index c994c691d2fdb..a65c83684eaf9 100644 --- a/streams/src/test/java/org/apache/kafka/streams/query/PositionTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/query/PositionTest.java @@ -19,11 +19,21 @@ import org.junit.jupiter.api.Test; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; +import java.util.Random; import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import static org.apache.kafka.common.utils.Utils.mkEntry; import static org.apache.kafka.common.utils.Utils.mkMap; @@ -32,9 +42,12 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; public class PositionTest { + private static final Random RANDOM = new Random(); + @Test public void shouldCreateFromMap() { final Map> map = mkMap( @@ -221,4 +234,118 @@ public void shouldNotHash() { final HashMap map = new HashMap<>(); assertThrows(UnsupportedOperationException.class, () -> map.put(position, 5)); } + + @Test + public void shouldMonotonicallyIncreasePartitionPosition() throws InterruptedException, ExecutionException, TimeoutException { + final int threadCount = 10; + final int maxTopics = 50; + final int maxPartitions = 50; + final int maxOffset = 1000; + final CountDownLatch startLatch = new CountDownLatch(threadCount); + final Position mergePosition = Position.emptyPosition(); + final Position withComponentPosition = Position.emptyPosition(); + final List> futures = new ArrayList<>(); + ExecutorService executorService = null; + + try { + executorService = Executors.newFixedThreadPool(threadCount); + + for (int i = 0; i < threadCount; i++) { + futures.add(executorService.submit(() -> { + final Position threadPosition = Position.emptyPosition(); + final int topicCount = RANDOM.nextInt(maxTopics) + 1; + + // build the thread's position + for (int topicNum = 0; topicNum < topicCount; topicNum++) { + final String topic = "topic-" + topicNum; + final int partitionCount = RANDOM.nextInt(maxPartitions) + 1; + for (int partitionNum = 0; partitionNum < partitionCount; partitionNum++) { + final long offset = RANDOM.nextInt(maxOffset) + 1; + threadPosition.withComponent(topic, partitionNum, offset); + } + } + + startLatch.countDown(); + try { + startLatch.await(); + } catch (final InterruptedException e) { + // convert to unchecked exception so the future completes exceptionally and fails the test + throw new RuntimeException(e); + } + + // merge with the shared position + mergePosition.merge(threadPosition); + // duplicate the shared position to get a snapshot of its state + final Position threadMergePositionState = mergePosition.copy(); + + // update the shared position using withComponent + for (final String topic : threadPosition.getTopics()) { + for (final Map.Entry partitionOffset : threadPosition + .getPartitionPositions(topic) + .entrySet()) { + withComponentPosition.withComponent(topic, partitionOffset.getKey(), partitionOffset.getValue()); + } + } + // duplicate the shared position to get a snapshot of its state + final Position threadWithComponentPositionState = withComponentPosition.copy(); + + // validate that any offsets in the merged position and the withComponent position are >= the thread position + for (final String topic : threadPosition.getTopics()) { + final Map threadOffsets = threadPosition.getPartitionPositions(topic); + final Map mergedOffsets = threadMergePositionState.getPartitionPositions(topic); + final Map withComponentOffsets = threadWithComponentPositionState.getPartitionPositions(topic); + + for (final Map.Entry threadOffset : threadOffsets.entrySet()) { + final int partition = threadOffset.getKey(); + final long offsetValue = threadOffset.getValue(); + + // merge checks + assertTrue( + mergedOffsets.containsKey(partition), + "merge method failure. Missing partition " + partition + " for topic " + topic + ); + assertTrue( + mergedOffsets.get(partition) >= offsetValue, + "merge method failure. Offset for topic " + + topic + + " partition " + + partition + + " expected >= " + + offsetValue + + " but got " + + mergedOffsets.get(partition) + ); + + // withComponent checks + assertTrue( + withComponentOffsets.containsKey(partition), + "withComponent method failure. Missing partition " + partition + " for topic " + topic + ); + assertTrue( + withComponentOffsets.get(partition) >= offsetValue, + "withComponent method failure. Offset for topic " + + topic + + " partition " + + partition + + " expected >= " + + offsetValue + + " but got " + + withComponentOffsets.get(partition) + ); + } + } + })); + } + + for (final Future future : futures) { + // Wait for all threads to complete + future.get(1, TimeUnit.SECONDS); // Check for exceptions + } + } finally { + if (executorService != null) { + executorService.shutdown(); + assertTrue(executorService.awaitTermination(10, TimeUnit.SECONDS)); + } + } + } }