Skip to content

Commit d84f212

Browse files
author
wangxiaogang
committed
[Enhance][translation] Add support for re-signaling NoMoreSplitsEvent after reader re-registration
1 parent c3b63c8 commit d84f212

File tree

8 files changed

+383
-6
lines changed

8 files changed

+383
-6
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.seatunnel.connectors.seatunnel.jdbc.source;
19+
20+
import org.apache.seatunnel.api.common.metrics.MetricsContext;
21+
import org.apache.seatunnel.api.event.EventListener;
22+
import org.apache.seatunnel.api.source.SourceEvent;
23+
import org.apache.seatunnel.api.source.SourceSplitEnumerator;
24+
import org.apache.seatunnel.api.table.catalog.CatalogTable;
25+
import org.apache.seatunnel.api.table.catalog.TableIdentifier;
26+
import org.apache.seatunnel.api.table.catalog.TablePath;
27+
import org.apache.seatunnel.api.table.catalog.TableSchema;
28+
import org.apache.seatunnel.connectors.seatunnel.jdbc.config.JdbcConnectionConfig;
29+
import org.apache.seatunnel.connectors.seatunnel.jdbc.config.JdbcSourceConfig;
30+
31+
import org.junit.jupiter.api.Assertions;
32+
import org.junit.jupiter.api.Test;
33+
34+
import java.util.ArrayList;
35+
import java.util.Collections;
36+
import java.util.HashMap;
37+
import java.util.HashSet;
38+
import java.util.List;
39+
import java.util.Map;
40+
import java.util.Set;
41+
import java.util.concurrent.ConcurrentHashMap;
42+
import java.util.concurrent.atomic.AtomicInteger;
43+
44+
class JdbcSourceSplitEnumeratorTest {
45+
46+
@Test
47+
void testRunSignalsNoMoreSplitsOnce() throws Exception {
48+
int parallelism = 1;
49+
TablePath tablePath = TablePath.of("db", "schema", "table");
50+
51+
Map<TablePath, JdbcSourceTable> tables = new HashMap<>();
52+
tables.put(tablePath, createJdbcSourceTable(tablePath));
53+
54+
List<Integer> assignTargets = new ArrayList<>();
55+
Set<Integer> noMoreSplitsReaders = new HashSet<>();
56+
AtomicInteger noMoreSplitsCallCount = new AtomicInteger();
57+
58+
SourceSplitEnumerator.Context<JdbcSourceSplit> context =
59+
new SourceSplitEnumerator.Context<JdbcSourceSplit>() {
60+
@Override
61+
public int currentParallelism() {
62+
return parallelism;
63+
}
64+
65+
@Override
66+
public Set<Integer> registeredReaders() {
67+
return Collections.singleton(0);
68+
}
69+
70+
@Override
71+
public void assignSplit(int subtaskId, List<JdbcSourceSplit> splits) {
72+
assignTargets.add(subtaskId);
73+
}
74+
75+
@Override
76+
public void signalNoMoreSplits(int subtask) {
77+
noMoreSplitsCallCount.incrementAndGet();
78+
noMoreSplitsReaders.add(subtask);
79+
}
80+
81+
@Override
82+
public void sendEventToSourceReader(int subtaskId, SourceEvent event) {}
83+
84+
@Override
85+
public MetricsContext getMetricsContext() {
86+
return null;
87+
}
88+
89+
@Override
90+
public EventListener getEventListener() {
91+
return null;
92+
}
93+
};
94+
95+
JdbcSourceConfig sourceConfig =
96+
JdbcSourceConfig.builder()
97+
.jdbcConnectionConfig(
98+
JdbcConnectionConfig.builder()
99+
.url("jdbc:generic://localhost:0/test")
100+
.driverName("org.example.Driver")
101+
.build())
102+
.build();
103+
104+
JdbcSourceSplitEnumerator enumerator =
105+
new JdbcSourceSplitEnumerator(context, sourceConfig, tables, null);
106+
107+
enumerator.open();
108+
enumerator.run();
109+
110+
Assertions.assertEquals(Collections.singletonList(0), assignTargets);
111+
Assertions.assertEquals(Collections.singleton(0), noMoreSplitsReaders);
112+
Assertions.assertEquals(1, noMoreSplitsCallCount.get());
113+
114+
// NoMoreSplitsEvent is only sent once at the end of run().
115+
enumerator.addSplitsBack(Collections.emptyList(), 0);
116+
enumerator.registerReader(0);
117+
118+
Assertions.assertEquals(1, noMoreSplitsCallCount.get());
119+
}
120+
121+
@Test
122+
void testRunSignalsNoMoreSplitsForAllRegisteredReadersWithHighParallelism() throws Exception {
123+
int parallelism = 8;
124+
125+
Set<Integer> registeredReaders = new HashSet<>();
126+
for (int i = 0; i < parallelism; i++) {
127+
registeredReaders.add(i);
128+
}
129+
130+
Map<TablePath, JdbcSourceTable> tables = new HashMap<>();
131+
for (int i = 0; i < 3; i++) {
132+
TablePath tablePath = TablePath.of("db", "schema", "table_" + i);
133+
tables.put(tablePath, createJdbcSourceTable(tablePath));
134+
}
135+
136+
Map<String, Integer> assignedSplitOwners = new HashMap<>();
137+
Set<Integer> noMoreSplitsReaders = ConcurrentHashMap.newKeySet();
138+
AtomicInteger noMoreSplitsCallCount = new AtomicInteger();
139+
140+
SourceSplitEnumerator.Context<JdbcSourceSplit> context =
141+
new SourceSplitEnumerator.Context<JdbcSourceSplit>() {
142+
@Override
143+
public int currentParallelism() {
144+
return parallelism;
145+
}
146+
147+
@Override
148+
public Set<Integer> registeredReaders() {
149+
return new HashSet<>(registeredReaders);
150+
}
151+
152+
@Override
153+
public void assignSplit(int subtaskId, List<JdbcSourceSplit> splits) {
154+
for (JdbcSourceSplit split : splits) {
155+
assignedSplitOwners.put(split.splitId(), subtaskId);
156+
}
157+
}
158+
159+
@Override
160+
public void signalNoMoreSplits(int subtask) {
161+
noMoreSplitsCallCount.incrementAndGet();
162+
noMoreSplitsReaders.add(subtask);
163+
}
164+
165+
@Override
166+
public void sendEventToSourceReader(int subtaskId, SourceEvent event) {}
167+
168+
@Override
169+
public MetricsContext getMetricsContext() {
170+
return null;
171+
}
172+
173+
@Override
174+
public EventListener getEventListener() {
175+
return null;
176+
}
177+
};
178+
179+
JdbcSourceConfig sourceConfig =
180+
JdbcSourceConfig.builder()
181+
.jdbcConnectionConfig(
182+
JdbcConnectionConfig.builder()
183+
.url("jdbc:generic://localhost:0/test")
184+
.driverName("org.example.Driver")
185+
.build())
186+
.build();
187+
188+
JdbcSourceSplitEnumerator enumerator =
189+
new JdbcSourceSplitEnumerator(context, sourceConfig, tables, null);
190+
191+
enumerator.open();
192+
enumerator.run();
193+
194+
Assertions.assertEquals(tables.size(), assignedSplitOwners.size());
195+
assignedSplitOwners.forEach(
196+
(splitId, owner) -> {
197+
int expectedOwner = (splitId.hashCode() & Integer.MAX_VALUE) % parallelism;
198+
Assertions.assertEquals(expectedOwner, owner);
199+
});
200+
201+
Assertions.assertEquals(registeredReaders, noMoreSplitsReaders);
202+
Assertions.assertEquals(parallelism, noMoreSplitsCallCount.get());
203+
Assertions.assertEquals(0, enumerator.currentUnassignedSplitSize());
204+
}
205+
206+
private JdbcSourceTable createJdbcSourceTable(TablePath tablePath) {
207+
TableIdentifier tableId = TableIdentifier.of("default", tablePath);
208+
TableSchema tableSchema = TableSchema.builder().columns(Collections.emptyList()).build();
209+
CatalogTable catalogTable =
210+
CatalogTable.of(
211+
tableId, tableSchema, Collections.emptyMap(), Collections.emptyList(), "");
212+
return JdbcSourceTable.builder().tablePath(tablePath).catalogTable(catalogTable).build();
213+
}
214+
}

