diff --git a/streams/src/main/java/org/apache/kafka/streams/internals/ConsumerWrapper.java b/streams/src/main/java/org/apache/kafka/streams/internals/ConsumerWrapper.java index 20cd7cb84d7bd..5088cd4b7dd82 100644 --- a/streams/src/main/java/org/apache/kafka/streams/internals/ConsumerWrapper.java +++ b/streams/src/main/java/org/apache/kafka/streams/internals/ConsumerWrapper.java @@ -27,6 +27,7 @@ import org.apache.kafka.clients.consumer.SubscriptionPattern; import org.apache.kafka.clients.consumer.internals.AsyncKafkaConsumer; import org.apache.kafka.clients.consumer.internals.StreamsRebalanceData; +import org.apache.kafka.clients.consumer.internals.StreamsRebalanceListener; import org.apache.kafka.common.Metric; import org.apache.kafka.common.MetricName; import org.apache.kafka.common.PartitionInfo; @@ -54,10 +55,6 @@ public void wrapConsumer( this.delegate = delegate; } - public AsyncKafkaConsumer consumer() { - return delegate; - } - @Override public Set assignment() { return delegate.assignment(); @@ -78,6 +75,10 @@ public void subscribe(final Collection topics, final ConsumerRebalanceLi delegate.subscribe(topics, callback); } + public void subscribe(final Collection topics, final StreamsRebalanceListener streamsRebalanceListener) { + delegate.subscribe(topics, streamsRebalanceListener); + } + @Override public void assign(final Collection partitions) { delegate.assign(partitions); diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java index fdc5e8df4bc32..46591e557e1e3 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java @@ -1137,19 +1137,29 @@ private void subscribeConsumer() { mainConsumer.subscribe(topologyMetadata.sourceTopicPattern(), rebalanceListener); } else { if (streamsRebalanceData.isPresent()) { - final AsyncKafkaConsumer consumer = mainConsumer instanceof ConsumerWrapper - ? ((ConsumerWrapper) mainConsumer).consumer() - : (AsyncKafkaConsumer) mainConsumer; - consumer.subscribe( - topologyMetadata.allFullSourceTopicNames(), - new DefaultStreamsRebalanceListener( - log, - time, - streamsRebalanceData.get(), - this, - taskManager - ) - ); + if (mainConsumer instanceof ConsumerWrapper) { + ((ConsumerWrapper) mainConsumer).subscribe( + topologyMetadata.allFullSourceTopicNames(), + new DefaultStreamsRebalanceListener( + log, + time, + streamsRebalanceData.get(), + this, + taskManager + ) + ); + } else { + ((AsyncKafkaConsumer) mainConsumer).subscribe( + topologyMetadata.allFullSourceTopicNames(), + new DefaultStreamsRebalanceListener( + log, + time, + streamsRebalanceData.get(), + this, + taskManager + ) + ); + } } else { mainConsumer.subscribe(topologyMetadata.allFullSourceTopicNames(), rebalanceListener); } diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java index 54230d11d3be8..96090aa32fafd 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java @@ -69,7 +69,6 @@ import org.apache.kafka.streams.errors.StreamsException; import org.apache.kafka.streams.errors.TaskCorruptedException; import org.apache.kafka.streams.errors.TaskMigratedException; -import org.apache.kafka.streams.internals.ConsumerWrapper; import org.apache.kafka.streams.kstream.Consumed; import org.apache.kafka.streams.kstream.Materialized; import org.apache.kafka.streams.kstream.internals.ConsumedInternal; @@ -3926,38 +3925,6 @@ t2p1, new PartitionInfo(t2p1.topic(), t2p1.partition(), null, new Node[0], new N ); } - @ParameterizedTest - @MethodSource("data") - public void shouldWrapMainConsumerFromClassConfig(final boolean stateUpdaterEnabled, final boolean processingThreadsEnabled) { - final Properties streamsConfigProps = configProps(false, stateUpdaterEnabled, processingThreadsEnabled); - streamsConfigProps.put(StreamsConfig.GROUP_PROTOCOL_CONFIG, "streams"); - streamsConfigProps.put(InternalConfig.INTERNAL_CONSUMER_WRAPPER, TestWrapper.class); - - thread = createStreamThread("clientId", new StreamsConfig(streamsConfigProps)); - - assertInstanceOf( - AsyncKafkaConsumer.class, - assertInstanceOf(TestWrapper.class, thread.mainConsumer()).consumer() - ); - } - - @ParameterizedTest - @MethodSource("data") - public void shouldWrapMainConsumerFromStringConfig(final boolean stateUpdaterEnabled, final boolean processingThreadsEnabled) { - final Properties streamsConfigProps = configProps(false, stateUpdaterEnabled, processingThreadsEnabled); - streamsConfigProps.put(StreamsConfig.GROUP_PROTOCOL_CONFIG, "streams"); - streamsConfigProps.put(InternalConfig.INTERNAL_CONSUMER_WRAPPER, TestWrapper.class.getName()); - - thread = createStreamThread("clientId", new StreamsConfig(streamsConfigProps)); - - assertInstanceOf( - AsyncKafkaConsumer.class, - assertInstanceOf(TestWrapper.class, thread.mainConsumer()).consumer() - ); - } - - public static final class TestWrapper extends ConsumerWrapper { } - private StreamThread setUpThread(final Properties streamsConfigProps) { final StreamsConfig config = new StreamsConfig(streamsConfigProps); final ConsumerGroupMetadata consumerGroupMetadata = Mockito.mock(ConsumerGroupMetadata.class);