From d649348a038da9081e3644b005d6f1564a3aaf63 Mon Sep 17 00:00:00 2001 From: Rui Fan <1996fanrui@gmail.com> Date: Thu, 15 Jan 2026 17:26:32 +0100 Subject: [PATCH 1/4] [hotfix][runtime] Extract RecordFilter as the interface --- .../DemultiplexingRecordDeserializer.java | 12 +- .../io/recovery/PartitionerRecordFilter.java | 60 ++++++++++ .../runtime/io/recovery/RecordFilter.java | 62 +++++----- .../RescalingStreamTaskNetworkInput.java | 9 +- .../DemultiplexingRecordDeserializerTest.java | 8 +- .../runtime/io/recovery/RecordFilterTest.java | 109 ++++++++++++++++++ 6 files changed, 208 insertions(+), 52 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/PartitionerRecordFilter.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterTest.java diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializer.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializer.java index 548b162d7fe96..c63b762435a99 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializer.java @@ -26,7 +26,6 @@ import org.apache.flink.runtime.plugable.DeserializationDelegate; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamElement; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus; import org.apache.flink.util.CloseableIterator; @@ -38,7 +37,6 @@ import java.util.Comparator; import java.util.Map; import java.util.function.Function; -import java.util.function.Predicate; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -61,14 +59,14 @@ class DemultiplexingRecordDeserializer static class VirtualChannel { private final RecordDeserializer> deserializer; - private final Predicate> recordFilter; + private final RecordFilter recordFilter; Watermark lastWatermark = Watermark.UNINITIALIZED; WatermarkStatus watermarkStatus = WatermarkStatus.ACTIVE; private DeserializationResult lastResult; VirtualChannel( RecordDeserializer> deserializer, - Predicate> recordFilter) { + RecordFilter recordFilter) { this.deserializer = deserializer; this.recordFilter = recordFilter; } @@ -81,7 +79,7 @@ public DeserializationResult getNextRecord(DeserializationDelegate DemultiplexingRecordDeserializer create( InflightDataRescalingDescriptor rescalingDescriptor, Function>> deserializerFactory, - Function>> recordFilterFactory) { + Function> recordFilterFactory) { int[] oldSubtaskIndexes = rescalingDescriptor.getOldSubtaskIndexes(channelInfo.getGateIdx()); if (oldSubtaskIndexes.length == 0) { @@ -223,7 +221,7 @@ static DemultiplexingRecordDeserializer create( deserializerFactory.apply(totalChannels), rescalingDescriptor.isAmbiguous(channelInfo.getGateIdx(), subtask) ? recordFilterFactory.apply(channelInfo) - : RecordFilter.all())); + : RecordFilter.acceptAll())); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/PartitionerRecordFilter.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/PartitionerRecordFilter.java new file mode 100644 index 0000000000000..a9f56cbdbc69d --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/PartitionerRecordFilter.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.io.recovery; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.io.network.api.writer.ChannelSelector; +import org.apache.flink.runtime.plugable.SerializationDelegate; +import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +/** + * A {@link RecordFilter} implementation that uses a partitioner to determine record ownership. + * + *

This filter checks if a record would have arrived at this subtask if it had been partitioned + * upstream. It is used during recovery for ambiguous channel mappings, such as when the downstream + * node of a keyed exchange is rescaled. + * + * @param The type of the record value. + */ +@Internal +public class PartitionerRecordFilter implements RecordFilter { + private final ChannelSelector>> partitioner; + + private final SerializationDelegate> delegate; + + private final int subtaskIndex; + + @SuppressWarnings({"unchecked", "rawtypes"}) + public PartitionerRecordFilter( + ChannelSelector>> partitioner, + TypeSerializer inputSerializer, + int subtaskIndex) { + this.partitioner = partitioner; + this.delegate = new SerializationDelegate<>(new StreamElementSerializer(inputSerializer)); + this.subtaskIndex = subtaskIndex; + } + + @Override + public boolean filter(StreamRecord streamRecord) { + delegate.setInstance(streamRecord); + // Check if record would have arrived at this subtask if it had been partitioned upstream. + return partitioner.selectChannel(delegate) == subtaskIndex; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilter.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilter.java index a0f805f27b582..796ea84dfd843 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilter.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilter.java @@ -17,48 +17,38 @@ package org.apache.flink.streaming.runtime.io.recovery; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.runtime.io.network.api.writer.ChannelSelector; -import org.apache.flink.runtime.plugable.SerializationDelegate; -import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer; +import org.apache.flink.annotation.Internal; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import java.util.function.Predicate; - /** - * Filters records for ambiguous channel mappings. + * A filter interface for determining whether a record should be processed. * - *

For example, when the downstream node of a keyed exchange is scaled from 1 to 2, the state of - * the output side on te upstream node needs to be replicated to both channels. This filter then - * checks the deserialized records on both downstream subtasks and filters out the irrelevant - * records. + *

This interface is used during recovery to filter records for ambiguous channel mappings. For + * example, when the downstream node of a keyed exchange is scaled from 1 to 2, the state of the + * output side on the upstream node needs to be replicated to both channels. The filter then checks + * the deserialized records on both downstream subtasks and filters out the irrelevant records. * - * @param + * @param The type of the record value. */ -class RecordFilter implements Predicate> { - private final ChannelSelector>> partitioner; - - private final SerializationDelegate> delegate; - - private final int subtaskIndex; - - public RecordFilter( - ChannelSelector>> partitioner, - TypeSerializer inputSerializer, - int subtaskIndex) { - this.partitioner = partitioner; - delegate = new SerializationDelegate<>(new StreamElementSerializer(inputSerializer)); - this.subtaskIndex = subtaskIndex; - } - - public static Predicate> all() { +@FunctionalInterface +@Internal +public interface RecordFilter { + + /** + * Tests whether the given record should be accepted. + * + * @param record The stream record to test. + * @return {@code true} if the record should be accepted, {@code false} otherwise. + */ + boolean filter(StreamRecord record); + + /** + * Returns a filter that accepts all records. + * + * @param The type of the record value. + * @return A filter that always returns {@code true}. + */ + static RecordFilter acceptAll() { return record -> true; } - - @Override - public boolean test(StreamRecord streamRecord) { - delegate.setInstance(streamRecord); - // check if record would have arrived at this subtask if it had been partitioned upstream - return partitioner.selectChannel(delegate) == subtaskIndex; - } } diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RescalingStreamTaskNetworkInput.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RescalingStreamTaskNetworkInput.java index ac8bc4109a595..cc8c75825f325 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RescalingStreamTaskNetworkInput.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RescalingStreamTaskNetworkInput.java @@ -40,7 +40,6 @@ import org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner; import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; import org.apache.flink.streaming.runtime.streamrecord.StreamElement; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.StreamTask.CanEmitBatchOfRecordsChecker; import org.apache.flink.streaming.runtime.watermarkstatus.StatusWatermarkValve; import org.apache.flink.util.CollectionUtil; @@ -54,7 +53,6 @@ import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.function.Function; -import java.util.function.Predicate; import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_DECLINED_TASK_NOT_READY; import static org.apache.flink.util.Preconditions.checkState; @@ -201,8 +199,7 @@ public CompletableFuture prepareSnapshot( *

Filters must not be shared across different virtual channels to ensure that the state is * in-sync across different subtasks. */ - static class RecordFilterFactory - implements Function>> { + static class RecordFilterFactory implements Function> { private final Map> partitionerCache = CollectionUtil.newHashMapWithExpectedSize(1); private final Function> gatePartitioners; @@ -225,7 +222,7 @@ public RecordFilterFactory( } @Override - public Predicate> apply(InputChannelInfo channelInfo) { + public RecordFilter apply(InputChannelInfo channelInfo) { // retrieving the partitioner for one input task is rather costly so cache them all final StreamPartitioner partitioner = partitionerCache.computeIfAbsent( @@ -234,7 +231,7 @@ public Predicate> apply(InputChannelInfo channelInfo) { // have the same state across several subtasks StreamPartitioner partitionerCopy = partitioner.copy(); partitionerCopy.setup(numberOfChannels); - return new RecordFilter<>(partitionerCopy, inputSerializer, subtaskIndex); + return new PartitionerRecordFilter<>(partitionerCopy, inputSerializer, subtaskIndex); } private StreamPartitioner createPartitioner(Integer index) { diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializerTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializerTest.java index 361b78e3f2e34..ed2ea64d9e753 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializerTest.java @@ -95,7 +95,7 @@ void testUpscale() throws IOException { unused -> new SpillingAdaptiveSpanningRecordDeserializer<>( ioManager.getSpillingDirectoriesPaths()), - unused -> RecordFilter.all()); + unused -> RecordFilter.acceptAll()); assertThat(deserializer.getVirtualChannelSelectors()) .containsOnly( @@ -136,7 +136,9 @@ void testAmbiguousChannels() throws IOException { unused -> new SpillingAdaptiveSpanningRecordDeserializer<>( ioManager.getSpillingDirectoriesPaths()), - unused -> new RecordFilter(new ModSelector(2), LongSerializer.INSTANCE, 1)); + unused -> + new PartitionerRecordFilter<>( + new ModSelector(2), LongSerializer.INSTANCE, 1)); assertThat(deserializer.getVirtualChannelSelectors()) .containsOnly( @@ -179,7 +181,7 @@ void testWatermarks() throws IOException { unused -> new SpillingAdaptiveSpanningRecordDeserializer<>( ioManager.getSpillingDirectoriesPaths()), - unused -> RecordFilter.all()); + unused -> RecordFilter.acceptAll()); assertThat(deserializer.getVirtualChannelSelectors()).hasSize(4); diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterTest.java new file mode 100644 index 0000000000000..10c7e0b9839ef --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterTest.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.io.recovery; + +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.runtime.io.network.api.writer.ChannelSelector; +import org.apache.flink.runtime.plugable.SerializationDelegate; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link RecordFilter} interface and {@link PartitionerRecordFilter}. */ +class RecordFilterTest { + + @Test + void testAcceptAllFilterAcceptsEveryRecord() { + RecordFilter filter = RecordFilter.acceptAll(); + assertThat(filter.filter(new StreamRecord<>(0L))).isTrue(); + assertThat(filter.filter(new StreamRecord<>(1L))).isTrue(); + assertThat(filter.filter(new StreamRecord<>(42L))).isTrue(); + assertThat(filter.filter(new StreamRecord<>(-1L))).isTrue(); + } + + @Test + void testPartitionerRecordFilterAcceptsMatchingSubtask() { + // Mod-based partitioner with 2 channels, subtask index 0 receives even values + PartitionerRecordFilter filter = + new PartitionerRecordFilter<>( + new ModChannelSelector(2), LongSerializer.INSTANCE, 0); + + assertThat(filter.filter(new StreamRecord<>(0L))).isTrue(); + assertThat(filter.filter(new StreamRecord<>(2L))).isTrue(); + assertThat(filter.filter(new StreamRecord<>(4L))).isTrue(); + } + + @Test + void testPartitionerRecordFilterRejectsNonMatchingSubtask() { + // Mod-based partitioner with 2 channels, subtask index 0 should reject odd values + PartitionerRecordFilter filter = + new PartitionerRecordFilter<>( + new ModChannelSelector(2), LongSerializer.INSTANCE, 0); + + assertThat(filter.filter(new StreamRecord<>(1L))).isFalse(); + assertThat(filter.filter(new StreamRecord<>(3L))).isFalse(); + assertThat(filter.filter(new StreamRecord<>(5L))).isFalse(); + } + + @Test + void testPartitionerRecordFilterWithDifferentSubtaskIndex() { + // Subtask index 1 should accept odd values (mod 2 == 1) + PartitionerRecordFilter filter = + new PartitionerRecordFilter<>( + new ModChannelSelector(2), LongSerializer.INSTANCE, 1); + + assertThat(filter.filter(new StreamRecord<>(1L))).isTrue(); + assertThat(filter.filter(new StreamRecord<>(3L))).isTrue(); + assertThat(filter.filter(new StreamRecord<>(0L))).isFalse(); + assertThat(filter.filter(new StreamRecord<>(2L))).isFalse(); + } + + @Test + void testRecordFilterAsFunctionalInterface() { + // RecordFilter is a FunctionalInterface and can be implemented as a lambda + RecordFilter onlyPositive = record -> record.getValue() > 0; + assertThat(onlyPositive.filter(new StreamRecord<>(5L))).isTrue(); + assertThat(onlyPositive.filter(new StreamRecord<>(-1L))).isFalse(); + assertThat(onlyPositive.filter(new StreamRecord<>(0L))).isFalse(); + } + + /** A simple mod-based channel selector for testing. */ + private static class ModChannelSelector + implements ChannelSelector>> { + private final int numberOfChannels; + + private ModChannelSelector(int numberOfChannels) { + this.numberOfChannels = numberOfChannels; + } + + @Override + public void setup(int numberOfChannels) {} + + @Override + public int selectChannel(SerializationDelegate> record) { + return (int) (record.getInstance().getValue() % numberOfChannels); + } + + @Override + public boolean isBroadcast() { + return false; + } + } +} From 9f8d09e2b3ff12803dc2dc301a8faad46d86de85 Mon Sep 17 00:00:00 2001 From: Rui Fan <1996fanrui@gmail.com> Date: Thu, 15 Jan 2026 17:28:49 +0100 Subject: [PATCH 2/4] [hotfix] Extract VirtualChannel as the public class --- .../DemultiplexingRecordDeserializer.java | 55 +------- .../runtime/io/recovery/VirtualChannel.java | 130 ++++++++++++++++++ 2 files changed, 133 insertions(+), 52 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannel.java diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializer.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializer.java index c63b762435a99..6f1f3bda8b0e4 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializer.java @@ -57,56 +57,6 @@ class DemultiplexingRecordDeserializer private VirtualChannel currentVirtualChannel; - static class VirtualChannel { - private final RecordDeserializer> deserializer; - private final RecordFilter recordFilter; - Watermark lastWatermark = Watermark.UNINITIALIZED; - WatermarkStatus watermarkStatus = WatermarkStatus.ACTIVE; - private DeserializationResult lastResult; - - VirtualChannel( - RecordDeserializer> deserializer, - RecordFilter recordFilter) { - this.deserializer = deserializer; - this.recordFilter = recordFilter; - } - - public DeserializationResult getNextRecord(DeserializationDelegate delegate) - throws IOException { - do { - lastResult = deserializer.getNextRecord(delegate); - - if (lastResult.isFullRecord()) { - final StreamElement element = delegate.getInstance(); - // test if record belongs to this subtask if it comes from ambiguous channel - if (element.isRecord() && recordFilter.filter(element.asRecord())) { - return lastResult; - } else if (element.isWatermark()) { - lastWatermark = element.asWatermark(); - return lastResult; - } else if (element.isWatermarkStatus()) { - watermarkStatus = element.asWatermarkStatus(); - return lastResult; - } - } - // loop is only re-executed for filtered full records - } while (!lastResult.isBufferConsumed()); - return DeserializationResult.PARTIAL_RECORD; - } - - public void setNextBuffer(Buffer buffer) throws IOException { - deserializer.setNextBuffer(buffer); - } - - public void clear() { - deserializer.clear(); - } - - public boolean hasPartialData() { - return lastResult != null && !lastResult.isBufferConsumed(); - } - } - public DemultiplexingRecordDeserializer( Map> channels) { this.channels = checkNotNull(channels); @@ -159,7 +109,7 @@ public DeserializationResult getNextRecord(DeserializationDelegate virtualChannel.lastWatermark) + .map(VirtualChannel::getLastWatermark) .min(Comparator.comparing(Watermark::getTimestamp)) .orElseThrow( () -> @@ -174,7 +124,8 @@ public DeserializationResult getNextRecord(DeserializationDelegate d.watermarkStatus.isActive())) { + if (channels.values().stream() + .anyMatch(vc -> vc.getWatermarkStatus().isActive())) { delegate.setInstance(WatermarkStatus.ACTIVE); } return result; diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannel.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannel.java new file mode 100644 index 0000000000000..ddcbba79f250a --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannel.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.io.recovery; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; +import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.plugable.DeserializationDelegate; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamElement; +import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus; + +import java.io.IOException; + +/** + * Represents a virtual channel for demultiplexing records during recovery. + * + *

A virtual channel wraps a {@link RecordDeserializer} and adds record filtering capability, + * along with tracking watermark and watermark status state. + * + * @param The type of record values. + */ +@Internal +public class VirtualChannel { + private final RecordDeserializer> deserializer; + private final RecordFilter recordFilter; + + private Watermark lastWatermark = Watermark.UNINITIALIZED; + private WatermarkStatus watermarkStatus = WatermarkStatus.ACTIVE; + private DeserializationResult lastResult; + + public VirtualChannel( + RecordDeserializer> deserializer, + RecordFilter recordFilter) { + this.deserializer = deserializer; + this.recordFilter = recordFilter; + } + + /** + * Deserializes the next record from the buffer, applying the record filter. + * + *

This method loops through records until it finds one that passes the filter or the buffer + * is consumed. Watermarks and watermark statuses are always accepted and their state is + * updated. + * + * @param delegate The deserialization delegate to populate with the record. + * @return The deserialization result indicating whether a full record was read. + * @throws IOException If an I/O error occurs during deserialization. + */ + public DeserializationResult getNextRecord(DeserializationDelegate delegate) + throws IOException { + do { + lastResult = deserializer.getNextRecord(delegate); + + if (lastResult.isFullRecord()) { + final StreamElement element = delegate.getInstance(); + // test if record belongs to this subtask if it comes from ambiguous channel + if (element.isRecord() && recordFilter.filter(element.asRecord())) { + return lastResult; + } else if (element.isWatermark()) { + lastWatermark = element.asWatermark(); + return lastResult; + } else if (element.isWatermarkStatus()) { + watermarkStatus = element.asWatermarkStatus(); + return lastResult; + } + } + // loop is only re-executed for filtered full records + } while (!lastResult.isBufferConsumed()); + return DeserializationResult.PARTIAL_RECORD; + } + + /** + * Sets the next buffer to be deserialized. + * + * @param buffer The buffer containing serialized records. + * @throws IOException If an I/O error occurs. + */ + public void setNextBuffer(Buffer buffer) throws IOException { + deserializer.setNextBuffer(buffer); + } + + /** Clears the deserializer state. */ + public void clear() { + deserializer.clear(); + } + + /** + * Checks if there is partial data remaining in the buffer. + * + * @return true if the last result indicates the buffer was not fully consumed. + */ + public boolean hasPartialData() { + return lastResult != null && !lastResult.isBufferConsumed(); + } + + /** + * Gets the last watermark received on this virtual channel. + * + * @return The last watermark, or {@link Watermark#UNINITIALIZED} if none received yet. + */ + public Watermark getLastWatermark() { + return lastWatermark; + } + + /** + * Gets the current watermark status of this virtual channel. + * + * @return The current watermark status. + */ + public WatermarkStatus getWatermarkStatus() { + return watermarkStatus; + } +} From ca52afa867b02ca6df236b5e94241f9831776b7d Mon Sep 17 00:00:00 2001 From: Rui Fan <1996fanrui@gmail.com> Date: Thu, 15 Jan 2026 17:41:54 +0100 Subject: [PATCH 3/4] [FLINK-38541][checkpoint] Introducing config option: execution.checkpointing.unaligned.during-recovery.enabled --- .../common_checkpointing_section.html | 6 ++ .../configuration/CheckpointingOptions.java | 55 +++++++++++++++++++ .../CheckpointingOptionsTest.java | 46 ++++++++++++++++ 3 files changed, 107 insertions(+) diff --git a/docs/layouts/shortcodes/generated/common_checkpointing_section.html b/docs/layouts/shortcodes/generated/common_checkpointing_section.html index 9db118174a328..f942529b79c71 100644 --- a/docs/layouts/shortcodes/generated/common_checkpointing_section.html +++ b/docs/layouts/shortcodes/generated/common_checkpointing_section.html @@ -56,5 +56,11 @@ Integer The maximum number of completed checkpoints to retain. + +

execution.checkpointing.unaligned.recover-output-on-downstream.enabled
+ false + Boolean + Whether recovering output buffers of upstream task on downstream task directly when job restores from the unaligned checkpoint. + diff --git a/flink-core/src/main/java/org/apache/flink/configuration/CheckpointingOptions.java b/flink-core/src/main/java/org/apache/flink/configuration/CheckpointingOptions.java index ce205077474db..69ec46ed85fc2 100644 --- a/flink-core/src/main/java/org/apache/flink/configuration/CheckpointingOptions.java +++ b/flink-core/src/main/java/org/apache/flink/configuration/CheckpointingOptions.java @@ -659,6 +659,7 @@ public class CheckpointingOptions { + "Each subtask will create a new channel state file when this is configured to 1."); @Experimental + @Documentation.Section(Documentation.Sections.COMMON_CHECKPOINTING) public static final ConfigOption UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM = ConfigOptions.key( "execution.checkpointing.unaligned.recover-output-on-downstream.enabled") @@ -668,6 +669,40 @@ public class CheckpointingOptions { "Whether recovering output buffers of upstream task on downstream task directly " + "when job restores from the unaligned checkpoint."); + /** + * Whether to enable checkpointing during recovery from an unaligned checkpoint. + * + *

When enabled, the job can take checkpoints while still recovering channel state (inflight + * data) from a previous unaligned checkpoint. This avoids the need to wait for full recovery + * before the first checkpoint can be triggered, which reduces the window of vulnerability to + * failures during recovery. + * + *

This option requires {@link #UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM} to be enabled. It + * does not require unaligned checkpoints to be currently enabled, because a job may restore + * from an unaligned checkpoint while having unaligned checkpoints disabled for the new + * execution. + */ + @Experimental + @Documentation.ExcludeFromDocumentation( + "This option is not yet ready for public use, will be documented in a follow-up commit") + public static final ConfigOption UNALIGNED_DURING_RECOVERY_ENABLED = + ConfigOptions.key("execution.checkpointing.unaligned.during-recovery.enabled") + .booleanType() + .defaultValue(false) + .withDescription( + Description.builder() + .text( + "Whether to enable checkpointing during recovery from an unaligned checkpoint. " + + "When enabled, the job can take checkpoints while still recovering channel state " + + "(inflight data), reducing the window of vulnerability to failures during recovery.") + .linebreak() + .linebreak() + .text( + "This option requires %s to be enabled.", + TextElement.code( + UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM.key())) + .build()); + /** * Determines whether checkpointing is enabled based on the configuration. * @@ -763,4 +798,24 @@ public static boolean isUnalignedCheckpointInterruptibleTimersEnabled(Configurat } return config.get(ENABLE_UNALIGNED_INTERRUPTIBLE_TIMERS); } + + /** + * Determines whether unaligned checkpoint support during recovery is enabled. + * + *

This feature requires {@link #UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM} to be enabled. Note + * that it does not require unaligned checkpoints to be currently enabled, because a job may + * restore from an unaligned checkpoint while having unaligned checkpoints disabled for the new + * execution. + * + * @param config the configuration to check + * @return {@code true} if unaligned checkpoint during recovery is enabled, {@code false} + * otherwise + */ + @Internal + public static boolean isUnalignedDuringRecoveryEnabled(Configuration config) { + if (!config.get(UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM)) { + return false; + } + return config.get(UNALIGNED_DURING_RECOVERY_ENABLED); + } } diff --git a/flink-core/src/test/java/org/apache/flink/configuration/CheckpointingOptionsTest.java b/flink-core/src/test/java/org/apache/flink/configuration/CheckpointingOptionsTest.java index 4104dfb762b87..747f789af3599 100644 --- a/flink-core/src/test/java/org/apache/flink/configuration/CheckpointingOptionsTest.java +++ b/flink-core/src/test/java/org/apache/flink/configuration/CheckpointingOptionsTest.java @@ -328,4 +328,50 @@ void testIsUnalignedCheckpointInterruptibleTimersEnabled() { .as("Interruptible timers should be disabled when runtime mode is BATCH") .isFalse(); } + + @Test + void testIsUnalignedDuringRecoveryEnabled() { + // Test when both options are disabled (default) - should return false + Configuration defaultConfig = new Configuration(); + assertThat(CheckpointingOptions.isUnalignedDuringRecoveryEnabled(defaultConfig)) + .as("During-recovery should be disabled by default") + .isFalse(); + + // Test when during-recovery is enabled but recover-output-on-downstream is disabled + Configuration onlyDuringRecoveryConfig = new Configuration(); + onlyDuringRecoveryConfig.set(CheckpointingOptions.UNALIGNED_DURING_RECOVERY_ENABLED, true); + assertThat(CheckpointingOptions.isUnalignedDuringRecoveryEnabled(onlyDuringRecoveryConfig)) + .as( + "During-recovery should be disabled when recover-output-on-downstream is not enabled") + .isFalse(); + + // Test when recover-output-on-downstream is enabled but during-recovery is disabled + Configuration onlyRecoverOnDownstreamConfig = new Configuration(); + onlyRecoverOnDownstreamConfig.set( + CheckpointingOptions.UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM, true); + assertThat( + CheckpointingOptions.isUnalignedDuringRecoveryEnabled( + onlyRecoverOnDownstreamConfig)) + .as("During-recovery should be disabled when during-recovery option is not enabled") + .isFalse(); + + // Test when both options are enabled - should return true + Configuration bothEnabledConfig = new Configuration(); + bothEnabledConfig.set(CheckpointingOptions.UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM, true); + bothEnabledConfig.set(CheckpointingOptions.UNALIGNED_DURING_RECOVERY_ENABLED, true); + assertThat(CheckpointingOptions.isUnalignedDuringRecoveryEnabled(bothEnabledConfig)) + .as( + "During-recovery should be enabled when both recover-output-on-downstream and during-recovery are enabled") + .isTrue(); + + // Test when recover-output-on-downstream is explicitly false and during-recovery is true + Configuration explicitlyDisabledConfig = new Configuration(); + explicitlyDisabledConfig.set( + CheckpointingOptions.UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM, false); + explicitlyDisabledConfig.set(CheckpointingOptions.UNALIGNED_DURING_RECOVERY_ENABLED, true); + assertThat(CheckpointingOptions.isUnalignedDuringRecoveryEnabled(explicitlyDisabledConfig)) + .as( + "During-recovery should be disabled when recover-output-on-downstream is explicitly false") + .isFalse(); + } } From 997e3a3c42c9b6fe78070fc2fcd35d4a32a617d7 Mon Sep 17 00:00:00 2001 From: Rui Fan <1996fanrui@gmail.com> Date: Wed, 18 Feb 2026 21:25:46 +0100 Subject: [PATCH 4/4] [FLINK-38930][checkpoint] Filtering record before processing without spilling strategy Core filtering mechanism for recovered channel state buffers: - ChannelStateFilteringHandler with per-gate GateFilterHandler - RecordFilterContext with VirtualChannelRecordFilterFactory - Partial data check in SequentialChannelStateReaderImpl - Fix RecordFilterContext for Union downscale scenario --- .../channel/ChannelStateFilteringHandler.java | 418 ++++++++++++++++++ .../channel/RecoveredChannelStateHandler.java | 67 ++- .../channel/SequentialChannelStateReader.java | 13 +- .../SequentialChannelStateReaderImpl.java | 29 +- .../StreamMultipleInputProcessorFactory.java | 8 +- .../io/StreamTaskNetworkInputFactory.java | 9 +- .../io/StreamTwoInputProcessorFactory.java | 11 +- .../io/recovery/RecordFilterContext.java | 227 ++++++++++ .../VirtualChannelRecordFilterFactory.java | 123 ++++++ .../runtime/tasks/OneInputStreamTask.java | 6 +- .../streaming/runtime/tasks/StreamTask.java | 82 +++- ...InputChannelRecoveredStateHandlerTest.java | 6 +- .../SequentialChannelStateReaderImplTest.java | 3 +- .../state/ChannelPersistenceITCase.java | 3 +- .../io/recovery/RecordFilterContextTest.java | 193 ++++++++ ...VirtualChannelRecordFilterFactoryTest.java | 90 ++++ 16 files changed, 1263 insertions(+), 25 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactory.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactoryTest.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java new file mode 100644 index 0000000000000..fe4feb7e05e7a --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java @@ -0,0 +1,418 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.DataOutputSerializer; +import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor; +import org.apache.flink.runtime.checkpoint.RescaleMappings; +import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor; +import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; +import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult; +import org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.plugable.DeserializationDelegate; +import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate; +import org.apache.flink.streaming.runtime.io.recovery.RecordFilter; +import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext; +import org.apache.flink.streaming.runtime.io.recovery.VirtualChannel; +import org.apache.flink.streaming.runtime.io.recovery.VirtualChannelRecordFilterFactory; +import org.apache.flink.streaming.runtime.streamrecord.StreamElement; +import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Filters recovered channel state buffers during the channel-state-unspilling phase, removing + * records that do not belong to the current subtask after rescaling. + * + *

Uses a per-gate architecture: each {@link InputGate} gets its own {@link GateFilterHandler} + * with the correct serializer, so multi-input tasks (e.g., TwoInputStreamTask) correctly + * deserialize different record types on different gates. + */ +@Internal +public class ChannelStateFilteringHandler { + + /** + * Handles record filtering for a single input gate. Each gate has its own serializer and set of + * virtual channels, allowing different gates to handle different record types independently. + */ + static class GateFilterHandler { + + private final Map> virtualChannels; + private final StreamElementSerializer serializer; + private final DeserializationDelegate deserializationDelegate; + private final DataOutputSerializer outputSerializer; + private final byte[] lengthBuffer = new byte[4]; + + GateFilterHandler( + Map> virtualChannels, + StreamElementSerializer serializer) { + this.virtualChannels = checkNotNull(virtualChannels); + this.serializer = checkNotNull(serializer); + this.deserializationDelegate = new NonReusingDeserializationDelegate<>(serializer); + this.outputSerializer = new DataOutputSerializer(128); + } + + /** + * Deserializes records from {@code sourceBuffer}, applies the virtual channel's record + * filter, and re-serializes the surviving records into new buffers. + */ + List filterAndRewrite( + int oldSubtaskIndex, + int oldChannelIndex, + Buffer sourceBuffer, + BufferSupplier bufferSupplier) + throws IOException, InterruptedException { + + SubtaskConnectionDescriptor key = + new SubtaskConnectionDescriptor(oldSubtaskIndex, oldChannelIndex); + VirtualChannel vc = virtualChannels.get(key); + if (vc == null) { + throw new IllegalStateException( + "No VirtualChannel found for key: " + + key + + "; known channels are " + + virtualChannels.keySet()); + } + + vc.setNextBuffer(sourceBuffer); + + List filteredElements = new ArrayList<>(); + + while (true) { + DeserializationResult result = vc.getNextRecord(deserializationDelegate); + if (result.isFullRecord()) { + filteredElements.add(deserializationDelegate.getInstance()); + } + if (result.isBufferConsumed()) { + break; + } + } + + return serializeToBuffers(filteredElements, bufferSupplier); + } + + /** + * Serializes stream elements into buffers using the length-prefixed format (4-byte + * big-endian length + record bytes) expected by Flink's record deserializers. + */ + private List serializeToBuffers( + List elements, BufferSupplier bufferSupplier) + throws IOException, InterruptedException { + + List resultBuffers = new ArrayList<>(); + + if (elements.isEmpty()) { + return resultBuffers; + } + + Buffer currentBuffer = bufferSupplier.requestBufferBlocking(); + + for (StreamElement element : elements) { + outputSerializer.clear(); + serializer.serialize(element, outputSerializer); + int recordLength = outputSerializer.length(); + + writeLengthToBuffer(recordLength); + currentBuffer = + writeDataToBuffer( + lengthBuffer, 0, 4, currentBuffer, resultBuffers, bufferSupplier); + + byte[] serializedData = outputSerializer.getSharedBuffer(); + currentBuffer = + writeDataToBuffer( + serializedData, + 0, + recordLength, + currentBuffer, + resultBuffers, + bufferSupplier); + } + + if (currentBuffer.readableBytes() > 0) { + resultBuffers.add(currentBuffer.retainBuffer()); + } + currentBuffer.recycleBuffer(); + + return resultBuffers; + } + + private void writeLengthToBuffer(int length) { + lengthBuffer[0] = (byte) (length >> 24); + lengthBuffer[1] = (byte) (length >> 16); + lengthBuffer[2] = (byte) (length >> 8); + lengthBuffer[3] = (byte) length; + } + + /** + * Writes data to the current buffer, spilling into new buffers from {@code bufferSupplier} + * when the current one is full. + * + * @return the buffer to continue writing into (may differ from the input buffer). + */ + private Buffer writeDataToBuffer( + byte[] data, + int dataOffset, + int dataLength, + Buffer currentBuffer, + List resultBuffers, + BufferSupplier bufferSupplier) + throws IOException, InterruptedException { + int offset = dataOffset; + int remaining = dataLength; + + while (remaining > 0) { + int writableBytes = currentBuffer.getMaxCapacity() - currentBuffer.getSize(); + + if (writableBytes == 0) { + if (currentBuffer.readableBytes() > 0) { + resultBuffers.add(currentBuffer.retainBuffer()); + } + currentBuffer.recycleBuffer(); + currentBuffer = bufferSupplier.requestBufferBlocking(); + writableBytes = currentBuffer.getMaxCapacity(); + } + + int bytesToWrite = Math.min(remaining, writableBytes); + currentBuffer + .getMemorySegment() + .put( + currentBuffer.getMemorySegmentOffset() + currentBuffer.getSize(), + data, + offset, + bytesToWrite); + currentBuffer.setSize(currentBuffer.getSize() + bytesToWrite); + + offset += bytesToWrite; + remaining -= bytesToWrite; + } + return currentBuffer; + } + + boolean hasPartialData() { + return virtualChannels.values().stream().anyMatch(VirtualChannel::hasPartialData); + } + + void clear() { + virtualChannels.values().forEach(VirtualChannel::clear); + } + } + + // Wildcard allows heterogeneous record types across gates. + private final GateFilterHandler[] gateHandlers; + + ChannelStateFilteringHandler(GateFilterHandler[] gateHandlers) { + this.gateHandlers = checkNotNull(gateHandlers); + } + + /** + * Creates a handler from the recovery context, building per-gate virtual channels based on + * rescaling descriptors. Returns {@code null} if no filtering is needed (e.g., source tasks or + * no rescaling). + */ + @Nullable + public static ChannelStateFilteringHandler createFromContext( + RecordFilterContext filterContext, InputGate[] inputGates) { + // Source tasks have no network inputs + if (filterContext.getNumberOfGates() == 0) { + return null; + } + + InflightDataRescalingDescriptor rescalingDescriptor = + filterContext.getRescalingDescriptor(); + + GateFilterHandler[] gateHandlers = new GateFilterHandler[inputGates.length]; + boolean hasAnyVirtualChannels = false; + + for (int gateIndex = 0; gateIndex < inputGates.length; gateIndex++) { + gateHandlers[gateIndex] = + createGateHandler(filterContext, inputGates, rescalingDescriptor, gateIndex); + if (gateHandlers[gateIndex] != null) { + hasAnyVirtualChannels = true; + } + } + + if (!hasAnyVirtualChannels) { + return null; + } + + return new ChannelStateFilteringHandler(gateHandlers); + } + + /** + * Creates a {@link GateFilterHandler} for a single gate. The method-level type parameter + * ensures type safety within each gate while allowing different gates to have different types. + */ + @SuppressWarnings("unchecked") + @Nullable + private static GateFilterHandler createGateHandler( + RecordFilterContext filterContext, + InputGate[] inputGates, + InflightDataRescalingDescriptor rescalingDescriptor, + int gateIndex) { + RecordFilterContext.InputFilterConfig inputConfig = filterContext.getInputConfig(gateIndex); + if (inputConfig == null) { + throw new IllegalStateException( + "No InputFilterConfig for gateIndex " + + gateIndex + + ". This indicates a bug in RecordFilterContext initialization."); + } + + InputGate gate = inputGates[gateIndex]; + int[] oldSubtaskIndexes = rescalingDescriptor.getOldSubtaskIndexes(gateIndex); + RescaleMappings channelMapping = rescalingDescriptor.getChannelMapping(gateIndex); + + TypeSerializer typeSerializer = (TypeSerializer) inputConfig.getTypeSerializer(); + StreamElementSerializer elementSerializer = + new StreamElementSerializer<>(typeSerializer); + + VirtualChannelRecordFilterFactory filterFactory = + VirtualChannelRecordFilterFactory.fromContext(filterContext, gateIndex); + + Map> gateVirtualChannels = new HashMap<>(); + + for (int oldSubtaskIndex : oldSubtaskIndexes) { + int numChannels = gate.getNumberOfInputChannels(); + int[] oldChannelIndexes = getOldChannelIndexes(channelMapping, numChannels); + + for (int oldChannelIndex : oldChannelIndexes) { + SubtaskConnectionDescriptor key = + new SubtaskConnectionDescriptor(oldSubtaskIndex, oldChannelIndex); + + if (gateVirtualChannels.containsKey(key)) { + continue; + } + + // Only ambiguous channels need actual filtering; non-ambiguous ones pass through + boolean isAmbiguous = rescalingDescriptor.isAmbiguous(gateIndex, oldSubtaskIndex); + + RecordFilter recordFilter = + isAmbiguous + ? filterFactory.createFilter() + : VirtualChannelRecordFilterFactory.createPassThroughFilter(); + + RecordDeserializer> deserializer = + createDeserializer(filterContext.getTmpDirectories()); + + VirtualChannel vc = new VirtualChannel<>(deserializer, recordFilter); + gateVirtualChannels.put(key, vc); + } + } + + if (gateVirtualChannels.isEmpty()) { + return null; + } + + return new GateFilterHandler<>(gateVirtualChannels, elementSerializer); + } + + /** + * Collects all old channel indexes that are mapped from any new channel index in this gate. + * channelMapping is new-to-old, so we iterate new indexes and collect their old counterparts. + */ + private static int[] getOldChannelIndexes(RescaleMappings channelMapping, int numChannels) { + List oldIndexes = new ArrayList<>(); + for (int newIndex = 0; newIndex < numChannels; newIndex++) { + int[] mapped = channelMapping.getMappedIndexes(newIndex); + for (int oldIndex : mapped) { + if (!oldIndexes.contains(oldIndex)) { + oldIndexes.add(oldIndex); + } + } + } + return oldIndexes.stream().mapToInt(Integer::intValue).toArray(); + } + + private static RecordDeserializer> createDeserializer( + String[] tmpDirectories) { + if (tmpDirectories != null && tmpDirectories.length > 0) { + return new SpillingAdaptiveSpanningRecordDeserializer<>(tmpDirectories); + } else { + String[] defaultDirs = new String[] {System.getProperty("java.io.tmpdir")}; + return new SpillingAdaptiveSpanningRecordDeserializer<>(defaultDirs); + } + } + + /** + * Filters a recovered buffer from the specified virtual channel, returning new buffers + * containing only the records that belong to the current subtask. + * + * @return filtered buffers, possibly empty if all records were filtered out. + */ + public List filterAndRewrite( + int gateIndex, + int oldSubtaskIndex, + int oldChannelIndex, + Buffer sourceBuffer, + BufferSupplier bufferSupplier) + throws IOException, InterruptedException { + + if (gateIndex < 0 || gateIndex >= gateHandlers.length) { + throw new IllegalStateException( + "Invalid gateIndex: " + + gateIndex + + ", number of gates: " + + gateHandlers.length); + } + + GateFilterHandler gateHandler = gateHandlers[gateIndex]; + if (gateHandler == null) { + throw new IllegalStateException( + "No handler for gateIndex " + + gateIndex + + ". This gate is not a network input and should not have recovered buffers."); + } + return gateHandler.filterAndRewrite( + oldSubtaskIndex, oldChannelIndex, sourceBuffer, bufferSupplier); + } + + /** Returns {@code true} if any virtual channel has a partial (spanning) record pending. */ + public boolean hasPartialData() { + for (GateFilterHandler handler : gateHandlers) { + if (handler != null && handler.hasPartialData()) { + return true; + } + } + return false; + } + + public void clear() { + for (GateFilterHandler handler : gateHandlers) { + if (handler != null) { + handler.clear(); + } + } + } + + /** Provides buffers for re-serializing filtered records. Implementations may block. */ + @FunctionalInterface + public interface BufferSupplier { + Buffer requestBufferBlocking() throws IOException, InterruptedException; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java index 85fc31db4bc69..1c3b19e33d3ba 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java @@ -31,6 +31,7 @@ import org.apache.flink.runtime.io.network.partition.consumer.RecoveredInputChannel; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.io.IOException; import java.util.HashMap; @@ -63,7 +64,7 @@ public void close() { * case of an error. */ void recover(Info info, int oldSubtaskIndex, BufferWithContext bufferWithContext) - throws IOException; + throws IOException, InterruptedException; } class InputChannelRecoveredStateHandler @@ -75,10 +76,19 @@ class InputChannelRecoveredStateHandler private final Map rescaledChannels = new HashMap<>(); private final Map oldToNewMappings = new HashMap<>(); + /** + * Optional filtering handler for filtering recovered buffers. When non-null, filtering is + * performed during recovery in the channel-state-unspilling thread. + */ + @Nullable private final ChannelStateFilteringHandler filteringHandler; + InputChannelRecoveredStateHandler( - InputGate[] inputGates, InflightDataRescalingDescriptor channelMapping) { + InputGate[] inputGates, + InflightDataRescalingDescriptor channelMapping, + @Nullable ChannelStateFilteringHandler filteringHandler) { this.inputGates = inputGates; this.channelMapping = channelMapping; + this.filteringHandler = filteringHandler; } @Override @@ -95,23 +105,60 @@ public void recover( InputChannelInfo channelInfo, int oldSubtaskIndex, BufferWithContext bufferWithContext) - throws IOException { + throws IOException, InterruptedException { Buffer buffer = bufferWithContext.context; try { if (buffer.readableBytes() > 0) { RecoveredInputChannel channel = getMappedChannels(channelInfo); - channel.onRecoveredStateBuffer( - EventSerializer.toBuffer( - new SubtaskConnectionDescriptor( - oldSubtaskIndex, channelInfo.getInputChannelIdx()), - false)); - channel.onRecoveredStateBuffer(buffer.retainBuffer()); + + if (filteringHandler != null) { + // Filtering mode: filter records and rewrite to new buffers + recoverWithFiltering(channel, channelInfo, oldSubtaskIndex, buffer); + } else { + // Non-filtering mode: pass through original buffer with descriptor + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer( + new SubtaskConnectionDescriptor( + oldSubtaskIndex, channelInfo.getInputChannelIdx()), + false)); + channel.onRecoveredStateBuffer(buffer.retainBuffer()); + } } } finally { buffer.recycleBuffer(); } } + private void recoverWithFiltering( + RecoveredInputChannel channel, + InputChannelInfo channelInfo, + int oldSubtaskIndex, + Buffer buffer) + throws IOException, InterruptedException { + checkState(filteringHandler != null, "filtering handler not set."); + // Extra retain: filterAndRewrite consumes one ref, caller's finally releases another. + buffer.retainBuffer(); + + List filteredBuffers; + try { + filteredBuffers = + filteringHandler.filterAndRewrite( + channelInfo.getGateIdx(), + oldSubtaskIndex, + channelInfo.getInputChannelIdx(), + buffer, + channel::requestBufferBlocking); + } catch (Throwable t) { + // filterAndRewrite didn't consume the buffer, release the extra ref. + buffer.recycleBuffer(); + throw t; + } + + for (Buffer filteredBuffer : filteredBuffers) { + channel.onRecoveredStateBuffer(filteredBuffer); + } + } + @Override public void close() throws IOException { // note that we need to finish all RecoveredInputChannels, not just those with state @@ -191,7 +238,7 @@ public void recover( ResultSubpartitionInfo subpartitionInfo, int oldSubtaskIndex, BufferWithContext bufferWithContext) - throws IOException { + throws IOException, InterruptedException { try (BufferBuilder bufferBuilder = bufferWithContext.context; BufferConsumer bufferConsumer = bufferBuilder.createBufferConsumerFromBeginning()) { bufferBuilder.finish(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java index 7adf6d6294671..547b60ef93aee 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java @@ -20,6 +20,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext; import java.io.IOException; @@ -27,7 +28,14 @@ @Internal public interface SequentialChannelStateReader extends AutoCloseable { - void readInputData(InputGate[] inputGates) throws IOException, InterruptedException; + /** + * Reads input channel state with filtering support. + * + * @param inputGates The input gates to recover state for. + * @param filterContext The filter context containing input configs and rescaling info. + */ + void readInputData(InputGate[] inputGates, RecordFilterContext filterContext) + throws IOException, InterruptedException; void readOutputData(ResultPartitionWriter[] writers, boolean notifyAndBlockOnCompletion) throws IOException, InterruptedException; @@ -39,7 +47,8 @@ void readOutputData(ResultPartitionWriter[] writers, boolean notifyAndBlockOnCom new SequentialChannelStateReader() { @Override - public void readInputData(InputGate[] inputGates) {} + public void readInputData( + InputGate[] inputGates, RecordFilterContext filterContext) {} @Override public void readOutputData( diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java index 3daa4b4947a61..ca0ae8a8385a5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java @@ -28,6 +28,7 @@ import org.apache.flink.runtime.state.AbstractChannelStateHandle; import org.apache.flink.runtime.state.ChannelStateHelper; import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext; import java.io.Closeable; import java.io.IOException; @@ -43,6 +44,7 @@ import static java.util.Comparator.comparingLong; import static java.util.stream.Collectors.groupingBy; import static java.util.stream.Collectors.toList; +import static org.apache.flink.util.Preconditions.checkState; /** {@link SequentialChannelStateReader} implementation. */ public class SequentialChannelStateReaderImpl implements SequentialChannelStateReader { @@ -58,10 +60,21 @@ public SequentialChannelStateReaderImpl(TaskStateSnapshot taskStateSnapshot) { } @Override - public void readInputData(InputGate[] inputGates) throws IOException, InterruptedException { + public void readInputData(InputGate[] inputGates, RecordFilterContext filterContext) + throws IOException, InterruptedException { + + // Create filtering handler if filtering is needed + ChannelStateFilteringHandler filteringHandler = null; + if (filterContext.isUnalignedDuringRecoveryEnabled()) { + filteringHandler = + ChannelStateFilteringHandler.createFromContext(filterContext, inputGates); + } + try (InputChannelRecoveredStateHandler stateHandler = new InputChannelRecoveredStateHandler( - inputGates, taskStateSnapshot.getInputRescalingDescriptor())) { + inputGates, + taskStateSnapshot.getInputRescalingDescriptor(), + filteringHandler)) { read( stateHandler, groupByDelegate( @@ -72,6 +85,18 @@ public void readInputData(InputGate[] inputGates) throws IOException, Interrupte groupByDelegate( streamSubtaskStates(), OperatorSubtaskState::getUpstreamOutputBufferState)); + + if (filteringHandler != null) { + checkState( + !filteringHandler.hasPartialData(), + "Not all data has been fully consumed during filtering"); + } + } finally { + // Clean up filtering handler resources (e.g., temp files from + // SpillingAdaptiveSpanningRecordDeserializer) on both success and error paths + if (filteringHandler != null) { + filteringHandler.clear(); + } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamMultipleInputProcessorFactory.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamMultipleInputProcessorFactory.java index 78877a3d62ec6..db92afd2f4294 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamMultipleInputProcessorFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamMultipleInputProcessorFactory.java @@ -23,6 +23,7 @@ import org.apache.flink.api.common.TaskInfo; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.configuration.CheckpointingOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.memory.ManagedMemoryUseCase; import org.apache.flink.metrics.Counter; @@ -103,6 +104,10 @@ public static StreamMultipleInputProcessor create( "Number of configured inputs in StreamConfig [%s] doesn't match the main operator's number of inputs [%s]", configuredInputs.length, inputsCount); + + boolean unalignedDuringRecoveryEnabled = + CheckpointingOptions.isUnalignedDuringRecoveryEnabled(jobConfig); + StreamTaskInput[] inputs = new StreamTaskInput[inputsCount]; for (int i = 0; i < inputsCount; i++) { StreamConfig.InputConfig configuredInput = configuredInputs[i]; @@ -121,7 +126,8 @@ public static StreamMultipleInputProcessor create( gatePartitioners, taskInfo, canEmitBatchOfRecords, - streamConfig.getWatermarkDeclarations(userClassloader)); + streamConfig.getWatermarkDeclarations(userClassloader), + unalignedDuringRecoveryEnabled); } else if (configuredInput instanceof StreamConfig.SourceInputConfig) { StreamConfig.SourceInputConfig sourceInput = (StreamConfig.SourceInputConfig) configuredInput; diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputFactory.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputFactory.java index 46c9cd96936e3..7718419f6bb55 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputFactory.java @@ -47,9 +47,14 @@ public static StreamTaskInput create( Function> gatePartitioners, TaskInfo taskInfo, CanEmitBatchOfRecordsChecker canEmitBatchOfRecords, - Set> watermarkDeclarationSet) { + Set> watermarkDeclarationSet, + boolean unalignedDuringRecoveryEnabled) { return rescalingDescriptorinflightDataRescalingDescriptor.equals( - InflightDataRescalingDescriptor.NO_RESCALE) + InflightDataRescalingDescriptor.NO_RESCALE) + // When filter during recovery is enabled, records are already filtered in + // the channel-state-unspilling thread. Use StreamTaskNetworkInput to avoid + // redundant demultiplexing/filtering in the Task thread. + || unalignedDuringRecoveryEnabled ? new StreamTaskNetworkInput<>( checkpointedInputGate, inputSerializer, diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessorFactory.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessorFactory.java index 2a0c675710b34..03d2236325d0f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessorFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessorFactory.java @@ -22,6 +22,7 @@ import org.apache.flink.api.common.TaskInfo; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.configuration.CheckpointingOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.memory.ManagedMemoryUseCase; import org.apache.flink.metrics.Counter; @@ -84,6 +85,10 @@ public static StreamMultipleInputProcessor create( checkNotNull(operatorChain); taskIOMetricGroup.reuseRecordsInputCounter(numRecordsIn); + + boolean unalignedDuringRecoveryEnabled = + CheckpointingOptions.isUnalignedDuringRecoveryEnabled(jobConfig); + TypeSerializer typeSerializer1 = streamConfig.getTypeSerializerIn(0, userClassloader); StreamTaskInput input1 = StreamTaskNetworkInputFactory.create( @@ -96,7 +101,8 @@ public static StreamMultipleInputProcessor create( gatePartitioners, taskInfo, canEmitBatchOfRecords, - streamConfig.getWatermarkDeclarations(userClassloader)); + streamConfig.getWatermarkDeclarations(userClassloader), + unalignedDuringRecoveryEnabled); TypeSerializer typeSerializer2 = streamConfig.getTypeSerializerIn(1, userClassloader); StreamTaskInput input2 = StreamTaskNetworkInputFactory.create( @@ -109,7 +115,8 @@ public static StreamMultipleInputProcessor create( gatePartitioners, taskInfo, canEmitBatchOfRecords, - streamConfig.getWatermarkDeclarations(userClassloader)); + streamConfig.getWatermarkDeclarations(userClassloader), + unalignedDuringRecoveryEnabled); InputSelectable inputSelectable = streamOperator instanceof InputSelectable ? (InputSelectable) streamOperator : null; diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java new file mode 100644 index 0000000000000..e7edcb22ce6fc --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.io.recovery; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; + +import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Context containing all information needed for filtering recovered channel state buffers. + * + *

This context encapsulates the input configurations, rescaling descriptor, and subtask + * information required by the channel-state-unspilling thread to perform record filtering during + * recovery. + * + *

Supports multiple inputs (e.g., TwoInputStreamTask, MultipleInputStreamTask) by storing a list + * of {@link InputFilterConfig} instances indexed by input index. + * + *

Use the constructor with empty inputConfigs or enabled=false when filtering is not needed. + */ +@Internal +public class RecordFilterContext { + + /** Configuration for filtering records on a specific input. */ + public static class InputFilterConfig { + private final TypeSerializer typeSerializer; + private final StreamPartitioner partitioner; + private final int numberOfChannels; + + /** + * Creates a new InputFilterConfig. + * + * @param typeSerializer Serializer for the record type. + * @param partitioner Partitioner used to determine record ownership. + * @param numberOfChannels The parallelism of the current operator. + */ + public InputFilterConfig( + TypeSerializer typeSerializer, + StreamPartitioner partitioner, + int numberOfChannels) { + this.typeSerializer = checkNotNull(typeSerializer); + this.partitioner = checkNotNull(partitioner); + this.numberOfChannels = numberOfChannels; + } + + public TypeSerializer getTypeSerializer() { + return typeSerializer; + } + + public StreamPartitioner getPartitioner() { + return partitioner; + } + + public int getNumberOfChannels() { + return numberOfChannels; + } + } + + /** + * Input configurations indexed by gate index. Array elements may be null for non-network inputs + * (e.g., SourceInputConfig). The array length equals the total number of input gates. + */ + private final InputFilterConfig[] inputConfigs; + + /** Descriptor containing rescaling information. Never null. */ + private final InflightDataRescalingDescriptor rescalingDescriptor; + + /** Current subtask index. */ + private final int subtaskIndex; + + /** Maximum parallelism for configuring partitioners. */ + private final int maxParallelism; + + /** Temporary directories for spilling spanning records. Can be empty but never null. */ + private final String[] tmpDirectories; + + /** Whether unaligned checkpoint during recovery is enabled. */ + private final boolean unalignedDuringRecoveryEnabled; + + /** + * Creates a new RecordFilterContext. + * + * @param inputConfigs Input configurations indexed by gate index. Array elements may be null + * for non-network inputs. Not null itself. + * @param rescalingDescriptor Descriptor containing rescaling information. Not null. + * @param subtaskIndex Current subtask index. + * @param maxParallelism Maximum parallelism. + * @param tmpDirectories Temporary directories for spilling spanning records. Can be null + * (converted to empty array). + * @param unalignedDuringRecoveryEnabled Whether unaligned checkpoint during recovery is + * enabled. + */ + public RecordFilterContext( + InputFilterConfig[] inputConfigs, + InflightDataRescalingDescriptor rescalingDescriptor, + int subtaskIndex, + int maxParallelism, + String[] tmpDirectories, + boolean unalignedDuringRecoveryEnabled) { + this.inputConfigs = checkNotNull(inputConfigs).clone(); + this.rescalingDescriptor = checkNotNull(rescalingDescriptor); + this.subtaskIndex = subtaskIndex; + this.maxParallelism = maxParallelism; + this.tmpDirectories = tmpDirectories != null ? tmpDirectories : new String[0]; + this.unalignedDuringRecoveryEnabled = unalignedDuringRecoveryEnabled; + } + + /** + * Gets the input configuration for a specific gate. + * + * @param gateIndex The gate index (0-based). + * @return The input configuration for the specified gate, or null if the gate is not a network + * input (e.g., SourceInputConfig). + * @throws IllegalArgumentException if gateIndex is out of bounds. + */ + public InputFilterConfig getInputConfig(int gateIndex) { + checkArgument( + gateIndex >= 0 && gateIndex < inputConfigs.length, + "Invalid gate index: %s, number of gates: %s", + gateIndex, + inputConfigs.length); + return inputConfigs[gateIndex]; + } + + /** + * Gets the number of input gates. + * + * @return The number of input gates. + */ + public int getNumberOfGates() { + return inputConfigs.length; + } + + /** + * Checks whether unaligned checkpoint during recovery is enabled. + * + * @return {@code true} if enabled, {@code false} otherwise. + */ + public boolean isUnalignedDuringRecoveryEnabled() { + return unalignedDuringRecoveryEnabled; + } + + /** + * Gets the rescaling descriptor. + * + * @return The descriptor containing rescaling information. + */ + public InflightDataRescalingDescriptor getRescalingDescriptor() { + return rescalingDescriptor; + } + + /** + * Gets the current subtask index. + * + * @return The subtask index. + */ + public int getSubtaskIndex() { + return subtaskIndex; + } + + /** + * Gets the maximum parallelism. + * + * @return The maximum parallelism value. + */ + public int getMaxParallelism() { + return maxParallelism; + } + + /** + * Gets the temporary directories for spilling spanning records. + * + * @return The temporary directories, never null (may be empty array). + */ + public String[] getTmpDirectories() { + return tmpDirectories; + } + + /** + * Checks if a specific gate and subtask combination is ambiguous (requires filtering). + * + * @param gateIndex The gate index. + * @param oldSubtaskIndex The old subtask index. + * @return true if enabled and the channel is ambiguous and records need filtering. + */ + public boolean isAmbiguous(int gateIndex, int oldSubtaskIndex) { + return unalignedDuringRecoveryEnabled + && rescalingDescriptor.isAmbiguous(gateIndex, oldSubtaskIndex); + } + + /** + * Creates a disabled RecordFilterContext for testing or when filtering is not needed. + * + *

The returned context has empty inputConfigs and enabled=false, so {@link + * #isUnalignedDuringRecoveryEnabled()} will always return false. + * + * @return A disabled RecordFilterContext. + */ + public static RecordFilterContext disabled() { + return new RecordFilterContext( + new InputFilterConfig[0], + InflightDataRescalingDescriptor.NO_RESCALE, + 0, + 0, + new String[0], + false); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactory.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactory.java new file mode 100644 index 0000000000000..a2093f1143707 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactory.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.io.recovery; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.io.network.api.writer.ChannelSelector; +import org.apache.flink.runtime.plugable.SerializationDelegate; +import org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +/** + * Factory for creating record filters used in Virtual Channels during channel state recovery. + * + *

This factory provides methods to create {@link RecordFilter} instances that determine whether + * a record belongs to the current subtask based on the partitioner logic. + * + * @param The type of record values. + */ +@Internal +public class VirtualChannelRecordFilterFactory { + + private final TypeSerializer typeSerializer; + private final StreamPartitioner partitioner; + private final int subtaskIndex; + private final int numberOfChannels; + private final int maxParallelism; + + /** + * Creates a new VirtualChannelRecordFilterFactory. + * + * @param typeSerializer Serializer for the record type. + * @param partitioner Partitioner used to determine record ownership. + * @param subtaskIndex Current subtask index. + * @param numberOfChannels Number of parallel subtasks. + * @param maxParallelism Maximum parallelism for configuring partitioners. + */ + public VirtualChannelRecordFilterFactory( + TypeSerializer typeSerializer, + StreamPartitioner partitioner, + int subtaskIndex, + int numberOfChannels, + int maxParallelism) { + this.typeSerializer = typeSerializer; + this.partitioner = partitioner; + this.subtaskIndex = subtaskIndex; + this.numberOfChannels = numberOfChannels; + this.maxParallelism = maxParallelism; + } + + /** + * Creates a new VirtualChannelRecordFilterFactory from a RecordFilterContext and input index. + * + * @param context The record filter context. + * @param inputIndex The input index to get configuration from. + * @param The type of record values. + * @return A new factory instance. + */ + @SuppressWarnings("unchecked") + public static VirtualChannelRecordFilterFactory fromContext( + RecordFilterContext context, int inputIndex) { + RecordFilterContext.InputFilterConfig inputConfig = context.getInputConfig(inputIndex); + return new VirtualChannelRecordFilterFactory<>( + (TypeSerializer) inputConfig.getTypeSerializer(), + (StreamPartitioner) inputConfig.getPartitioner(), + context.getSubtaskIndex(), + inputConfig.getNumberOfChannels(), + context.getMaxParallelism()); + } + + /** + * Creates a record filter for ambiguous channels that requires actual filtering. + * + * @return A RecordFilter that tests if a record belongs to this subtask. + */ + public RecordFilter createFilter() { + StreamPartitioner configuredPartitioner = configurePartitioner(); + @SuppressWarnings("unchecked") + ChannelSelector>> channelSelector = + configuredPartitioner; + return new PartitionerRecordFilter<>(channelSelector, typeSerializer, subtaskIndex); + } + + /** + * Creates a pass-through filter that accepts all records. + * + * @param The type of record values. + * @return A RecordFilter that always returns true. + */ + public static RecordFilter createPassThroughFilter() { + return RecordFilter.acceptAll(); + } + + /** + * Configures the partitioner with the correct number of channels and max parallelism. + * + * @return A configured copy of the partitioner. + */ + private StreamPartitioner configurePartitioner() { + StreamPartitioner copy = partitioner.copy(); + copy.setup(numberOfChannels); + if (copy instanceof ConfigurableStreamPartitioner) { + ((ConfigurableStreamPartitioner) copy).configure(maxParallelism); + } + return copy; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java index 353582674e01a..676175514d922 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java @@ -203,6 +203,9 @@ private StreamTaskInput createTaskInput(CheckpointedInputGate inputGate) { Set> watermarkDeclarationSet = configuration.getWatermarkDeclarations(getUserCodeClassLoader()); + boolean unalignedDuringRecoveryEnabled = + CheckpointingOptions.isUnalignedDuringRecoveryEnabled(getJobConfiguration()); + return StreamTaskNetworkInputFactory.create( inputGate, inSerializer, @@ -217,7 +220,8 @@ private StreamTaskInput createTaskInput(CheckpointedInputGate inputGate) { .getPartitioner(), getEnvironment().getTaskInfo(), getCanEmitBatchOfRecords(), - watermarkDeclarationSet); + watermarkDeclarationSet, + unalignedDuringRecoveryEnabled); } /** diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index a07d5ee3915aa..5eb6e0ebea4ed 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -21,6 +21,7 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.operators.MailboxExecutor; import org.apache.flink.api.common.operators.ProcessingTimeService.ProcessingTimeCallback; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.configuration.CheckpointingOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.NettyShuffleEnvironmentOptions; @@ -84,6 +85,7 @@ import org.apache.flink.runtime.util.ConfigurationParserUtils; import org.apache.flink.streaming.api.graph.NonChainedOutput; import org.apache.flink.streaming.api.graph.StreamConfig; +import org.apache.flink.streaming.api.graph.StreamEdge; import org.apache.flink.streaming.api.operators.InternalTimeServiceManager; import org.apache.flink.streaming.api.operators.InternalTimeServiceManagerImpl; import org.apache.flink.streaming.api.operators.StreamOperator; @@ -94,6 +96,7 @@ import org.apache.flink.streaming.runtime.io.StreamInputProcessor; import org.apache.flink.streaming.runtime.io.checkpointing.BarrierAlignmentUtil; import org.apache.flink.streaming.runtime.io.checkpointing.CheckpointBarrierHandler; +import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext; import org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner; import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner; @@ -881,7 +884,7 @@ private CompletableFuture restoreStateAndGates( channelIOExecutor.execute( () -> { try { - reader.readInputData(inputGates); + reader.readInputData(inputGates, createRecordFilterContext()); } catch (Exception e) { asyncExceptionHandler.handleAsyncException( "Unable to read channel state", e); @@ -1956,6 +1959,83 @@ public final Environment getEnvironment() { return environment; } + /** + * Creates a RecordFilterContext for filtering recovered channel state buffers. + * + *

This method builds the complete context using information available in StreamTask, + * including input configurations for all network inputs. + * + * @return A RecordFilterContext with input configurations. The context may have empty + * inputConfigs (e.g., for source tasks) or enabled=false when filtering is not needed. + */ + protected RecordFilterContext createRecordFilterContext() { + boolean unalignedDuringRecoveryEnabled = + CheckpointingOptions.isUnalignedDuringRecoveryEnabled(getJobConfiguration()); + if (!unalignedDuringRecoveryEnabled) { + return RecordFilterContext.disabled(); + } + + ClassLoader cl = getUserCodeClassLoader(); + StreamConfig.InputConfig[] inputs = configuration.getInputs(cl); + List inEdges = configuration.getInPhysicalEdges(cl); + + // Create array sized to match the number of physical input gates. + // For source tasks, this will be 0. For tasks with network inputs, each physical gate + // must have a corresponding config entry. + int numGates = getEnvironment().getAllInputGates().length; + RecordFilterContext.InputFilterConfig[] inputConfigs = + new RecordFilterContext.InputFilterConfig[numGates]; + + Preconditions.checkState( + numGates == inEdges.size(), + "Number of input gates (%s) does not match number of physical edges (%s)", + numGates, + inEdges.size()); + + // Iterate through all physical edges (inEdges) instead of logical inputs. + // This is critical for Union scenarios where multiple physical gates map to one logical + // input. The order of inEdges matches the order of physical input gates. + int numberOfChannels = getEnvironment().getTaskInfo().getNumberOfParallelSubtasks(); + for (int gateIndex = 0; gateIndex < inEdges.size(); gateIndex++) { + StreamEdge edge = inEdges.get(gateIndex); + // Calculate logical input index from typeNumber + // typeNumber = 0 means single input, typeNumber >= 1 means multi-input (1-indexed) + int inputIndex = edge.getTypeNumber() == 0 ? 0 : edge.getTypeNumber() - 1; + + Preconditions.checkState( + inputIndex < inputs.length + && inputs[inputIndex] instanceof StreamConfig.NetworkInputConfig, + "Physical edge at gateIndex %s has invalid inputIndex %s or non-network input", + gateIndex, + inputIndex); + + StreamConfig.NetworkInputConfig networkInput = + (StreamConfig.NetworkInputConfig) inputs[inputIndex]; + TypeSerializer typeSerializer = networkInput.getTypeSerializer(); + StreamPartitioner partitioner = edge.getPartitioner(); + + inputConfigs[gateIndex] = + new RecordFilterContext.InputFilterConfig( + typeSerializer, partitioner, numberOfChannels); + } + + for (int i = 0; i < inputConfigs.length; i++) { + Preconditions.checkState( + inputConfigs[i] != null, + "InputFilterConfig at index %s is null. " + + "All physical gates must have corresponding configurations.", + i); + } + + return new RecordFilterContext( + inputConfigs, + getEnvironment().getTaskStateManager().getInputRescalingDescriptor(), + getEnvironment().getTaskInfo().getIndexOfThisSubtask(), + getEnvironment().getTaskInfo().getMaxNumberOfParallelSubtasks(), + getEnvironment().getIOManager().getSpillingDirectoriesPaths(), + true); + } + /** Check whether records can be emitted in batch. */ @FunctionalInterface public interface CanEmitBatchOfRecordsChecker { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java index e2b1d69c56ab1..39ce6c7d4bf7d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java @@ -77,7 +77,8 @@ private InputChannelRecoveredStateHandler buildInputChannelStateHandler( InflightDataRescalingDescriptor .InflightDataGateOrPartitionRescalingDescriptor .MappingType.IDENTITY) - })); + }), + null); } private InputChannelRecoveredStateHandler buildMultiChannelHandler() { @@ -103,7 +104,8 @@ private InputChannelRecoveredStateHandler buildMultiChannelHandler() { InflightDataRescalingDescriptor .InflightDataGateOrPartitionRescalingDescriptor .MappingType.RESCALING) - })); + }), + null); } @Test diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java index 5f05af0d92e59..d80442b8a06f8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java @@ -40,6 +40,7 @@ import org.apache.flink.runtime.state.InputChannelStateHandle; import org.apache.flink.runtime.state.ResultSubpartitionStateHandle; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; +import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext; import org.apache.flink.testutils.junit.extensions.parameterized.Parameter; import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension; import org.apache.flink.testutils.junit.extensions.parameterized.Parameters; @@ -143,7 +144,7 @@ void testReadPermutedState() throws Exception { withInputGates( gates -> { - reader.readInputData(gates); + reader.readInputData(gates, RecordFilterContext.disabled()); assertBuffersEquals(inputChannelsData, collectBuffers(gates)); assertConsumed(gates); }); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java index 53c05b1961309..f64a4d9fb9cac 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java @@ -49,6 +49,7 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage; +import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext; import org.apache.flink.util.function.SupplierWithException; import org.junit.jupiter.api.Test; @@ -119,7 +120,7 @@ void testReadWritten() throws Exception { try { int numChannels = 1; InputGate gate = buildGate(networkBufferPool, numChannels); - reader.readInputData(new InputGate[] {gate}); + reader.readInputData(new InputGate[] {gate}, RecordFilterContext.disabled()); assertThat(collectBytes(gate::pollNext, BufferOrEvent::getBuffer)) .isEqualTo(inputChannelInfoData); diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java new file mode 100644 index 0000000000000..27832b4c3c528 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.io.recovery; + +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor; +import org.apache.flink.runtime.checkpoint.RescaleMappings; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; + +import org.junit.jupiter.api.Test; + +import static org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptorUtil.mappings; +import static org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptorUtil.rescalingDescriptor; +import static org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptorUtil.set; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link RecordFilterContext}. */ +class RecordFilterContextTest { + + @Test + void testDisabledContextHasNoGates() { + RecordFilterContext disabled = RecordFilterContext.disabled(); + assertThat(disabled.getNumberOfGates()).isEqualTo(0); + assertThat(disabled.isUnalignedDuringRecoveryEnabled()).isFalse(); + } + + @Test + void testGetInputConfigReturnsCorrectConfig() { + RecordFilterContext.InputFilterConfig config = + new RecordFilterContext.InputFilterConfig( + LongSerializer.INSTANCE, new ForwardPartitioner<>(), 4); + + RecordFilterContext context = + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[] {config}, + InflightDataRescalingDescriptor.NO_RESCALE, + 0, + 128, + new String[] {"/tmp"}, + true); + + assertThat(context.getNumberOfGates()).isEqualTo(1); + assertThat(context.getInputConfig(0)).isSameAs(config); + assertThat(context.getSubtaskIndex()).isEqualTo(0); + assertThat(context.getMaxParallelism()).isEqualTo(128); + assertThat(context.isUnalignedDuringRecoveryEnabled()).isTrue(); + } + + @Test + void testGetInputConfigThrowsForInvalidIndex() { + RecordFilterContext context = + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[0], + InflightDataRescalingDescriptor.NO_RESCALE, + 0, + 128, + null, + false); + + assertThatThrownBy(() -> context.getInputConfig(0)) + .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> context.getInputConfig(-1)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + void testNullTmpDirectoriesConvertedToEmptyArray() { + RecordFilterContext context = + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[0], + InflightDataRescalingDescriptor.NO_RESCALE, + 0, + 128, + null, + false); + + assertThat(context.getTmpDirectories()).isNotNull().isEmpty(); + } + + @Test + void testIsAmbiguousWhenDisabled() { + // Create a rescaling descriptor with an ambiguous subtask (oldSubtask 0 is ambiguous) + RescaleMappings mapping = mappings(new int[] {0}); + InflightDataRescalingDescriptor descriptor = + rescalingDescriptor(new int[] {0}, new RescaleMappings[] {mapping}, set(0)); + + // When unalignedDuringRecoveryEnabled is false, isAmbiguous should always return false + RecordFilterContext context = + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[0], + descriptor, + 0, + 128, + null, + false); + + assertThat(context.isAmbiguous(0, 0)).isFalse(); + } + + @Test + void testIsAmbiguousWhenEnabled() { + // Create a rescaling descriptor with an ambiguous subtask (oldSubtask 0 is ambiguous) + RescaleMappings mapping = mappings(new int[] {0}); + InflightDataRescalingDescriptor descriptor = + rescalingDescriptor(new int[] {0}, new RescaleMappings[] {mapping}, set(0)); + + // When unalignedDuringRecoveryEnabled is true, isAmbiguous follows rescalingDescriptor + RecordFilterContext context = + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[0], + descriptor, + 0, + 128, + null, + true); + + assertThat(context.isAmbiguous(0, 0)).isTrue(); + } + + @Test + void testIsAmbiguousForNonAmbiguousSubtask() { + // Create a rescaling descriptor where oldSubtask 0 is ambiguous but oldSubtask 1 is not + RescaleMappings mapping = mappings(new int[] {0}); + InflightDataRescalingDescriptor descriptor = + rescalingDescriptor(new int[] {0, 1}, new RescaleMappings[] {mapping}, set(0)); + + RecordFilterContext context = + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[0], + descriptor, + 0, + 128, + null, + true); + + // oldSubtask 0 is ambiguous + assertThat(context.isAmbiguous(0, 0)).isTrue(); + // oldSubtask 1 is NOT in the ambiguous set + assertThat(context.isAmbiguous(0, 1)).isFalse(); + } + + @Test + void testInputFilterConfigGetters() { + ForwardPartitioner partitioner = new ForwardPartitioner<>(); + RecordFilterContext.InputFilterConfig config = + new RecordFilterContext.InputFilterConfig(LongSerializer.INSTANCE, partitioner, 4); + + assertThat(config.getTypeSerializer()).isSameAs(LongSerializer.INSTANCE); + assertThat(config.getPartitioner()).isSameAs(partitioner); + assertThat(config.getNumberOfChannels()).isEqualTo(4); + } + + @Test + void testMultipleGateConfigs() { + RecordFilterContext.InputFilterConfig config0 = + new RecordFilterContext.InputFilterConfig( + LongSerializer.INSTANCE, new ForwardPartitioner<>(), 2); + RecordFilterContext.InputFilterConfig config1 = + new RecordFilterContext.InputFilterConfig( + LongSerializer.INSTANCE, new ForwardPartitioner<>(), 4); + + RecordFilterContext context = + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[] {config0, config1}, + InflightDataRescalingDescriptor.NO_RESCALE, + 1, + 256, + new String[] {"/tmp"}, + false); + + assertThat(context.getNumberOfGates()).isEqualTo(2); + assertThat(context.getInputConfig(0)).isSameAs(config0); + assertThat(context.getInputConfig(1)).isSameAs(config1); + assertThat(context.getInputConfig(0).getNumberOfChannels()).isEqualTo(2); + assertThat(context.getInputConfig(1).getNumberOfChannels()).isEqualTo(4); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactoryTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactoryTest.java new file mode 100644 index 0000000000000..bffc42e33294d --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactoryTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.io.recovery; + +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor; +import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link VirtualChannelRecordFilterFactory}. */ +class VirtualChannelRecordFilterFactoryTest { + + @Test + void testCreatePassThroughFilter() { + RecordFilter filter = VirtualChannelRecordFilterFactory.createPassThroughFilter(); + assertThat(filter.filter(new StreamRecord<>(0L))).isTrue(); + assertThat(filter.filter(new StreamRecord<>(1L))).isTrue(); + assertThat(filter.filter(new StreamRecord<>(42L))).isTrue(); + } + + @Test + void testCreateFilterProducesPartitionerBasedFilter() { + RebalancePartitioner partitioner = new RebalancePartitioner<>(); + + VirtualChannelRecordFilterFactory factory = + new VirtualChannelRecordFilterFactory<>( + LongSerializer.INSTANCE, partitioner, 0, 2, 128); + + RecordFilter filter = factory.createFilter(); + // The filter should be a PartitionerRecordFilter that filters based on partitioner + assertThat(filter).isInstanceOf(PartitionerRecordFilter.class); + } + + @Test + void testFromContextCreatesFactory() { + RebalancePartitioner partitioner = new RebalancePartitioner<>(); + RecordFilterContext.InputFilterConfig config = + new RecordFilterContext.InputFilterConfig(LongSerializer.INSTANCE, partitioner, 4); + + RecordFilterContext context = + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[] {config}, + InflightDataRescalingDescriptor.NO_RESCALE, + 1, + 128, + new String[] {"/tmp"}, + true); + + VirtualChannelRecordFilterFactory factory = + VirtualChannelRecordFilterFactory.fromContext(context, 0); + RecordFilter filter = factory.createFilter(); + + // The filter should be a functional PartitionerRecordFilter + assertThat(filter).isInstanceOf(PartitionerRecordFilter.class); + } + + @Test + void testEachFilterCallCreatesIndependentFilter() { + RebalancePartitioner partitioner = new RebalancePartitioner<>(); + + VirtualChannelRecordFilterFactory factory = + new VirtualChannelRecordFilterFactory<>( + LongSerializer.INSTANCE, partitioner, 0, 2, 128); + + RecordFilter filter1 = factory.createFilter(); + RecordFilter filter2 = factory.createFilter(); + + // Each call should produce a distinct filter instance (using a copy of the partitioner) + assertThat(filter1).isNotSameAs(filter2); + } +}