seatunnel-engine/seatunnel-engine-server/src/main/java/org/apache/seatunnel/engine/server/task/SourceSplitEnumeratorTask.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,9 +223,16 @@ public void receivedReader(TaskLocation readerId, Address memberAddr)
223223
log.info("received reader register, readerID: " + readerId);
224224

225225
SourceSplitEnumerator<SplitT, Serializable> enumerator = getEnumerator();
226+
int readerIndex = readerId.getTaskIndex();
226227
this.addTaskMemberMapping(readerId, memberAddr);
227228
synchronized (this) {
228-
enumerator.registerReader(readerId.getTaskIndex());
229+
enumerator.registerReader(readerIndex);
230+
if (enumeratorContext.hasNoMoreSplitsSignaled(readerIndex)) {
231+
log.info(
232+
"Reader [{}] re-registered after failover. Re-signaling NoMoreSplitsEvent.",
233+
readerIndex);
234+
enumeratorContext.signalNoMoreSplits(readerIndex);
235+
}
229236
}
230237
int taskSize = taskMemberMapping.size();
231238
if (maxReaderSize == taskSize) {

seatunnel-engine/seatunnel-engine-server/src/main/java/org/apache/seatunnel/engine/server/task/context/SeaTunnelSplitEnumeratorContext.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import java.util.HashSet;
3232
import java.util.List;
3333
import java.util.Set;
34+
import java.util.concurrent.ConcurrentHashMap;
3435
import java.util.stream.Collectors;
3536

3637
import static org.apache.seatunnel.engine.common.utils.ExceptionUtil.sneaky;
@@ -46,6 +47,8 @@ public class SeaTunnelSplitEnumeratorContext<SplitT extends SourceSplit>
4647
private final MetricsContext metricsContext;
4748
private final EventListener eventListener;
4849

50+
private final Set<Integer> noMoreSplitsSignaledReaders = ConcurrentHashMap.newKeySet();
51+
4952
public SeaTunnelSplitEnumeratorContext(
5053
int parallelism,
5154
SourceSplitEnumeratorTask<SplitT> task,
@@ -88,6 +91,7 @@ public void assignSplit(int subtaskIndex, List<SplitT> splits) {
8891

8992
@Override
9093
public void signalNoMoreSplits(int subtaskIndex) {
94+
noMoreSplitsSignaledReaders.add(subtaskIndex);
9195
List<byte[]> emptySplits = Collections.emptyList();
9296
task.getExecutionContext()
9397
.sendToMember(
@@ -109,4 +113,8 @@ public MetricsContext getMetricsContext() {
109113
public EventListener getEventListener() {
110114
return eventListener;
111115
}
116+
117+
public boolean hasNoMoreSplitsSignaled(int subtaskIndex) {
118+
return noMoreSplitsSignaledReaders.contains(subtaskIndex);
119+
}
112120
}

seatunnel-engine/seatunnel-engine-server/src/test/java/org/apache/seatunnel/engine/server/task/SourceSplitEnumeratorTaskTest.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.seatunnel.engine.server.execution.TaskExecutionContext;
2626
import org.apache.seatunnel.engine.server.execution.TaskGroupLocation;
2727
import org.apache.seatunnel.engine.server.execution.TaskLocation;
28+
import org.apache.seatunnel.engine.server.task.context.SeaTunnelSplitEnumeratorContext;
2829

2930
import org.junit.jupiter.api.Assertions;
3031
import org.junit.jupiter.api.Test;
@@ -37,6 +38,7 @@
3738
import java.util.Collections;
3839
import java.util.HashSet;
3940
import java.util.concurrent.atomic.AtomicLong;
41+
import java.util.concurrent.atomic.AtomicReference;
4042

4143
public class SourceSplitEnumeratorTaskTest {
4244

@@ -102,4 +104,60 @@ void testOpenShouldBeforeReaderRegister() throws Exception {
102104

103105
Assertions.assertTrue(openTime.get() < registerReaderTime.get());
104106
}
107+
108+
@Test
109+
void testResignalNoMoreSplitsAfterReaderReregister() throws Exception {
110+
SeaTunnelSource source = Mockito.mock(SeaTunnelSource.class);
111+
SourceSplitEnumerator enumerator = Mockito.mock(SourceSplitEnumerator.class);
112+
113+
AtomicReference<SeaTunnelSplitEnumeratorContext> enumeratorContextRef =
114+
new AtomicReference<>();
115+
Mockito.when(source.createEnumerator(Mockito.any()))
116+
.thenAnswer(
117+
invocation -> {
118+
enumeratorContextRef.set(
119+
(SeaTunnelSplitEnumeratorContext) invocation.getArgument(0));
120+
return enumerator;
121+
});
122+
123+
SourceAction action =
124+
new SourceAction<>(1, "fake", source, new HashSet<>(), Collections.emptySet());
125+
SourceSplitEnumeratorTask enumeratorTask =
126+
new SourceSplitEnumeratorTask<>(
127+
1, new TaskLocation(new TaskGroupLocation(1, 1, 1), 1, 1), action);
128+
129+
TaskExecutionContext context = Mockito.mock(TaskExecutionContext.class);
130+
InvocationFuture future = Mockito.mock(InvocationFuture.class);
131+
Mockito.when(context.getOrCreateMetricsContext(Mockito.any())).thenReturn(null);
132+
Mockito.when(context.sendToMaster(Mockito.any())).thenReturn(future);
133+
Mockito.when(context.sendToMember(Mockito.any(), Mockito.any())).thenReturn(future);
134+
Mockito.when(future.join()).thenReturn(null);
135+
TaskExecutionService taskExecutionService = Mockito.mock(TaskExecutionService.class);
136+
Mockito.when(context.getTaskExecutionService()).thenReturn(taskExecutionService);
137+
138+
enumeratorTask.setTaskExecutionContext(context);
139+
enumeratorTask.init();
140+
enumeratorTask.restoreState(new ArrayList<>());
141+
142+
TaskLocation readerLocation = new TaskLocation(new TaskGroupLocation(1, 1, 1), 1, 1);
143+
Address address = Address.createUnresolvedAddress("localhost", 5701);
144+
145+
// Initial register
146+
enumeratorTask.receivedReader(readerLocation, address);
147+
148+
SeaTunnelSplitEnumeratorContext enumeratorContext = enumeratorContextRef.get();
149+
Assertions.assertNotNull(enumeratorContext);
150+
151+
Mockito.clearInvocations(context);
152+
153+
// Simulate that NoMoreSplitsEvent has been signaled once.
154+
enumeratorContext.signalNoMoreSplits(readerLocation.getTaskIndex());
155+
Assertions.assertTrue(
156+
enumeratorContext.hasNoMoreSplitsSignaled(readerLocation.getTaskIndex()));
157+
158+
// Reader re-registers after failover, framework should re-signal.
159+
enumeratorTask.receivedReader(readerLocation, address);
160+
161+
Mockito.verify(context, Mockito.times(2)).sendToMember(Mockito.any(), Mockito.any());
162+
}
105163
}

seatunnel-translation/seatunnel-translation-flink/seatunnel-translation-flink-common/src/main/java/org/apache/seatunnel/translation/flink/source/FlinkSource.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838

3939
import java.io.Serializable;
4040
import java.sql.DriverManager;
41+
import java.util.Set;
42+
import java.util.concurrent.ConcurrentHashMap;
4143

4244
/**
4345
* The source implementation of {@link Source}, used for proxy all {@link SeaTunnelSource} in flink.
@@ -91,21 +93,25 @@ public SourceReader<SeaTunnelRow, SplitWrapper<SplitT>> createReader(
9193
@Override
9294
public SplitEnumerator<SplitWrapper<SplitT>, EnumStateT> createEnumerator(
9395
SplitEnumeratorContext<SplitWrapper<SplitT>> enumContext) throws Exception {
96+
Set<Integer> noMoreSplitsSignaledReaders = ConcurrentHashMap.newKeySet();
9497
SourceSplitEnumerator.Context<SplitT> context =
95-
new FlinkSourceSplitEnumeratorContext<>(enumContext);
98+
new FlinkSourceSplitEnumeratorContext<>(
99+
enumContext, noMoreSplitsSignaledReaders::add);
96100
SourceSplitEnumerator<SplitT, EnumStateT> enumerator = source.createEnumerator(context);
97-
return new FlinkSourceEnumerator<>(enumerator, enumContext);
101+
return new FlinkSourceEnumerator<>(enumerator, enumContext, noMoreSplitsSignaledReaders);
98102
}
99103

100104
@Override
101105
public SplitEnumerator<SplitWrapper<SplitT>, EnumStateT> restoreEnumerator(
102106
SplitEnumeratorContext<SplitWrapper<SplitT>> enumContext, EnumStateT checkpoint)
103107
throws Exception {
108+
Set<Integer> noMoreSplitsSignaledReaders = ConcurrentHashMap.newKeySet();
104109
FlinkSourceSplitEnumeratorContext<SplitT> context =
105-
new FlinkSourceSplitEnumeratorContext<>(enumContext);
110+
new FlinkSourceSplitEnumeratorContext<>(
111+
enumContext, noMoreSplitsSignaledReaders::add);
106112
SourceSplitEnumerator<SplitT, EnumStateT> enumerator =
107113
source.restoreEnumerator(context, checkpoint);
108-
return new FlinkSourceEnumerator<>(enumerator, enumContext);
114+
return new FlinkSourceEnumerator<>(enumerator, enumContext, noMoreSplitsSignaledReaders);
109115
}
110116

111117
@Override

0 commit comments

Comments
 (0)