diff --git a/docs/layouts/shortcodes/generated/common_checkpointing_section.html b/docs/layouts/shortcodes/generated/common_checkpointing_section.html
index 9db118174a328..f942529b79c71 100644
--- a/docs/layouts/shortcodes/generated/common_checkpointing_section.html
+++ b/docs/layouts/shortcodes/generated/common_checkpointing_section.html
@@ -56,5 +56,11 @@
UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM =
ConfigOptions.key(
"execution.checkpointing.unaligned.recover-output-on-downstream.enabled")
@@ -668,6 +669,40 @@ public class CheckpointingOptions {
"Whether recovering output buffers of upstream task on downstream task directly "
+ "when job restores from the unaligned checkpoint.");
+ /**
+ * Whether to enable checkpointing during recovery from an unaligned checkpoint.
+ *
+ * When enabled, the job can take checkpoints while still recovering channel state (inflight
+ * data) from a previous unaligned checkpoint. This avoids the need to wait for full recovery
+ * before the first checkpoint can be triggered, which reduces the window of vulnerability to
+ * failures during recovery.
+ *
+ *
This option requires {@link #UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM} to be enabled. It
+ * does not require unaligned checkpoints to be currently enabled, because a job may restore
+ * from an unaligned checkpoint while having unaligned checkpoints disabled for the new
+ * execution.
+ */
+ @Experimental
+ @Documentation.ExcludeFromDocumentation(
+ "This option is not yet ready for public use, will be documented in a follow-up commit")
+ public static final ConfigOption UNALIGNED_DURING_RECOVERY_ENABLED =
+ ConfigOptions.key("execution.checkpointing.unaligned.during-recovery.enabled")
+ .booleanType()
+ .defaultValue(false)
+ .withDescription(
+ Description.builder()
+ .text(
+ "Whether to enable checkpointing during recovery from an unaligned checkpoint. "
+ + "When enabled, the job can take checkpoints while still recovering channel state "
+ + "(inflight data), reducing the window of vulnerability to failures during recovery.")
+ .linebreak()
+ .linebreak()
+ .text(
+ "This option requires %s to be enabled.",
+ TextElement.code(
+ UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM.key()))
+ .build());
+
/**
* Determines whether checkpointing is enabled based on the configuration.
*
@@ -763,4 +798,24 @@ public static boolean isUnalignedCheckpointInterruptibleTimersEnabled(Configurat
}
return config.get(ENABLE_UNALIGNED_INTERRUPTIBLE_TIMERS);
}
+
+ /**
+ * Determines whether unaligned checkpoint support during recovery is enabled.
+ *
+ * This feature requires {@link #UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM} to be enabled. Note
+ * that it does not require unaligned checkpoints to be currently enabled, because a job may
+ * restore from an unaligned checkpoint while having unaligned checkpoints disabled for the new
+ * execution.
+ *
+ * @param config the configuration to check
+ * @return {@code true} if unaligned checkpoint during recovery is enabled, {@code false}
+ * otherwise
+ */
+ @Internal
+ public static boolean isUnalignedDuringRecoveryEnabled(Configuration config) {
+ if (!config.get(UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM)) {
+ return false;
+ }
+ return config.get(UNALIGNED_DURING_RECOVERY_ENABLED);
+ }
}
diff --git a/flink-core/src/test/java/org/apache/flink/configuration/CheckpointingOptionsTest.java b/flink-core/src/test/java/org/apache/flink/configuration/CheckpointingOptionsTest.java
index 4104dfb762b87..747f789af3599 100644
--- a/flink-core/src/test/java/org/apache/flink/configuration/CheckpointingOptionsTest.java
+++ b/flink-core/src/test/java/org/apache/flink/configuration/CheckpointingOptionsTest.java
@@ -328,4 +328,50 @@ void testIsUnalignedCheckpointInterruptibleTimersEnabled() {
.as("Interruptible timers should be disabled when runtime mode is BATCH")
.isFalse();
}
+
+ @Test
+ void testIsUnalignedDuringRecoveryEnabled() {
+ // Test when both options are disabled (default) - should return false
+ Configuration defaultConfig = new Configuration();
+ assertThat(CheckpointingOptions.isUnalignedDuringRecoveryEnabled(defaultConfig))
+ .as("During-recovery should be disabled by default")
+ .isFalse();
+
+ // Test when during-recovery is enabled but recover-output-on-downstream is disabled
+ Configuration onlyDuringRecoveryConfig = new Configuration();
+ onlyDuringRecoveryConfig.set(CheckpointingOptions.UNALIGNED_DURING_RECOVERY_ENABLED, true);
+ assertThat(CheckpointingOptions.isUnalignedDuringRecoveryEnabled(onlyDuringRecoveryConfig))
+ .as(
+ "During-recovery should be disabled when recover-output-on-downstream is not enabled")
+ .isFalse();
+
+ // Test when recover-output-on-downstream is enabled but during-recovery is disabled
+ Configuration onlyRecoverOnDownstreamConfig = new Configuration();
+ onlyRecoverOnDownstreamConfig.set(
+ CheckpointingOptions.UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM, true);
+ assertThat(
+ CheckpointingOptions.isUnalignedDuringRecoveryEnabled(
+ onlyRecoverOnDownstreamConfig))
+ .as("During-recovery should be disabled when during-recovery option is not enabled")
+ .isFalse();
+
+ // Test when both options are enabled - should return true
+ Configuration bothEnabledConfig = new Configuration();
+ bothEnabledConfig.set(CheckpointingOptions.UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM, true);
+ bothEnabledConfig.set(CheckpointingOptions.UNALIGNED_DURING_RECOVERY_ENABLED, true);
+ assertThat(CheckpointingOptions.isUnalignedDuringRecoveryEnabled(bothEnabledConfig))
+ .as(
+ "During-recovery should be enabled when both recover-output-on-downstream and during-recovery are enabled")
+ .isTrue();
+
+ // Test when recover-output-on-downstream is explicitly false and during-recovery is true
+ Configuration explicitlyDisabledConfig = new Configuration();
+ explicitlyDisabledConfig.set(
+ CheckpointingOptions.UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM, false);
+ explicitlyDisabledConfig.set(CheckpointingOptions.UNALIGNED_DURING_RECOVERY_ENABLED, true);
+ assertThat(CheckpointingOptions.isUnalignedDuringRecoveryEnabled(explicitlyDisabledConfig))
+ .as(
+ "During-recovery should be disabled when recover-output-on-downstream is explicitly false")
+ .isFalse();
+ }
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java
new file mode 100644
index 0000000000000..fe4feb7e05e7a
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java
@@ -0,0 +1,418 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint.channel;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.runtime.checkpoint.RescaleMappings;
+import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
+import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
+import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult;
+import org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
+import org.apache.flink.runtime.plugable.DeserializationDelegate;
+import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilter;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext;
+import org.apache.flink.streaming.runtime.io.recovery.VirtualChannel;
+import org.apache.flink.streaming.runtime.io.recovery.VirtualChannelRecordFilterFactory;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * Filters recovered channel state buffers during the channel-state-unspilling phase, removing
+ * records that do not belong to the current subtask after rescaling.
+ *
+ *
Uses a per-gate architecture: each {@link InputGate} gets its own {@link GateFilterHandler}
+ * with the correct serializer, so multi-input tasks (e.g., TwoInputStreamTask) correctly
+ * deserialize different record types on different gates.
+ */
+@Internal
+public class ChannelStateFilteringHandler {
+
+ /**
+ * Handles record filtering for a single input gate. Each gate has its own serializer and set of
+ * virtual channels, allowing different gates to handle different record types independently.
+ */
+ static class GateFilterHandler {
+
+ private final Map> virtualChannels;
+ private final StreamElementSerializer serializer;
+ private final DeserializationDelegate deserializationDelegate;
+ private final DataOutputSerializer outputSerializer;
+ private final byte[] lengthBuffer = new byte[4];
+
+ GateFilterHandler(
+ Map> virtualChannels,
+ StreamElementSerializer serializer) {
+ this.virtualChannels = checkNotNull(virtualChannels);
+ this.serializer = checkNotNull(serializer);
+ this.deserializationDelegate = new NonReusingDeserializationDelegate<>(serializer);
+ this.outputSerializer = new DataOutputSerializer(128);
+ }
+
+ /**
+ * Deserializes records from {@code sourceBuffer}, applies the virtual channel's record
+ * filter, and re-serializes the surviving records into new buffers.
+ */
+ List filterAndRewrite(
+ int oldSubtaskIndex,
+ int oldChannelIndex,
+ Buffer sourceBuffer,
+ BufferSupplier bufferSupplier)
+ throws IOException, InterruptedException {
+
+ SubtaskConnectionDescriptor key =
+ new SubtaskConnectionDescriptor(oldSubtaskIndex, oldChannelIndex);
+ VirtualChannel vc = virtualChannels.get(key);
+ if (vc == null) {
+ throw new IllegalStateException(
+ "No VirtualChannel found for key: "
+ + key
+ + "; known channels are "
+ + virtualChannels.keySet());
+ }
+
+ vc.setNextBuffer(sourceBuffer);
+
+ List filteredElements = new ArrayList<>();
+
+ while (true) {
+ DeserializationResult result = vc.getNextRecord(deserializationDelegate);
+ if (result.isFullRecord()) {
+ filteredElements.add(deserializationDelegate.getInstance());
+ }
+ if (result.isBufferConsumed()) {
+ break;
+ }
+ }
+
+ return serializeToBuffers(filteredElements, bufferSupplier);
+ }
+
+ /**
+ * Serializes stream elements into buffers using the length-prefixed format (4-byte
+ * big-endian length + record bytes) expected by Flink's record deserializers.
+ */
+ private List serializeToBuffers(
+ List elements, BufferSupplier bufferSupplier)
+ throws IOException, InterruptedException {
+
+ List resultBuffers = new ArrayList<>();
+
+ if (elements.isEmpty()) {
+ return resultBuffers;
+ }
+
+ Buffer currentBuffer = bufferSupplier.requestBufferBlocking();
+
+ for (StreamElement element : elements) {
+ outputSerializer.clear();
+ serializer.serialize(element, outputSerializer);
+ int recordLength = outputSerializer.length();
+
+ writeLengthToBuffer(recordLength);
+ currentBuffer =
+ writeDataToBuffer(
+ lengthBuffer, 0, 4, currentBuffer, resultBuffers, bufferSupplier);
+
+ byte[] serializedData = outputSerializer.getSharedBuffer();
+ currentBuffer =
+ writeDataToBuffer(
+ serializedData,
+ 0,
+ recordLength,
+ currentBuffer,
+ resultBuffers,
+ bufferSupplier);
+ }
+
+ if (currentBuffer.readableBytes() > 0) {
+ resultBuffers.add(currentBuffer.retainBuffer());
+ }
+ currentBuffer.recycleBuffer();
+
+ return resultBuffers;
+ }
+
+ private void writeLengthToBuffer(int length) {
+ lengthBuffer[0] = (byte) (length >> 24);
+ lengthBuffer[1] = (byte) (length >> 16);
+ lengthBuffer[2] = (byte) (length >> 8);
+ lengthBuffer[3] = (byte) length;
+ }
+
+ /**
+ * Writes data to the current buffer, spilling into new buffers from {@code bufferSupplier}
+ * when the current one is full.
+ *
+ * @return the buffer to continue writing into (may differ from the input buffer).
+ */
+ private Buffer writeDataToBuffer(
+ byte[] data,
+ int dataOffset,
+ int dataLength,
+ Buffer currentBuffer,
+ List resultBuffers,
+ BufferSupplier bufferSupplier)
+ throws IOException, InterruptedException {
+ int offset = dataOffset;
+ int remaining = dataLength;
+
+ while (remaining > 0) {
+ int writableBytes = currentBuffer.getMaxCapacity() - currentBuffer.getSize();
+
+ if (writableBytes == 0) {
+ if (currentBuffer.readableBytes() > 0) {
+ resultBuffers.add(currentBuffer.retainBuffer());
+ }
+ currentBuffer.recycleBuffer();
+ currentBuffer = bufferSupplier.requestBufferBlocking();
+ writableBytes = currentBuffer.getMaxCapacity();
+ }
+
+ int bytesToWrite = Math.min(remaining, writableBytes);
+ currentBuffer
+ .getMemorySegment()
+ .put(
+ currentBuffer.getMemorySegmentOffset() + currentBuffer.getSize(),
+ data,
+ offset,
+ bytesToWrite);
+ currentBuffer.setSize(currentBuffer.getSize() + bytesToWrite);
+
+ offset += bytesToWrite;
+ remaining -= bytesToWrite;
+ }
+ return currentBuffer;
+ }
+
+ boolean hasPartialData() {
+ return virtualChannels.values().stream().anyMatch(VirtualChannel::hasPartialData);
+ }
+
+ void clear() {
+ virtualChannels.values().forEach(VirtualChannel::clear);
+ }
+ }
+
+ // Wildcard allows heterogeneous record types across gates.
+ private final GateFilterHandler>[] gateHandlers;
+
+ ChannelStateFilteringHandler(GateFilterHandler>[] gateHandlers) {
+ this.gateHandlers = checkNotNull(gateHandlers);
+ }
+
+ /**
+ * Creates a handler from the recovery context, building per-gate virtual channels based on
+ * rescaling descriptors. Returns {@code null} if no filtering is needed (e.g., source tasks or
+ * no rescaling).
+ */
+ @Nullable
+ public static ChannelStateFilteringHandler createFromContext(
+ RecordFilterContext filterContext, InputGate[] inputGates) {
+ // Source tasks have no network inputs
+ if (filterContext.getNumberOfGates() == 0) {
+ return null;
+ }
+
+ InflightDataRescalingDescriptor rescalingDescriptor =
+ filterContext.getRescalingDescriptor();
+
+ GateFilterHandler>[] gateHandlers = new GateFilterHandler>[inputGates.length];
+ boolean hasAnyVirtualChannels = false;
+
+ for (int gateIndex = 0; gateIndex < inputGates.length; gateIndex++) {
+ gateHandlers[gateIndex] =
+ createGateHandler(filterContext, inputGates, rescalingDescriptor, gateIndex);
+ if (gateHandlers[gateIndex] != null) {
+ hasAnyVirtualChannels = true;
+ }
+ }
+
+ if (!hasAnyVirtualChannels) {
+ return null;
+ }
+
+ return new ChannelStateFilteringHandler(gateHandlers);
+ }
+
+ /**
+ * Creates a {@link GateFilterHandler} for a single gate. The method-level type parameter
+ * ensures type safety within each gate while allowing different gates to have different types.
+ */
+ @SuppressWarnings("unchecked")
+ @Nullable
+ private static GateFilterHandler createGateHandler(
+ RecordFilterContext filterContext,
+ InputGate[] inputGates,
+ InflightDataRescalingDescriptor rescalingDescriptor,
+ int gateIndex) {
+ RecordFilterContext.InputFilterConfig inputConfig = filterContext.getInputConfig(gateIndex);
+ if (inputConfig == null) {
+ throw new IllegalStateException(
+ "No InputFilterConfig for gateIndex "
+ + gateIndex
+ + ". This indicates a bug in RecordFilterContext initialization.");
+ }
+
+ InputGate gate = inputGates[gateIndex];
+ int[] oldSubtaskIndexes = rescalingDescriptor.getOldSubtaskIndexes(gateIndex);
+ RescaleMappings channelMapping = rescalingDescriptor.getChannelMapping(gateIndex);
+
+ TypeSerializer typeSerializer = (TypeSerializer) inputConfig.getTypeSerializer();
+ StreamElementSerializer elementSerializer =
+ new StreamElementSerializer<>(typeSerializer);
+
+ VirtualChannelRecordFilterFactory filterFactory =
+ VirtualChannelRecordFilterFactory.fromContext(filterContext, gateIndex);
+
+ Map> gateVirtualChannels = new HashMap<>();
+
+ for (int oldSubtaskIndex : oldSubtaskIndexes) {
+ int numChannels = gate.getNumberOfInputChannels();
+ int[] oldChannelIndexes = getOldChannelIndexes(channelMapping, numChannels);
+
+ for (int oldChannelIndex : oldChannelIndexes) {
+ SubtaskConnectionDescriptor key =
+ new SubtaskConnectionDescriptor(oldSubtaskIndex, oldChannelIndex);
+
+ if (gateVirtualChannels.containsKey(key)) {
+ continue;
+ }
+
+ // Only ambiguous channels need actual filtering; non-ambiguous ones pass through
+ boolean isAmbiguous = rescalingDescriptor.isAmbiguous(gateIndex, oldSubtaskIndex);
+
+ RecordFilter recordFilter =
+ isAmbiguous
+ ? filterFactory.createFilter()
+ : VirtualChannelRecordFilterFactory.createPassThroughFilter();
+
+ RecordDeserializer> deserializer =
+ createDeserializer(filterContext.getTmpDirectories());
+
+ VirtualChannel vc = new VirtualChannel<>(deserializer, recordFilter);
+ gateVirtualChannels.put(key, vc);
+ }
+ }
+
+ if (gateVirtualChannels.isEmpty()) {
+ return null;
+ }
+
+ return new GateFilterHandler<>(gateVirtualChannels, elementSerializer);
+ }
+
+ /**
+ * Collects all old channel indexes that are mapped from any new channel index in this gate.
+ * channelMapping is new-to-old, so we iterate new indexes and collect their old counterparts.
+ */
+ private static int[] getOldChannelIndexes(RescaleMappings channelMapping, int numChannels) {
+ List oldIndexes = new ArrayList<>();
+ for (int newIndex = 0; newIndex < numChannels; newIndex++) {
+ int[] mapped = channelMapping.getMappedIndexes(newIndex);
+ for (int oldIndex : mapped) {
+ if (!oldIndexes.contains(oldIndex)) {
+ oldIndexes.add(oldIndex);
+ }
+ }
+ }
+ return oldIndexes.stream().mapToInt(Integer::intValue).toArray();
+ }
+
+ private static RecordDeserializer> createDeserializer(
+ String[] tmpDirectories) {
+ if (tmpDirectories != null && tmpDirectories.length > 0) {
+ return new SpillingAdaptiveSpanningRecordDeserializer<>(tmpDirectories);
+ } else {
+ String[] defaultDirs = new String[] {System.getProperty("java.io.tmpdir")};
+ return new SpillingAdaptiveSpanningRecordDeserializer<>(defaultDirs);
+ }
+ }
+
+ /**
+ * Filters a recovered buffer from the specified virtual channel, returning new buffers
+ * containing only the records that belong to the current subtask.
+ *
+ * @return filtered buffers, possibly empty if all records were filtered out.
+ */
+ public List filterAndRewrite(
+ int gateIndex,
+ int oldSubtaskIndex,
+ int oldChannelIndex,
+ Buffer sourceBuffer,
+ BufferSupplier bufferSupplier)
+ throws IOException, InterruptedException {
+
+ if (gateIndex < 0 || gateIndex >= gateHandlers.length) {
+ throw new IllegalStateException(
+ "Invalid gateIndex: "
+ + gateIndex
+ + ", number of gates: "
+ + gateHandlers.length);
+ }
+
+ GateFilterHandler> gateHandler = gateHandlers[gateIndex];
+ if (gateHandler == null) {
+ throw new IllegalStateException(
+ "No handler for gateIndex "
+ + gateIndex
+ + ". This gate is not a network input and should not have recovered buffers.");
+ }
+ return gateHandler.filterAndRewrite(
+ oldSubtaskIndex, oldChannelIndex, sourceBuffer, bufferSupplier);
+ }
+
+ /** Returns {@code true} if any virtual channel has a partial (spanning) record pending. */
+ public boolean hasPartialData() {
+ for (GateFilterHandler> handler : gateHandlers) {
+ if (handler != null && handler.hasPartialData()) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ public void clear() {
+ for (GateFilterHandler> handler : gateHandlers) {
+ if (handler != null) {
+ handler.clear();
+ }
+ }
+ }
+
+ /** Provides buffers for re-serializing filtered records. Implementations may block. */
+ @FunctionalInterface
+ public interface BufferSupplier {
+ Buffer requestBufferBlocking() throws IOException, InterruptedException;
+ }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java
index 85fc31db4bc69..1c3b19e33d3ba 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java
@@ -31,6 +31,7 @@
import org.apache.flink.runtime.io.network.partition.consumer.RecoveredInputChannel;
import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
import java.io.IOException;
import java.util.HashMap;
@@ -63,7 +64,7 @@ public void close() {
* case of an error.
*/
void recover(Info info, int oldSubtaskIndex, BufferWithContext bufferWithContext)
- throws IOException;
+ throws IOException, InterruptedException;
}
class InputChannelRecoveredStateHandler
@@ -75,10 +76,19 @@ class InputChannelRecoveredStateHandler
private final Map rescaledChannels = new HashMap<>();
private final Map oldToNewMappings = new HashMap<>();
+ /**
+ * Optional filtering handler for filtering recovered buffers. When non-null, filtering is
+ * performed during recovery in the channel-state-unspilling thread.
+ */
+ @Nullable private final ChannelStateFilteringHandler filteringHandler;
+
InputChannelRecoveredStateHandler(
- InputGate[] inputGates, InflightDataRescalingDescriptor channelMapping) {
+ InputGate[] inputGates,
+ InflightDataRescalingDescriptor channelMapping,
+ @Nullable ChannelStateFilteringHandler filteringHandler) {
this.inputGates = inputGates;
this.channelMapping = channelMapping;
+ this.filteringHandler = filteringHandler;
}
@Override
@@ -95,23 +105,60 @@ public void recover(
InputChannelInfo channelInfo,
int oldSubtaskIndex,
BufferWithContext bufferWithContext)
- throws IOException {
+ throws IOException, InterruptedException {
Buffer buffer = bufferWithContext.context;
try {
if (buffer.readableBytes() > 0) {
RecoveredInputChannel channel = getMappedChannels(channelInfo);
- channel.onRecoveredStateBuffer(
- EventSerializer.toBuffer(
- new SubtaskConnectionDescriptor(
- oldSubtaskIndex, channelInfo.getInputChannelIdx()),
- false));
- channel.onRecoveredStateBuffer(buffer.retainBuffer());
+
+ if (filteringHandler != null) {
+ // Filtering mode: filter records and rewrite to new buffers
+ recoverWithFiltering(channel, channelInfo, oldSubtaskIndex, buffer);
+ } else {
+ // Non-filtering mode: pass through original buffer with descriptor
+ channel.onRecoveredStateBuffer(
+ EventSerializer.toBuffer(
+ new SubtaskConnectionDescriptor(
+ oldSubtaskIndex, channelInfo.getInputChannelIdx()),
+ false));
+ channel.onRecoveredStateBuffer(buffer.retainBuffer());
+ }
}
} finally {
buffer.recycleBuffer();
}
}
+ private void recoverWithFiltering(
+ RecoveredInputChannel channel,
+ InputChannelInfo channelInfo,
+ int oldSubtaskIndex,
+ Buffer buffer)
+ throws IOException, InterruptedException {
+ checkState(filteringHandler != null, "filtering handler not set.");
+ // Extra retain: filterAndRewrite consumes one ref, caller's finally releases another.
+ buffer.retainBuffer();
+
+ List filteredBuffers;
+ try {
+ filteredBuffers =
+ filteringHandler.filterAndRewrite(
+ channelInfo.getGateIdx(),
+ oldSubtaskIndex,
+ channelInfo.getInputChannelIdx(),
+ buffer,
+ channel::requestBufferBlocking);
+ } catch (Throwable t) {
+ // filterAndRewrite didn't consume the buffer, release the extra ref.
+ buffer.recycleBuffer();
+ throw t;
+ }
+
+ for (Buffer filteredBuffer : filteredBuffers) {
+ channel.onRecoveredStateBuffer(filteredBuffer);
+ }
+ }
+
@Override
public void close() throws IOException {
// note that we need to finish all RecoveredInputChannels, not just those with state
@@ -191,7 +238,7 @@ public void recover(
ResultSubpartitionInfo subpartitionInfo,
int oldSubtaskIndex,
BufferWithContext bufferWithContext)
- throws IOException {
+ throws IOException, InterruptedException {
try (BufferBuilder bufferBuilder = bufferWithContext.context;
BufferConsumer bufferConsumer = bufferBuilder.createBufferConsumerFromBeginning()) {
bufferBuilder.finish();
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java
index 7adf6d6294671..547b60ef93aee 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java
@@ -20,6 +20,7 @@
import org.apache.flink.annotation.Internal;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext;
import java.io.IOException;
@@ -27,7 +28,14 @@
@Internal
public interface SequentialChannelStateReader extends AutoCloseable {
- void readInputData(InputGate[] inputGates) throws IOException, InterruptedException;
+ /**
+ * Reads input channel state with filtering support.
+ *
+ * @param inputGates The input gates to recover state for.
+ * @param filterContext The filter context containing input configs and rescaling info.
+ */
+ void readInputData(InputGate[] inputGates, RecordFilterContext filterContext)
+ throws IOException, InterruptedException;
void readOutputData(ResultPartitionWriter[] writers, boolean notifyAndBlockOnCompletion)
throws IOException, InterruptedException;
@@ -39,7 +47,8 @@ void readOutputData(ResultPartitionWriter[] writers, boolean notifyAndBlockOnCom
new SequentialChannelStateReader() {
@Override
- public void readInputData(InputGate[] inputGates) {}
+ public void readInputData(
+ InputGate[] inputGates, RecordFilterContext filterContext) {}
@Override
public void readOutputData(
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java
index 3daa4b4947a61..ca0ae8a8385a5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java
@@ -28,6 +28,7 @@
import org.apache.flink.runtime.state.AbstractChannelStateHandle;
import org.apache.flink.runtime.state.ChannelStateHelper;
import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext;
import java.io.Closeable;
import java.io.IOException;
@@ -43,6 +44,7 @@
import static java.util.Comparator.comparingLong;
import static java.util.stream.Collectors.groupingBy;
import static java.util.stream.Collectors.toList;
+import static org.apache.flink.util.Preconditions.checkState;
/** {@link SequentialChannelStateReader} implementation. */
public class SequentialChannelStateReaderImpl implements SequentialChannelStateReader {
@@ -58,10 +60,21 @@ public SequentialChannelStateReaderImpl(TaskStateSnapshot taskStateSnapshot) {
}
@Override
- public void readInputData(InputGate[] inputGates) throws IOException, InterruptedException {
+ public void readInputData(InputGate[] inputGates, RecordFilterContext filterContext)
+ throws IOException, InterruptedException {
+
+ // Create filtering handler if filtering is needed
+ ChannelStateFilteringHandler filteringHandler = null;
+ if (filterContext.isUnalignedDuringRecoveryEnabled()) {
+ filteringHandler =
+ ChannelStateFilteringHandler.createFromContext(filterContext, inputGates);
+ }
+
try (InputChannelRecoveredStateHandler stateHandler =
new InputChannelRecoveredStateHandler(
- inputGates, taskStateSnapshot.getInputRescalingDescriptor())) {
+ inputGates,
+ taskStateSnapshot.getInputRescalingDescriptor(),
+ filteringHandler)) {
read(
stateHandler,
groupByDelegate(
@@ -72,6 +85,18 @@ public void readInputData(InputGate[] inputGates) throws IOException, Interrupte
groupByDelegate(
streamSubtaskStates(),
OperatorSubtaskState::getUpstreamOutputBufferState));
+
+ if (filteringHandler != null) {
+ checkState(
+ !filteringHandler.hasPartialData(),
+ "Not all data has been fully consumed during filtering");
+ }
+ } finally {
+ // Clean up filtering handler resources (e.g., temp files from
+ // SpillingAdaptiveSpanningRecordDeserializer) on both success and error paths
+ if (filteringHandler != null) {
+ filteringHandler.clear();
+ }
}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamMultipleInputProcessorFactory.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamMultipleInputProcessorFactory.java
index 78877a3d62ec6..db92afd2f4294 100644
--- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamMultipleInputProcessorFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamMultipleInputProcessorFactory.java
@@ -23,6 +23,7 @@
import org.apache.flink.api.common.TaskInfo;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.configuration.CheckpointingOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.metrics.Counter;
@@ -103,6 +104,10 @@ public static StreamMultipleInputProcessor create(
"Number of configured inputs in StreamConfig [%s] doesn't match the main operator's number of inputs [%s]",
configuredInputs.length,
inputsCount);
+
+ boolean unalignedDuringRecoveryEnabled =
+ CheckpointingOptions.isUnalignedDuringRecoveryEnabled(jobConfig);
+
StreamTaskInput[] inputs = new StreamTaskInput[inputsCount];
for (int i = 0; i < inputsCount; i++) {
StreamConfig.InputConfig configuredInput = configuredInputs[i];
@@ -121,7 +126,8 @@ public static StreamMultipleInputProcessor create(
gatePartitioners,
taskInfo,
canEmitBatchOfRecords,
- streamConfig.getWatermarkDeclarations(userClassloader));
+ streamConfig.getWatermarkDeclarations(userClassloader),
+ unalignedDuringRecoveryEnabled);
} else if (configuredInput instanceof StreamConfig.SourceInputConfig) {
StreamConfig.SourceInputConfig sourceInput =
(StreamConfig.SourceInputConfig) configuredInput;
diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputFactory.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputFactory.java
index 46c9cd96936e3..7718419f6bb55 100644
--- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputFactory.java
@@ -47,9 +47,14 @@ public static StreamTaskInput create(
Function> gatePartitioners,
TaskInfo taskInfo,
CanEmitBatchOfRecordsChecker canEmitBatchOfRecords,
- Set> watermarkDeclarationSet) {
+ Set> watermarkDeclarationSet,
+ boolean unalignedDuringRecoveryEnabled) {
return rescalingDescriptorinflightDataRescalingDescriptor.equals(
- InflightDataRescalingDescriptor.NO_RESCALE)
+ InflightDataRescalingDescriptor.NO_RESCALE)
+ // When filter during recovery is enabled, records are already filtered in
+ // the channel-state-unspilling thread. Use StreamTaskNetworkInput to avoid
+ // redundant demultiplexing/filtering in the Task thread.
+ || unalignedDuringRecoveryEnabled
? new StreamTaskNetworkInput<>(
checkpointedInputGate,
inputSerializer,
diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessorFactory.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessorFactory.java
index 2a0c675710b34..03d2236325d0f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessorFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessorFactory.java
@@ -22,6 +22,7 @@
import org.apache.flink.api.common.TaskInfo;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.configuration.CheckpointingOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.metrics.Counter;
@@ -84,6 +85,10 @@ public static StreamMultipleInputProcessor create(
checkNotNull(operatorChain);
taskIOMetricGroup.reuseRecordsInputCounter(numRecordsIn);
+
+ boolean unalignedDuringRecoveryEnabled =
+ CheckpointingOptions.isUnalignedDuringRecoveryEnabled(jobConfig);
+
TypeSerializer typeSerializer1 = streamConfig.getTypeSerializerIn(0, userClassloader);
StreamTaskInput input1 =
StreamTaskNetworkInputFactory.create(
@@ -96,7 +101,8 @@ public static StreamMultipleInputProcessor create(
gatePartitioners,
taskInfo,
canEmitBatchOfRecords,
- streamConfig.getWatermarkDeclarations(userClassloader));
+ streamConfig.getWatermarkDeclarations(userClassloader),
+ unalignedDuringRecoveryEnabled);
TypeSerializer typeSerializer2 = streamConfig.getTypeSerializerIn(1, userClassloader);
StreamTaskInput input2 =
StreamTaskNetworkInputFactory.create(
@@ -109,7 +115,8 @@ public static StreamMultipleInputProcessor create(
gatePartitioners,
taskInfo,
canEmitBatchOfRecords,
- streamConfig.getWatermarkDeclarations(userClassloader));
+ streamConfig.getWatermarkDeclarations(userClassloader),
+ unalignedDuringRecoveryEnabled);
InputSelectable inputSelectable =
streamOperator instanceof InputSelectable ? (InputSelectable) streamOperator : null;
diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializer.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializer.java
index 548b162d7fe96..6f1f3bda8b0e4 100644
--- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializer.java
@@ -26,7 +26,6 @@
import org.apache.flink.runtime.plugable.DeserializationDelegate;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
-import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
import org.apache.flink.util.CloseableIterator;
@@ -38,7 +37,6 @@
import java.util.Comparator;
import java.util.Map;
import java.util.function.Function;
-import java.util.function.Predicate;
import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -59,56 +57,6 @@ class DemultiplexingRecordDeserializer
private VirtualChannel currentVirtualChannel;
- static class VirtualChannel {
- private final RecordDeserializer> deserializer;
- private final Predicate> recordFilter;
- Watermark lastWatermark = Watermark.UNINITIALIZED;
- WatermarkStatus watermarkStatus = WatermarkStatus.ACTIVE;
- private DeserializationResult lastResult;
-
- VirtualChannel(
- RecordDeserializer> deserializer,
- Predicate> recordFilter) {
- this.deserializer = deserializer;
- this.recordFilter = recordFilter;
- }
-
- public DeserializationResult getNextRecord(DeserializationDelegate delegate)
- throws IOException {
- do {
- lastResult = deserializer.getNextRecord(delegate);
-
- if (lastResult.isFullRecord()) {
- final StreamElement element = delegate.getInstance();
- // test if record belongs to this subtask if it comes from ambiguous channel
- if (element.isRecord() && recordFilter.test(element.asRecord())) {
- return lastResult;
- } else if (element.isWatermark()) {
- lastWatermark = element.asWatermark();
- return lastResult;
- } else if (element.isWatermarkStatus()) {
- watermarkStatus = element.asWatermarkStatus();
- return lastResult;
- }
- }
- // loop is only re-executed for filtered full records
- } while (!lastResult.isBufferConsumed());
- return DeserializationResult.PARTIAL_RECORD;
- }
-
- public void setNextBuffer(Buffer buffer) throws IOException {
- deserializer.setNextBuffer(buffer);
- }
-
- public void clear() {
- deserializer.clear();
- }
-
- public boolean hasPartialData() {
- return lastResult != null && !lastResult.isBufferConsumed();
- }
- }
-
public DemultiplexingRecordDeserializer(
Map> channels) {
this.channels = checkNotNull(channels);
@@ -161,7 +109,7 @@ public DeserializationResult getNextRecord(DeserializationDelegate virtualChannel.lastWatermark)
+ .map(VirtualChannel::getLastWatermark)
.min(Comparator.comparing(Watermark::getTimestamp))
.orElseThrow(
() ->
@@ -176,7 +124,8 @@ public DeserializationResult getNextRecord(DeserializationDelegate d.watermarkStatus.isActive())) {
+ if (channels.values().stream()
+ .anyMatch(vc -> vc.getWatermarkStatus().isActive())) {
delegate.setInstance(WatermarkStatus.ACTIVE);
}
return result;
@@ -197,7 +146,7 @@ static DemultiplexingRecordDeserializer create(
InflightDataRescalingDescriptor rescalingDescriptor,
Function>>
deserializerFactory,
- Function>> recordFilterFactory) {
+ Function> recordFilterFactory) {
int[] oldSubtaskIndexes =
rescalingDescriptor.getOldSubtaskIndexes(channelInfo.getGateIdx());
if (oldSubtaskIndexes.length == 0) {
@@ -223,7 +172,7 @@ static DemultiplexingRecordDeserializer create(
deserializerFactory.apply(totalChannels),
rescalingDescriptor.isAmbiguous(channelInfo.getGateIdx(), subtask)
? recordFilterFactory.apply(channelInfo)
- : RecordFilter.all()));
+ : RecordFilter.acceptAll()));
}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/PartitionerRecordFilter.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/PartitionerRecordFilter.java
new file mode 100644
index 0000000000000..a9f56cbdbc69d
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/PartitionerRecordFilter.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.io.recovery;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.io.network.api.writer.ChannelSelector;
+import org.apache.flink.runtime.plugable.SerializationDelegate;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+/**
+ * A {@link RecordFilter} implementation that uses a partitioner to determine record ownership.
+ *
+ * This filter checks if a record would have arrived at this subtask if it had been partitioned
+ * upstream. It is used during recovery for ambiguous channel mappings, such as when the downstream
+ * node of a keyed exchange is rescaled.
+ *
+ * @param The type of the record value.
+ */
+@Internal
+public class PartitionerRecordFilter implements RecordFilter {
+ private final ChannelSelector>> partitioner;
+
+ private final SerializationDelegate> delegate;
+
+ private final int subtaskIndex;
+
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ public PartitionerRecordFilter(
+ ChannelSelector>> partitioner,
+ TypeSerializer inputSerializer,
+ int subtaskIndex) {
+ this.partitioner = partitioner;
+ this.delegate = new SerializationDelegate<>(new StreamElementSerializer(inputSerializer));
+ this.subtaskIndex = subtaskIndex;
+ }
+
+ @Override
+ public boolean filter(StreamRecord streamRecord) {
+ delegate.setInstance(streamRecord);
+ // Check if record would have arrived at this subtask if it had been partitioned upstream.
+ return partitioner.selectChannel(delegate) == subtaskIndex;
+ }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilter.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilter.java
index a0f805f27b582..796ea84dfd843 100644
--- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilter.java
+++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilter.java
@@ -17,48 +17,38 @@
package org.apache.flink.streaming.runtime.io.recovery;
-import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.runtime.io.network.api.writer.ChannelSelector;
-import org.apache.flink.runtime.plugable.SerializationDelegate;
-import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
+import org.apache.flink.annotation.Internal;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-import java.util.function.Predicate;
-
/**
- * Filters records for ambiguous channel mappings.
+ * A filter interface for determining whether a record should be processed.
*
- * For example, when the downstream node of a keyed exchange is scaled from 1 to 2, the state of
- * the output side on te upstream node needs to be replicated to both channels. This filter then
- * checks the deserialized records on both downstream subtasks and filters out the irrelevant
- * records.
+ *
This interface is used during recovery to filter records for ambiguous channel mappings. For
+ * example, when the downstream node of a keyed exchange is scaled from 1 to 2, the state of the
+ * output side on the upstream node needs to be replicated to both channels. The filter then checks
+ * the deserialized records on both downstream subtasks and filters out the irrelevant records.
*
- * @param
+ * @param The type of the record value.
*/
-class RecordFilter implements Predicate> {
- private final ChannelSelector>> partitioner;
-
- private final SerializationDelegate> delegate;
-
- private final int subtaskIndex;
-
- public RecordFilter(
- ChannelSelector>> partitioner,
- TypeSerializer inputSerializer,
- int subtaskIndex) {
- this.partitioner = partitioner;
- delegate = new SerializationDelegate<>(new StreamElementSerializer(inputSerializer));
- this.subtaskIndex = subtaskIndex;
- }
-
- public static Predicate> all() {
+@FunctionalInterface
+@Internal
+public interface RecordFilter {
+
+ /**
+ * Tests whether the given record should be accepted.
+ *
+ * @param record The stream record to test.
+ * @return {@code true} if the record should be accepted, {@code false} otherwise.
+ */
+ boolean filter(StreamRecord record);
+
+ /**
+ * Returns a filter that accepts all records.
+ *
+ * @param The type of the record value.
+ * @return A filter that always returns {@code true}.
+ */
+ static RecordFilter acceptAll() {
return record -> true;
}
-
- @Override
- public boolean test(StreamRecord streamRecord) {
- delegate.setInstance(streamRecord);
- // check if record would have arrived at this subtask if it had been partitioned upstream
- return partitioner.selectChannel(delegate) == subtaskIndex;
- }
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java
new file mode 100644
index 0000000000000..e7edcb22ce6fc
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java
@@ -0,0 +1,227 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.io.recovery;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * Context containing all information needed for filtering recovered channel state buffers.
+ *
+ * This context encapsulates the input configurations, rescaling descriptor, and subtask
+ * information required by the channel-state-unspilling thread to perform record filtering during
+ * recovery.
+ *
+ *
Supports multiple inputs (e.g., TwoInputStreamTask, MultipleInputStreamTask) by storing a list
+ * of {@link InputFilterConfig} instances indexed by input index.
+ *
+ *
Use the constructor with empty inputConfigs or enabled=false when filtering is not needed.
+ */
+@Internal
+public class RecordFilterContext {
+
+ /** Configuration for filtering records on a specific input. */
+ public static class InputFilterConfig {
+ private final TypeSerializer> typeSerializer;
+ private final StreamPartitioner> partitioner;
+ private final int numberOfChannels;
+
+ /**
+ * Creates a new InputFilterConfig.
+ *
+ * @param typeSerializer Serializer for the record type.
+ * @param partitioner Partitioner used to determine record ownership.
+ * @param numberOfChannels The parallelism of the current operator.
+ */
+ public InputFilterConfig(
+ TypeSerializer> typeSerializer,
+ StreamPartitioner> partitioner,
+ int numberOfChannels) {
+ this.typeSerializer = checkNotNull(typeSerializer);
+ this.partitioner = checkNotNull(partitioner);
+ this.numberOfChannels = numberOfChannels;
+ }
+
+ public TypeSerializer> getTypeSerializer() {
+ return typeSerializer;
+ }
+
+ public StreamPartitioner> getPartitioner() {
+ return partitioner;
+ }
+
+ public int getNumberOfChannels() {
+ return numberOfChannels;
+ }
+ }
+
+ /**
+ * Input configurations indexed by gate index. Array elements may be null for non-network inputs
+ * (e.g., SourceInputConfig). The array length equals the total number of input gates.
+ */
+ private final InputFilterConfig[] inputConfigs;
+
+ /** Descriptor containing rescaling information. Never null. */
+ private final InflightDataRescalingDescriptor rescalingDescriptor;
+
+ /** Current subtask index. */
+ private final int subtaskIndex;
+
+ /** Maximum parallelism for configuring partitioners. */
+ private final int maxParallelism;
+
+ /** Temporary directories for spilling spanning records. Can be empty but never null. */
+ private final String[] tmpDirectories;
+
+ /** Whether unaligned checkpoint during recovery is enabled. */
+ private final boolean unalignedDuringRecoveryEnabled;
+
+ /**
+ * Creates a new RecordFilterContext.
+ *
+ * @param inputConfigs Input configurations indexed by gate index. Array elements may be null
+ * for non-network inputs. Not null itself.
+ * @param rescalingDescriptor Descriptor containing rescaling information. Not null.
+ * @param subtaskIndex Current subtask index.
+ * @param maxParallelism Maximum parallelism.
+ * @param tmpDirectories Temporary directories for spilling spanning records. Can be null
+ * (converted to empty array).
+ * @param unalignedDuringRecoveryEnabled Whether unaligned checkpoint during recovery is
+ * enabled.
+ */
+ public RecordFilterContext(
+ InputFilterConfig[] inputConfigs,
+ InflightDataRescalingDescriptor rescalingDescriptor,
+ int subtaskIndex,
+ int maxParallelism,
+ String[] tmpDirectories,
+ boolean unalignedDuringRecoveryEnabled) {
+ this.inputConfigs = checkNotNull(inputConfigs).clone();
+ this.rescalingDescriptor = checkNotNull(rescalingDescriptor);
+ this.subtaskIndex = subtaskIndex;
+ this.maxParallelism = maxParallelism;
+ this.tmpDirectories = tmpDirectories != null ? tmpDirectories : new String[0];
+ this.unalignedDuringRecoveryEnabled = unalignedDuringRecoveryEnabled;
+ }
+
+ /**
+ * Gets the input configuration for a specific gate.
+ *
+ * @param gateIndex The gate index (0-based).
+ * @return The input configuration for the specified gate, or null if the gate is not a network
+ * input (e.g., SourceInputConfig).
+ * @throws IllegalArgumentException if gateIndex is out of bounds.
+ */
+ public InputFilterConfig getInputConfig(int gateIndex) {
+ checkArgument(
+ gateIndex >= 0 && gateIndex < inputConfigs.length,
+ "Invalid gate index: %s, number of gates: %s",
+ gateIndex,
+ inputConfigs.length);
+ return inputConfigs[gateIndex];
+ }
+
+ /**
+ * Gets the number of input gates.
+ *
+ * @return The number of input gates.
+ */
+ public int getNumberOfGates() {
+ return inputConfigs.length;
+ }
+
+ /**
+ * Checks whether unaligned checkpoint during recovery is enabled.
+ *
+ * @return {@code true} if enabled, {@code false} otherwise.
+ */
+ public boolean isUnalignedDuringRecoveryEnabled() {
+ return unalignedDuringRecoveryEnabled;
+ }
+
+ /**
+ * Gets the rescaling descriptor.
+ *
+ * @return The descriptor containing rescaling information.
+ */
+ public InflightDataRescalingDescriptor getRescalingDescriptor() {
+ return rescalingDescriptor;
+ }
+
+ /**
+ * Gets the current subtask index.
+ *
+ * @return The subtask index.
+ */
+ public int getSubtaskIndex() {
+ return subtaskIndex;
+ }
+
+ /**
+ * Gets the maximum parallelism.
+ *
+ * @return The maximum parallelism value.
+ */
+ public int getMaxParallelism() {
+ return maxParallelism;
+ }
+
+ /**
+ * Gets the temporary directories for spilling spanning records.
+ *
+ * @return The temporary directories, never null (may be empty array).
+ */
+ public String[] getTmpDirectories() {
+ return tmpDirectories;
+ }
+
+ /**
+ * Checks if a specific gate and subtask combination is ambiguous (requires filtering).
+ *
+ * @param gateIndex The gate index.
+ * @param oldSubtaskIndex The old subtask index.
+ * @return true if enabled and the channel is ambiguous and records need filtering.
+ */
+ public boolean isAmbiguous(int gateIndex, int oldSubtaskIndex) {
+ return unalignedDuringRecoveryEnabled
+ && rescalingDescriptor.isAmbiguous(gateIndex, oldSubtaskIndex);
+ }
+
+ /**
+ * Creates a disabled RecordFilterContext for testing or when filtering is not needed.
+ *
+ *
The returned context has empty inputConfigs and enabled=false, so {@link
+ * #isUnalignedDuringRecoveryEnabled()} will always return false.
+ *
+ * @return A disabled RecordFilterContext.
+ */
+ public static RecordFilterContext disabled() {
+ return new RecordFilterContext(
+ new InputFilterConfig[0],
+ InflightDataRescalingDescriptor.NO_RESCALE,
+ 0,
+ 0,
+ new String[0],
+ false);
+ }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RescalingStreamTaskNetworkInput.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RescalingStreamTaskNetworkInput.java
index ac8bc4109a595..cc8c75825f325 100644
--- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RescalingStreamTaskNetworkInput.java
+++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RescalingStreamTaskNetworkInput.java
@@ -40,7 +40,6 @@
import org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner;
import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
-import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.StreamTask.CanEmitBatchOfRecordsChecker;
import org.apache.flink.streaming.runtime.watermarkstatus.StatusWatermarkValve;
import org.apache.flink.util.CollectionUtil;
@@ -54,7 +53,6 @@
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
-import java.util.function.Predicate;
import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_DECLINED_TASK_NOT_READY;
import static org.apache.flink.util.Preconditions.checkState;
@@ -201,8 +199,7 @@ public CompletableFuture prepareSnapshot(
* Filters must not be shared across different virtual channels to ensure that the state is
* in-sync across different subtasks.
*/
- static class RecordFilterFactory
- implements Function>> {
+ static class RecordFilterFactory implements Function> {
private final Map> partitionerCache =
CollectionUtil.newHashMapWithExpectedSize(1);
private final Function> gatePartitioners;
@@ -225,7 +222,7 @@ public RecordFilterFactory(
}
@Override
- public Predicate> apply(InputChannelInfo channelInfo) {
+ public RecordFilter apply(InputChannelInfo channelInfo) {
// retrieving the partitioner for one input task is rather costly so cache them all
final StreamPartitioner partitioner =
partitionerCache.computeIfAbsent(
@@ -234,7 +231,7 @@ public Predicate> apply(InputChannelInfo channelInfo) {
// have the same state across several subtasks
StreamPartitioner partitionerCopy = partitioner.copy();
partitionerCopy.setup(numberOfChannels);
- return new RecordFilter<>(partitionerCopy, inputSerializer, subtaskIndex);
+ return new PartitionerRecordFilter<>(partitionerCopy, inputSerializer, subtaskIndex);
}
private StreamPartitioner createPartitioner(Integer index) {
diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannel.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannel.java
new file mode 100644
index 0000000000000..ddcbba79f250a
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannel.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.io.recovery;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
+import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.plugable.DeserializationDelegate;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
+import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
+
+import java.io.IOException;
+
+/**
+ * Represents a virtual channel for demultiplexing records during recovery.
+ *
+ * A virtual channel wraps a {@link RecordDeserializer} and adds record filtering capability,
+ * along with tracking watermark and watermark status state.
+ *
+ * @param The type of record values.
+ */
+@Internal
+public class VirtualChannel {
+ private final RecordDeserializer> deserializer;
+ private final RecordFilter recordFilter;
+
+ private Watermark lastWatermark = Watermark.UNINITIALIZED;
+ private WatermarkStatus watermarkStatus = WatermarkStatus.ACTIVE;
+ private DeserializationResult lastResult;
+
+ public VirtualChannel(
+ RecordDeserializer> deserializer,
+ RecordFilter recordFilter) {
+ this.deserializer = deserializer;
+ this.recordFilter = recordFilter;
+ }
+
+ /**
+ * Deserializes the next record from the buffer, applying the record filter.
+ *
+ * This method loops through records until it finds one that passes the filter or the buffer
+ * is consumed. Watermarks and watermark statuses are always accepted and their state is
+ * updated.
+ *
+ * @param delegate The deserialization delegate to populate with the record.
+ * @return The deserialization result indicating whether a full record was read.
+ * @throws IOException If an I/O error occurs during deserialization.
+ */
+ public DeserializationResult getNextRecord(DeserializationDelegate delegate)
+ throws IOException {
+ do {
+ lastResult = deserializer.getNextRecord(delegate);
+
+ if (lastResult.isFullRecord()) {
+ final StreamElement element = delegate.getInstance();
+ // test if record belongs to this subtask if it comes from ambiguous channel
+ if (element.isRecord() && recordFilter.filter(element.asRecord())) {
+ return lastResult;
+ } else if (element.isWatermark()) {
+ lastWatermark = element.asWatermark();
+ return lastResult;
+ } else if (element.isWatermarkStatus()) {
+ watermarkStatus = element.asWatermarkStatus();
+ return lastResult;
+ }
+ }
+ // loop is only re-executed for filtered full records
+ } while (!lastResult.isBufferConsumed());
+ return DeserializationResult.PARTIAL_RECORD;
+ }
+
+ /**
+ * Sets the next buffer to be deserialized.
+ *
+ * @param buffer The buffer containing serialized records.
+ * @throws IOException If an I/O error occurs.
+ */
+ public void setNextBuffer(Buffer buffer) throws IOException {
+ deserializer.setNextBuffer(buffer);
+ }
+
+ /** Clears the deserializer state. */
+ public void clear() {
+ deserializer.clear();
+ }
+
+ /**
+ * Checks if there is partial data remaining in the buffer.
+ *
+ * @return true if the last result indicates the buffer was not fully consumed.
+ */
+ public boolean hasPartialData() {
+ return lastResult != null && !lastResult.isBufferConsumed();
+ }
+
+ /**
+ * Gets the last watermark received on this virtual channel.
+ *
+ * @return The last watermark, or {@link Watermark#UNINITIALIZED} if none received yet.
+ */
+ public Watermark getLastWatermark() {
+ return lastWatermark;
+ }
+
+ /**
+ * Gets the current watermark status of this virtual channel.
+ *
+ * @return The current watermark status.
+ */
+ public WatermarkStatus getWatermarkStatus() {
+ return watermarkStatus;
+ }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactory.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactory.java
new file mode 100644
index 0000000000000..a2093f1143707
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactory.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.io.recovery;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.io.network.api.writer.ChannelSelector;
+import org.apache.flink.runtime.plugable.SerializationDelegate;
+import org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+/**
+ * Factory for creating record filters used in Virtual Channels during channel state recovery.
+ *
+ * This factory provides methods to create {@link RecordFilter} instances that determine whether
+ * a record belongs to the current subtask based on the partitioner logic.
+ *
+ * @param The type of record values.
+ */
+@Internal
+public class VirtualChannelRecordFilterFactory {
+
+ private final TypeSerializer typeSerializer;
+ private final StreamPartitioner partitioner;
+ private final int subtaskIndex;
+ private final int numberOfChannels;
+ private final int maxParallelism;
+
+ /**
+ * Creates a new VirtualChannelRecordFilterFactory.
+ *
+ * @param typeSerializer Serializer for the record type.
+ * @param partitioner Partitioner used to determine record ownership.
+ * @param subtaskIndex Current subtask index.
+ * @param numberOfChannels Number of parallel subtasks.
+ * @param maxParallelism Maximum parallelism for configuring partitioners.
+ */
+ public VirtualChannelRecordFilterFactory(
+ TypeSerializer typeSerializer,
+ StreamPartitioner partitioner,
+ int subtaskIndex,
+ int numberOfChannels,
+ int maxParallelism) {
+ this.typeSerializer = typeSerializer;
+ this.partitioner = partitioner;
+ this.subtaskIndex = subtaskIndex;
+ this.numberOfChannels = numberOfChannels;
+ this.maxParallelism = maxParallelism;
+ }
+
+ /**
+ * Creates a new VirtualChannelRecordFilterFactory from a RecordFilterContext and input index.
+ *
+ * @param context The record filter context.
+ * @param inputIndex The input index to get configuration from.
+ * @param The type of record values.
+ * @return A new factory instance.
+ */
+ @SuppressWarnings("unchecked")
+ public static VirtualChannelRecordFilterFactory fromContext(
+ RecordFilterContext context, int inputIndex) {
+ RecordFilterContext.InputFilterConfig inputConfig = context.getInputConfig(inputIndex);
+ return new VirtualChannelRecordFilterFactory<>(
+ (TypeSerializer) inputConfig.getTypeSerializer(),
+ (StreamPartitioner) inputConfig.getPartitioner(),
+ context.getSubtaskIndex(),
+ inputConfig.getNumberOfChannels(),
+ context.getMaxParallelism());
+ }
+
+ /**
+ * Creates a record filter for ambiguous channels that requires actual filtering.
+ *
+ * @return A RecordFilter that tests if a record belongs to this subtask.
+ */
+ public RecordFilter createFilter() {
+ StreamPartitioner configuredPartitioner = configurePartitioner();
+ @SuppressWarnings("unchecked")
+ ChannelSelector>> channelSelector =
+ configuredPartitioner;
+ return new PartitionerRecordFilter<>(channelSelector, typeSerializer, subtaskIndex);
+ }
+
+ /**
+ * Creates a pass-through filter that accepts all records.
+ *
+ * @param The type of record values.
+ * @return A RecordFilter that always returns true.
+ */
+ public static RecordFilter createPassThroughFilter() {
+ return RecordFilter.acceptAll();
+ }
+
+ /**
+ * Configures the partitioner with the correct number of channels and max parallelism.
+ *
+ * @return A configured copy of the partitioner.
+ */
+ private StreamPartitioner configurePartitioner() {
+ StreamPartitioner copy = partitioner.copy();
+ copy.setup(numberOfChannels);
+ if (copy instanceof ConfigurableStreamPartitioner) {
+ ((ConfigurableStreamPartitioner) copy).configure(maxParallelism);
+ }
+ return copy;
+ }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java
index 353582674e01a..676175514d922 100644
--- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java
+++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java
@@ -203,6 +203,9 @@ private StreamTaskInput createTaskInput(CheckpointedInputGate inputGate) {
Set> watermarkDeclarationSet =
configuration.getWatermarkDeclarations(getUserCodeClassLoader());
+ boolean unalignedDuringRecoveryEnabled =
+ CheckpointingOptions.isUnalignedDuringRecoveryEnabled(getJobConfiguration());
+
return StreamTaskNetworkInputFactory.create(
inputGate,
inSerializer,
@@ -217,7 +220,8 @@ private StreamTaskInput createTaskInput(CheckpointedInputGate inputGate) {
.getPartitioner(),
getEnvironment().getTaskInfo(),
getCanEmitBatchOfRecords(),
- watermarkDeclarationSet);
+ watermarkDeclarationSet,
+ unalignedDuringRecoveryEnabled);
}
/**
diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index a07d5ee3915aa..5eb6e0ebea4ed 100644
--- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -21,6 +21,7 @@
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.operators.MailboxExecutor;
import org.apache.flink.api.common.operators.ProcessingTimeService.ProcessingTimeCallback;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.configuration.CheckpointingOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
@@ -84,6 +85,7 @@
import org.apache.flink.runtime.util.ConfigurationParserUtils;
import org.apache.flink.streaming.api.graph.NonChainedOutput;
import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.api.graph.StreamEdge;
import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
import org.apache.flink.streaming.api.operators.InternalTimeServiceManagerImpl;
import org.apache.flink.streaming.api.operators.StreamOperator;
@@ -94,6 +96,7 @@
import org.apache.flink.streaming.runtime.io.StreamInputProcessor;
import org.apache.flink.streaming.runtime.io.checkpointing.BarrierAlignmentUtil;
import org.apache.flink.streaming.runtime.io.checkpointing.CheckpointBarrierHandler;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext;
import org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner;
import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner;
@@ -881,7 +884,7 @@ private CompletableFuture restoreStateAndGates(
channelIOExecutor.execute(
() -> {
try {
- reader.readInputData(inputGates);
+ reader.readInputData(inputGates, createRecordFilterContext());
} catch (Exception e) {
asyncExceptionHandler.handleAsyncException(
"Unable to read channel state", e);
@@ -1956,6 +1959,83 @@ public final Environment getEnvironment() {
return environment;
}
+ /**
+ * Creates a RecordFilterContext for filtering recovered channel state buffers.
+ *
+ * This method builds the complete context using information available in StreamTask,
+ * including input configurations for all network inputs.
+ *
+ * @return A RecordFilterContext with input configurations. The context may have empty
+ * inputConfigs (e.g., for source tasks) or enabled=false when filtering is not needed.
+ */
+ protected RecordFilterContext createRecordFilterContext() {
+ boolean unalignedDuringRecoveryEnabled =
+ CheckpointingOptions.isUnalignedDuringRecoveryEnabled(getJobConfiguration());
+ if (!unalignedDuringRecoveryEnabled) {
+ return RecordFilterContext.disabled();
+ }
+
+ ClassLoader cl = getUserCodeClassLoader();
+ StreamConfig.InputConfig[] inputs = configuration.getInputs(cl);
+ List inEdges = configuration.getInPhysicalEdges(cl);
+
+ // Create array sized to match the number of physical input gates.
+ // For source tasks, this will be 0. For tasks with network inputs, each physical gate
+ // must have a corresponding config entry.
+ int numGates = getEnvironment().getAllInputGates().length;
+ RecordFilterContext.InputFilterConfig[] inputConfigs =
+ new RecordFilterContext.InputFilterConfig[numGates];
+
+ Preconditions.checkState(
+ numGates == inEdges.size(),
+ "Number of input gates (%s) does not match number of physical edges (%s)",
+ numGates,
+ inEdges.size());
+
+ // Iterate through all physical edges (inEdges) instead of logical inputs.
+ // This is critical for Union scenarios where multiple physical gates map to one logical
+ // input. The order of inEdges matches the order of physical input gates.
+ int numberOfChannels = getEnvironment().getTaskInfo().getNumberOfParallelSubtasks();
+ for (int gateIndex = 0; gateIndex < inEdges.size(); gateIndex++) {
+ StreamEdge edge = inEdges.get(gateIndex);
+ // Calculate logical input index from typeNumber
+ // typeNumber = 0 means single input, typeNumber >= 1 means multi-input (1-indexed)
+ int inputIndex = edge.getTypeNumber() == 0 ? 0 : edge.getTypeNumber() - 1;
+
+ Preconditions.checkState(
+ inputIndex < inputs.length
+ && inputs[inputIndex] instanceof StreamConfig.NetworkInputConfig,
+ "Physical edge at gateIndex %s has invalid inputIndex %s or non-network input",
+ gateIndex,
+ inputIndex);
+
+ StreamConfig.NetworkInputConfig networkInput =
+ (StreamConfig.NetworkInputConfig) inputs[inputIndex];
+ TypeSerializer> typeSerializer = networkInput.getTypeSerializer();
+ StreamPartitioner> partitioner = edge.getPartitioner();
+
+ inputConfigs[gateIndex] =
+ new RecordFilterContext.InputFilterConfig(
+ typeSerializer, partitioner, numberOfChannels);
+ }
+
+ for (int i = 0; i < inputConfigs.length; i++) {
+ Preconditions.checkState(
+ inputConfigs[i] != null,
+ "InputFilterConfig at index %s is null. "
+ + "All physical gates must have corresponding configurations.",
+ i);
+ }
+
+ return new RecordFilterContext(
+ inputConfigs,
+ getEnvironment().getTaskStateManager().getInputRescalingDescriptor(),
+ getEnvironment().getTaskInfo().getIndexOfThisSubtask(),
+ getEnvironment().getTaskInfo().getMaxNumberOfParallelSubtasks(),
+ getEnvironment().getIOManager().getSpillingDirectoriesPaths(),
+ true);
+ }
+
/** Check whether records can be emitted in batch. */
@FunctionalInterface
public interface CanEmitBatchOfRecordsChecker {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java
index e2b1d69c56ab1..39ce6c7d4bf7d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java
@@ -77,7 +77,8 @@ private InputChannelRecoveredStateHandler buildInputChannelStateHandler(
InflightDataRescalingDescriptor
.InflightDataGateOrPartitionRescalingDescriptor
.MappingType.IDENTITY)
- }));
+ }),
+ null);
}
private InputChannelRecoveredStateHandler buildMultiChannelHandler() {
@@ -103,7 +104,8 @@ private InputChannelRecoveredStateHandler buildMultiChannelHandler() {
InflightDataRescalingDescriptor
.InflightDataGateOrPartitionRescalingDescriptor
.MappingType.RESCALING)
- }));
+ }),
+ null);
}
@Test
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java
index 5f05af0d92e59..d80442b8a06f8 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java
@@ -40,6 +40,7 @@
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext;
import org.apache.flink.testutils.junit.extensions.parameterized.Parameter;
import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
@@ -143,7 +144,7 @@ void testReadPermutedState() throws Exception {
withInputGates(
gates -> {
- reader.readInputData(gates);
+ reader.readInputData(gates, RecordFilterContext.disabled());
assertBuffersEquals(inputChannelsData, collectBuffers(gates));
assertConsumed(gates);
});
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java
index 53c05b1961309..f64a4d9fb9cac 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java
@@ -49,6 +49,7 @@
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext;
import org.apache.flink.util.function.SupplierWithException;
import org.junit.jupiter.api.Test;
@@ -119,7 +120,7 @@ void testReadWritten() throws Exception {
try {
int numChannels = 1;
InputGate gate = buildGate(networkBufferPool, numChannels);
- reader.readInputData(new InputGate[] {gate});
+ reader.readInputData(new InputGate[] {gate}, RecordFilterContext.disabled());
assertThat(collectBytes(gate::pollNext, BufferOrEvent::getBuffer))
.isEqualTo(inputChannelInfoData);
diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializerTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializerTest.java
index 361b78e3f2e34..ed2ea64d9e753 100644
--- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializerTest.java
@@ -95,7 +95,7 @@ void testUpscale() throws IOException {
unused ->
new SpillingAdaptiveSpanningRecordDeserializer<>(
ioManager.getSpillingDirectoriesPaths()),
- unused -> RecordFilter.all());
+ unused -> RecordFilter.acceptAll());
assertThat(deserializer.getVirtualChannelSelectors())
.containsOnly(
@@ -136,7 +136,9 @@ void testAmbiguousChannels() throws IOException {
unused ->
new SpillingAdaptiveSpanningRecordDeserializer<>(
ioManager.getSpillingDirectoriesPaths()),
- unused -> new RecordFilter(new ModSelector(2), LongSerializer.INSTANCE, 1));
+ unused ->
+ new PartitionerRecordFilter<>(
+ new ModSelector(2), LongSerializer.INSTANCE, 1));
assertThat(deserializer.getVirtualChannelSelectors())
.containsOnly(
@@ -179,7 +181,7 @@ void testWatermarks() throws IOException {
unused ->
new SpillingAdaptiveSpanningRecordDeserializer<>(
ioManager.getSpillingDirectoriesPaths()),
- unused -> RecordFilter.all());
+ unused -> RecordFilter.acceptAll());
assertThat(deserializer.getVirtualChannelSelectors()).hasSize(4);
diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java
new file mode 100644
index 0000000000000..27832b4c3c528
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.io.recovery;
+
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.runtime.checkpoint.RescaleMappings;
+import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
+
+import org.junit.jupiter.api.Test;
+
+import static org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptorUtil.mappings;
+import static org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptorUtil.rescalingDescriptor;
+import static org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptorUtil.set;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Tests for {@link RecordFilterContext}. */
+class RecordFilterContextTest {
+
+ @Test
+ void testDisabledContextHasNoGates() {
+ RecordFilterContext disabled = RecordFilterContext.disabled();
+ assertThat(disabled.getNumberOfGates()).isEqualTo(0);
+ assertThat(disabled.isUnalignedDuringRecoveryEnabled()).isFalse();
+ }
+
+ @Test
+ void testGetInputConfigReturnsCorrectConfig() {
+ RecordFilterContext.InputFilterConfig config =
+ new RecordFilterContext.InputFilterConfig(
+ LongSerializer.INSTANCE, new ForwardPartitioner<>(), 4);
+
+ RecordFilterContext context =
+ new RecordFilterContext(
+ new RecordFilterContext.InputFilterConfig[] {config},
+ InflightDataRescalingDescriptor.NO_RESCALE,
+ 0,
+ 128,
+ new String[] {"/tmp"},
+ true);
+
+ assertThat(context.getNumberOfGates()).isEqualTo(1);
+ assertThat(context.getInputConfig(0)).isSameAs(config);
+ assertThat(context.getSubtaskIndex()).isEqualTo(0);
+ assertThat(context.getMaxParallelism()).isEqualTo(128);
+ assertThat(context.isUnalignedDuringRecoveryEnabled()).isTrue();
+ }
+
+ @Test
+ void testGetInputConfigThrowsForInvalidIndex() {
+ RecordFilterContext context =
+ new RecordFilterContext(
+ new RecordFilterContext.InputFilterConfig[0],
+ InflightDataRescalingDescriptor.NO_RESCALE,
+ 0,
+ 128,
+ null,
+ false);
+
+ assertThatThrownBy(() -> context.getInputConfig(0))
+ .isInstanceOf(IllegalArgumentException.class);
+ assertThatThrownBy(() -> context.getInputConfig(-1))
+ .isInstanceOf(IllegalArgumentException.class);
+ }
+
+ @Test
+ void testNullTmpDirectoriesConvertedToEmptyArray() {
+ RecordFilterContext context =
+ new RecordFilterContext(
+ new RecordFilterContext.InputFilterConfig[0],
+ InflightDataRescalingDescriptor.NO_RESCALE,
+ 0,
+ 128,
+ null,
+ false);
+
+ assertThat(context.getTmpDirectories()).isNotNull().isEmpty();
+ }
+
+ @Test
+ void testIsAmbiguousWhenDisabled() {
+ // Create a rescaling descriptor with an ambiguous subtask (oldSubtask 0 is ambiguous)
+ RescaleMappings mapping = mappings(new int[] {0});
+ InflightDataRescalingDescriptor descriptor =
+ rescalingDescriptor(new int[] {0}, new RescaleMappings[] {mapping}, set(0));
+
+ // When unalignedDuringRecoveryEnabled is false, isAmbiguous should always return false
+ RecordFilterContext context =
+ new RecordFilterContext(
+ new RecordFilterContext.InputFilterConfig[0],
+ descriptor,
+ 0,
+ 128,
+ null,
+ false);
+
+ assertThat(context.isAmbiguous(0, 0)).isFalse();
+ }
+
+ @Test
+ void testIsAmbiguousWhenEnabled() {
+ // Create a rescaling descriptor with an ambiguous subtask (oldSubtask 0 is ambiguous)
+ RescaleMappings mapping = mappings(new int[] {0});
+ InflightDataRescalingDescriptor descriptor =
+ rescalingDescriptor(new int[] {0}, new RescaleMappings[] {mapping}, set(0));
+
+ // When unalignedDuringRecoveryEnabled is true, isAmbiguous follows rescalingDescriptor
+ RecordFilterContext context =
+ new RecordFilterContext(
+ new RecordFilterContext.InputFilterConfig[0],
+ descriptor,
+ 0,
+ 128,
+ null,
+ true);
+
+ assertThat(context.isAmbiguous(0, 0)).isTrue();
+ }
+
+ @Test
+ void testIsAmbiguousForNonAmbiguousSubtask() {
+ // Create a rescaling descriptor where oldSubtask 0 is ambiguous but oldSubtask 1 is not
+ RescaleMappings mapping = mappings(new int[] {0});
+ InflightDataRescalingDescriptor descriptor =
+ rescalingDescriptor(new int[] {0, 1}, new RescaleMappings[] {mapping}, set(0));
+
+ RecordFilterContext context =
+ new RecordFilterContext(
+ new RecordFilterContext.InputFilterConfig[0],
+ descriptor,
+ 0,
+ 128,
+ null,
+ true);
+
+ // oldSubtask 0 is ambiguous
+ assertThat(context.isAmbiguous(0, 0)).isTrue();
+ // oldSubtask 1 is NOT in the ambiguous set
+ assertThat(context.isAmbiguous(0, 1)).isFalse();
+ }
+
+ @Test
+ void testInputFilterConfigGetters() {
+ ForwardPartitioner partitioner = new ForwardPartitioner<>();
+ RecordFilterContext.InputFilterConfig config =
+ new RecordFilterContext.InputFilterConfig(LongSerializer.INSTANCE, partitioner, 4);
+
+ assertThat(config.getTypeSerializer()).isSameAs(LongSerializer.INSTANCE);
+ assertThat(config.getPartitioner()).isSameAs(partitioner);
+ assertThat(config.getNumberOfChannels()).isEqualTo(4);
+ }
+
+ @Test
+ void testMultipleGateConfigs() {
+ RecordFilterContext.InputFilterConfig config0 =
+ new RecordFilterContext.InputFilterConfig(
+ LongSerializer.INSTANCE, new ForwardPartitioner<>(), 2);
+ RecordFilterContext.InputFilterConfig config1 =
+ new RecordFilterContext.InputFilterConfig(
+ LongSerializer.INSTANCE, new ForwardPartitioner<>(), 4);
+
+ RecordFilterContext context =
+ new RecordFilterContext(
+ new RecordFilterContext.InputFilterConfig[] {config0, config1},
+ InflightDataRescalingDescriptor.NO_RESCALE,
+ 1,
+ 256,
+ new String[] {"/tmp"},
+ false);
+
+ assertThat(context.getNumberOfGates()).isEqualTo(2);
+ assertThat(context.getInputConfig(0)).isSameAs(config0);
+ assertThat(context.getInputConfig(1)).isSameAs(config1);
+ assertThat(context.getInputConfig(0).getNumberOfChannels()).isEqualTo(2);
+ assertThat(context.getInputConfig(1).getNumberOfChannels()).isEqualTo(4);
+ }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterTest.java
new file mode 100644
index 0000000000000..10c7e0b9839ef
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterTest.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.io.recovery;
+
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.runtime.io.network.api.writer.ChannelSelector;
+import org.apache.flink.runtime.plugable.SerializationDelegate;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.junit.jupiter.api.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Tests for {@link RecordFilter} interface and {@link PartitionerRecordFilter}. */
+class RecordFilterTest {
+
+ @Test
+ void testAcceptAllFilterAcceptsEveryRecord() {
+ RecordFilter filter = RecordFilter.acceptAll();
+ assertThat(filter.filter(new StreamRecord<>(0L))).isTrue();
+ assertThat(filter.filter(new StreamRecord<>(1L))).isTrue();
+ assertThat(filter.filter(new StreamRecord<>(42L))).isTrue();
+ assertThat(filter.filter(new StreamRecord<>(-1L))).isTrue();
+ }
+
+ @Test
+ void testPartitionerRecordFilterAcceptsMatchingSubtask() {
+ // Mod-based partitioner with 2 channels, subtask index 0 receives even values
+ PartitionerRecordFilter filter =
+ new PartitionerRecordFilter<>(
+ new ModChannelSelector(2), LongSerializer.INSTANCE, 0);
+
+ assertThat(filter.filter(new StreamRecord<>(0L))).isTrue();
+ assertThat(filter.filter(new StreamRecord<>(2L))).isTrue();
+ assertThat(filter.filter(new StreamRecord<>(4L))).isTrue();
+ }
+
+ @Test
+ void testPartitionerRecordFilterRejectsNonMatchingSubtask() {
+ // Mod-based partitioner with 2 channels, subtask index 0 should reject odd values
+ PartitionerRecordFilter filter =
+ new PartitionerRecordFilter<>(
+ new ModChannelSelector(2), LongSerializer.INSTANCE, 0);
+
+ assertThat(filter.filter(new StreamRecord<>(1L))).isFalse();
+ assertThat(filter.filter(new StreamRecord<>(3L))).isFalse();
+ assertThat(filter.filter(new StreamRecord<>(5L))).isFalse();
+ }
+
+ @Test
+ void testPartitionerRecordFilterWithDifferentSubtaskIndex() {
+ // Subtask index 1 should accept odd values (mod 2 == 1)
+ PartitionerRecordFilter filter =
+ new PartitionerRecordFilter<>(
+ new ModChannelSelector(2), LongSerializer.INSTANCE, 1);
+
+ assertThat(filter.filter(new StreamRecord<>(1L))).isTrue();
+ assertThat(filter.filter(new StreamRecord<>(3L))).isTrue();
+ assertThat(filter.filter(new StreamRecord<>(0L))).isFalse();
+ assertThat(filter.filter(new StreamRecord<>(2L))).isFalse();
+ }
+
+ @Test
+ void testRecordFilterAsFunctionalInterface() {
+ // RecordFilter is a FunctionalInterface and can be implemented as a lambda
+ RecordFilter onlyPositive = record -> record.getValue() > 0;
+ assertThat(onlyPositive.filter(new StreamRecord<>(5L))).isTrue();
+ assertThat(onlyPositive.filter(new StreamRecord<>(-1L))).isFalse();
+ assertThat(onlyPositive.filter(new StreamRecord<>(0L))).isFalse();
+ }
+
+ /** A simple mod-based channel selector for testing. */
+ private static class ModChannelSelector
+ implements ChannelSelector>> {
+ private final int numberOfChannels;
+
+ private ModChannelSelector(int numberOfChannels) {
+ this.numberOfChannels = numberOfChannels;
+ }
+
+ @Override
+ public void setup(int numberOfChannels) {}
+
+ @Override
+ public int selectChannel(SerializationDelegate> record) {
+ return (int) (record.getInstance().getValue() % numberOfChannels);
+ }
+
+ @Override
+ public boolean isBroadcast() {
+ return false;
+ }
+ }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactoryTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactoryTest.java
new file mode 100644
index 0000000000000..bffc42e33294d
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactoryTest.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.io.recovery;
+
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.junit.jupiter.api.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Tests for {@link VirtualChannelRecordFilterFactory}. */
+class VirtualChannelRecordFilterFactoryTest {
+
+ @Test
+ void testCreatePassThroughFilter() {
+ RecordFilter filter = VirtualChannelRecordFilterFactory.createPassThroughFilter();
+ assertThat(filter.filter(new StreamRecord<>(0L))).isTrue();
+ assertThat(filter.filter(new StreamRecord<>(1L))).isTrue();
+ assertThat(filter.filter(new StreamRecord<>(42L))).isTrue();
+ }
+
+ @Test
+ void testCreateFilterProducesPartitionerBasedFilter() {
+ RebalancePartitioner partitioner = new RebalancePartitioner<>();
+
+ VirtualChannelRecordFilterFactory factory =
+ new VirtualChannelRecordFilterFactory<>(
+ LongSerializer.INSTANCE, partitioner, 0, 2, 128);
+
+ RecordFilter filter = factory.createFilter();
+ // The filter should be a PartitionerRecordFilter that filters based on partitioner
+ assertThat(filter).isInstanceOf(PartitionerRecordFilter.class);
+ }
+
+ @Test
+ void testFromContextCreatesFactory() {
+ RebalancePartitioner partitioner = new RebalancePartitioner<>();
+ RecordFilterContext.InputFilterConfig config =
+ new RecordFilterContext.InputFilterConfig(LongSerializer.INSTANCE, partitioner, 4);
+
+ RecordFilterContext context =
+ new RecordFilterContext(
+ new RecordFilterContext.InputFilterConfig[] {config},
+ InflightDataRescalingDescriptor.NO_RESCALE,
+ 1,
+ 128,
+ new String[] {"/tmp"},
+ true);
+
+ VirtualChannelRecordFilterFactory factory =
+ VirtualChannelRecordFilterFactory.fromContext(context, 0);
+ RecordFilter filter = factory.createFilter();
+
+ // The filter should be a functional PartitionerRecordFilter
+ assertThat(filter).isInstanceOf(PartitionerRecordFilter.class);
+ }
+
+ @Test
+ void testEachFilterCallCreatesIndependentFilter() {
+ RebalancePartitioner partitioner = new RebalancePartitioner<>();
+
+ VirtualChannelRecordFilterFactory factory =
+ new VirtualChannelRecordFilterFactory<>(
+ LongSerializer.INSTANCE, partitioner, 0, 2, 128);
+
+ RecordFilter filter1 = factory.createFilter();
+ RecordFilter filter2 = factory.createFilter();
+
+ // Each call should produce a distinct filter instance (using a copy of the partitioner)
+ assertThat(filter1).isNotSameAs(filter2);
+ }
+}