Skip to content

Commit f4b04ec

Browse files
authored
[FLINK-34504][autoscaler] Avoid the parallelism adjustment when the upstream shuffle type doesn't have keyBy (#783)
1 parent d738c57 commit f4b04ec

16 files changed

+459
-151
lines changed

Diff for: flink-autoscaler/src/main/java/org/apache/flink/autoscaler/JobVertexScaler.java

+44-21
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.apache.flink.autoscaler.event.AutoScalerEventHandler;
2323
import org.apache.flink.autoscaler.metrics.EvaluatedScalingMetric;
2424
import org.apache.flink.autoscaler.metrics.ScalingMetric;
25+
import org.apache.flink.autoscaler.topology.ShipStrategy;
2526
import org.apache.flink.autoscaler.utils.AutoScalerUtils;
2627
import org.apache.flink.configuration.Configuration;
2728
import org.apache.flink.runtime.jobgraph.JobVertexID;
@@ -34,6 +35,7 @@
3435
import java.time.Duration;
3536
import java.time.Instant;
3637
import java.time.ZoneId;
38+
import java.util.Collection;
3739
import java.util.Map;
3840
import java.util.SortedMap;
3941

@@ -48,6 +50,7 @@
4850
import static org.apache.flink.autoscaler.metrics.ScalingMetric.MAX_PARALLELISM;
4951
import static org.apache.flink.autoscaler.metrics.ScalingMetric.PARALLELISM;
5052
import static org.apache.flink.autoscaler.metrics.ScalingMetric.TRUE_PROCESSING_RATE;
53+
import static org.apache.flink.autoscaler.topology.ShipStrategy.HASH;
5154

5255
/** Component responsible for computing vertex parallelism based on the scaling metrics. */
5356
public class JobVertexScaler<KEY, Context extends JobAutoScalerContext<KEY>> {
@@ -71,6 +74,7 @@ public JobVertexScaler(AutoScalerEventHandler<KEY, Context> autoScalerEventHandl
7174
public int computeScaleTargetParallelism(
7275
Context context,
7376
JobVertexID vertex,
77+
Collection<ShipStrategy> inputShipStrategies,
7478
Map<ScalingMetric, EvaluatedScalingMetric> evaluatedMetrics,
7579
SortedMap<Instant, ScalingSummary> history,
7680
Duration restartTime) {
@@ -121,6 +125,7 @@ public int computeScaleTargetParallelism(
121125
int newParallelism =
122126
scale(
123127
currentParallelism,
128+
inputShipStrategies,
124129
(int) evaluatedMetrics.get(MAX_PARALLELISM).getCurrent(),
125130
scaleFactor,
126131
Math.min(currentParallelism, conf.getInteger(VERTEX_MIN_PARALLELISM)),
@@ -245,50 +250,68 @@ private boolean detectIneffectiveScaleUp(
245250
}
246251
}
247252

253+
/**
254+
* Computing the newParallelism. In general, newParallelism = currentParallelism * scaleFactor.
255+
* But we limit newParallelism between parallelismLowerLimit and min(parallelismUpperLimit,
256+
* maxParallelism).
257+
*
258+
* <p>Also, in order to ensure the data is evenly spread across subtasks, we try to adjust the
259+
* parallelism for source and keyed vertex such that it divides the maxParallelism without a
260+
* remainder.
261+
*/
248262
@VisibleForTesting
249263
protected static int scale(
250-
int parallelism,
251-
int numKeyGroups,
264+
int currentParallelism,
265+
Collection<ShipStrategy> inputShipStrategies,
266+
int maxParallelism,
252267
double scaleFactor,
253-
int minParallelism,
254-
int maxParallelism) {
268+
int parallelismLowerLimit,
269+
int parallelismUpperLimit) {
255270
Preconditions.checkArgument(
256-
minParallelism <= maxParallelism,
257-
"The minimum parallelism must not be greater than the maximum parallelism.");
258-
if (minParallelism > numKeyGroups) {
271+
parallelismLowerLimit <= parallelismUpperLimit,
272+
"The parallelism lower limitation must not be greater than the parallelism upper limitation.");
273+
if (parallelismLowerLimit > maxParallelism) {
259274
LOG.warn(
260275
"Specified autoscaler minimum parallelism {} is greater than the operator max parallelism {}. The min parallelism will be set to the operator max parallelism.",
261-
minParallelism,
262-
numKeyGroups);
276+
parallelismLowerLimit,
277+
maxParallelism);
263278
}
264-
if (numKeyGroups < maxParallelism && maxParallelism != Integer.MAX_VALUE) {
279+
if (maxParallelism < parallelismUpperLimit && parallelismUpperLimit != Integer.MAX_VALUE) {
265280
LOG.debug(
266281
"Specified autoscaler maximum parallelism {} is greater than the operator max parallelism {}. This means the operator max parallelism can never be reached.",
267-
maxParallelism,
268-
numKeyGroups);
282+
parallelismUpperLimit,
283+
maxParallelism);
269284
}
270285

271286
int newParallelism =
272287
// Prevent integer overflow when converting from double to integer.
273288
// We do not have to detect underflow because doubles cannot
274289
// underflow.
275-
(int) Math.min(Math.ceil(scaleFactor * parallelism), Integer.MAX_VALUE);
290+
(int) Math.min(Math.ceil(scaleFactor * currentParallelism), Integer.MAX_VALUE);
276291

277-
// Cap parallelism at either number of key groups or parallelism limit
278-
final int upperBound = Math.min(numKeyGroups, maxParallelism);
292+
// Cap parallelism at either maxParallelism(number of key groups or source partitions) or
293+
// parallelism upper limit
294+
final int upperBound = Math.min(maxParallelism, parallelismUpperLimit);
279295

280296
// Apply min/max parallelism
281-
newParallelism = Math.min(Math.max(minParallelism, newParallelism), upperBound);
297+
newParallelism = Math.min(Math.max(parallelismLowerLimit, newParallelism), upperBound);
298+
299+
var adjustByMaxParallelism =
300+
inputShipStrategies.isEmpty() || inputShipStrategies.contains(HASH);
301+
if (!adjustByMaxParallelism) {
302+
return newParallelism;
303+
}
282304

283-
// Try to adjust the parallelism such that it divides the number of key groups without a
284-
// remainder => state is evenly spread across subtasks
285-
for (int p = newParallelism; p <= numKeyGroups / 2 && p <= upperBound; p++) {
286-
if (numKeyGroups % p == 0) {
305+
// When the shuffle type of vertex inputs contains keyBy or vertex is a source, we try to
306+
// adjust the parallelism such that it divides the maxParallelism without a remainder
307+
// => data is evenly spread across subtasks
308+
for (int p = newParallelism; p <= maxParallelism / 2 && p <= upperBound; p++) {
309+
if (maxParallelism % p == 0) {
287310
return p;
288311
}
289312
}
290313

291-
// If key group adjustment fails, use originally computed parallelism
314+
// If parallelism adjustment fails, use originally computed parallelism
292315
return newParallelism;
293316
}
294317

Diff for: flink-autoscaler/src/main/java/org/apache/flink/autoscaler/ScalingExecutor.java

+6-2
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ public boolean scaleResource(
102102
var restartTime = scalingTracking.getMaxRestartTimeOrDefault(conf);
103103

104104
var scalingSummaries =
105-
computeScalingSummary(context, evaluatedMetrics, scalingHistory, restartTime);
105+
computeScalingSummary(
106+
context, evaluatedMetrics, scalingHistory, restartTime, jobTopology);
106107

107108
if (scalingSummaries.isEmpty()) {
108109
LOG.info("All job vertices are currently running at their target parallelism.");
@@ -203,7 +204,8 @@ Map<JobVertexID, ScalingSummary> computeScalingSummary(
203204
Context context,
204205
EvaluatedMetrics evaluatedMetrics,
205206
Map<JobVertexID, SortedMap<Instant, ScalingSummary>> scalingHistory,
206-
Duration restartTime) {
207+
Duration restartTime,
208+
JobTopology jobTopology) {
207209
LOG.debug("Restart time used in scaling summary computation: {}", restartTime);
208210

209211
if (isJobUnderMemoryPressure(context, evaluatedMetrics.getGlobalMetrics())) {
@@ -225,10 +227,12 @@ Map<JobVertexID, ScalingSummary> computeScalingSummary(
225227
} else {
226228
var currentParallelism =
227229
(int) metrics.get(ScalingMetric.PARALLELISM).getCurrent();
230+
228231
var newParallelism =
229232
jobVertexScaler.computeScaleTargetParallelism(
230233
context,
231234
v,
235+
jobTopology.get(v).getInputs().values(),
232236
metrics,
233237
scalingHistory.getOrDefault(
234238
v, Collections.emptySortedMap()),

Diff for: flink-autoscaler/src/main/java/org/apache/flink/autoscaler/topology/JobTopology.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public JobTopology(VertexInfo... vertexInfo) {
6161

6262
public JobTopology(Set<VertexInfo> vertexInfo) {
6363

64-
Map<JobVertexID, Map<JobVertexID, String>> vertexOutputs = new HashMap<>();
64+
Map<JobVertexID, Map<JobVertexID, ShipStrategy>> vertexOutputs = new HashMap<>();
6565
vertexInfos =
6666
ImmutableMap.copyOf(
6767
vertexInfo.stream().collect(Collectors.toMap(VertexInfo::getId, v -> v)));
@@ -145,7 +145,7 @@ public static JobTopology fromJsonPlan(
145145

146146
for (JsonNode node : nodes) {
147147
var vertexId = JobVertexID.fromHexString(node.get("id").asText());
148-
var inputs = new HashMap<JobVertexID, String>();
148+
var inputs = new HashMap<JobVertexID, ShipStrategy>();
149149
var ioMetrics = metrics.get(vertexId);
150150
var finished = finishedVertices.contains(vertexId);
151151
vertexInfo.add(
@@ -160,7 +160,7 @@ public static JobTopology fromJsonPlan(
160160
for (JsonNode input : node.get("inputs")) {
161161
inputs.put(
162162
JobVertexID.fromHexString(input.get("id").asText()),
163-
input.get("ship_strategy").asText());
163+
ShipStrategy.of(input.get("ship_strategy").asText()));
164164
}
165165
}
166166
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.flink.autoscaler.topology;
19+
20+
import javax.annotation.Nonnull;
21+
22+
/** The ship strategy between 2 JobVertices. */
23+
public enum ShipStrategy {
24+
HASH,
25+
26+
REBALANCE,
27+
28+
RESCALE,
29+
30+
FORWARD,
31+
32+
CUSTOM,
33+
34+
BROADCAST,
35+
36+
GLOBAL,
37+
38+
SHUFFLE,
39+
40+
UNKNOWN;
41+
42+
/**
43+
* Generates a ShipStrategy from a string, or returns {@link #UNKNOWN} if the value cannot match
44+
* any ShipStrategy.
45+
*/
46+
@Nonnull
47+
public static ShipStrategy of(String value) {
48+
for (ShipStrategy shipStrategy : ShipStrategy.values()) {
49+
if (shipStrategy.toString().equalsIgnoreCase(value)) {
50+
return shipStrategy;
51+
}
52+
}
53+
return UNKNOWN;
54+
}
55+
}

Diff for: flink-autoscaler/src/main/java/org/apache/flink/autoscaler/topology/VertexInfo.java

+8-5
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ public class VertexInfo {
3131
private final JobVertexID id;
3232

3333
// All input vertices and the ship_strategy
34-
private final Map<JobVertexID, String> inputs;
34+
private final Map<JobVertexID, ShipStrategy> inputs;
3535

3636
// All output vertices and the ship_strategy
37-
private Map<JobVertexID, String> outputs;
37+
private Map<JobVertexID, ShipStrategy> outputs;
3838

3939
private final int parallelism;
4040

@@ -48,7 +48,7 @@ public class VertexInfo {
4848

4949
public VertexInfo(
5050
JobVertexID id,
51-
Map<JobVertexID, String> inputs,
51+
Map<JobVertexID, ShipStrategy> inputs,
5252
int parallelism,
5353
int maxParallelism,
5454
boolean finished,
@@ -65,7 +65,7 @@ public VertexInfo(
6565
@VisibleForTesting
6666
public VertexInfo(
6767
JobVertexID id,
68-
Map<JobVertexID, String> inputs,
68+
Map<JobVertexID, ShipStrategy> inputs,
6969
int parallelism,
7070
int maxParallelism,
7171
IOMetrics ioMetrics) {
@@ -74,7 +74,10 @@ public VertexInfo(
7474

7575
@VisibleForTesting
7676
public VertexInfo(
77-
JobVertexID id, Map<JobVertexID, String> inputs, int parallelism, int maxParallelism) {
77+
JobVertexID id,
78+
Map<JobVertexID, ShipStrategy> inputs,
79+
int parallelism,
80+
int maxParallelism) {
7881
this(id, inputs, parallelism, maxParallelism, null);
7982
}
8083

Diff for: flink-autoscaler/src/main/java/org/apache/flink/autoscaler/tuning/MemoryTuning.java

+12-9
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.apache.flink.autoscaler.metrics.EvaluatedScalingMetric;
2727
import org.apache.flink.autoscaler.metrics.ScalingMetric;
2828
import org.apache.flink.autoscaler.topology.JobTopology;
29+
import org.apache.flink.autoscaler.topology.ShipStrategy;
2930
import org.apache.flink.autoscaler.topology.VertexInfo;
3031
import org.apache.flink.autoscaler.utils.ResourceCheckUtils;
3132
import org.apache.flink.configuration.Configuration;
@@ -53,6 +54,8 @@
5354
import static org.apache.flink.autoscaler.metrics.ScalingMetric.HEAP_MEMORY_USED;
5455
import static org.apache.flink.autoscaler.metrics.ScalingMetric.MANAGED_MEMORY_USED;
5556
import static org.apache.flink.autoscaler.metrics.ScalingMetric.METASPACE_MEMORY_USED;
57+
import static org.apache.flink.autoscaler.topology.ShipStrategy.FORWARD;
58+
import static org.apache.flink.autoscaler.topology.ShipStrategy.RESCALE;
5659

5760
/** Tunes the TaskManager memory. */
5861
public class MemoryTuning {
@@ -254,9 +257,9 @@ private static MemorySize adjustNetworkMemory(
254257
long maxNetworkMemory = 0;
255258
for (VertexInfo vertexInfo : jobTopology.getVertexInfos().values()) {
256259
// Add max amount of memory for each input gate
257-
for (Map.Entry<JobVertexID, String> inputEntry : vertexInfo.getInputs().entrySet()) {
258-
final JobVertexID inputVertexId = inputEntry.getKey();
259-
final String shipStrategy = inputEntry.getValue();
260+
for (var inputEntry : vertexInfo.getInputs().entrySet()) {
261+
var inputVertexId = inputEntry.getKey();
262+
var shipStrategy = inputEntry.getValue();
260263
maxNetworkMemory +=
261264
calculateNetworkSegmentNumber(
262265
updatedParallelisms.get(vertexInfo.getId()),
@@ -268,9 +271,9 @@ private static MemorySize adjustNetworkMemory(
268271
}
269272
// Add max amount of memory for each output gate
270273
// Usually, there is just one output per task
271-
for (Map.Entry<JobVertexID, String> outputEntry : vertexInfo.getOutputs().entrySet()) {
272-
final JobVertexID outputVertexId = outputEntry.getKey();
273-
final String shipStrategy = outputEntry.getValue();
274+
for (var outputEntry : vertexInfo.getOutputs().entrySet()) {
275+
var outputVertexId = outputEntry.getKey();
276+
var shipStrategy = outputEntry.getValue();
274277
maxNetworkMemory +=
275278
calculateNetworkSegmentNumber(
276279
updatedParallelisms.get(vertexInfo.getId()),
@@ -300,15 +303,15 @@ private static MemorySize adjustNetworkMemory(
300303
static int calculateNetworkSegmentNumber(
301304
int currentVertexParallelism,
302305
int connectedVertexParallelism,
303-
String shipStrategy,
306+
ShipStrategy shipStrategy,
304307
int buffersPerChannel,
305308
int floatingBuffers) {
306309
// TODO When the parallelism is changed via the rescale api, the FORWARD may be changed to
307310
// RESCALE. This logic may needs to be updated after FLINK-33123.
308311
if (currentVertexParallelism == connectedVertexParallelism
309-
&& "FORWARD".equals(shipStrategy)) {
312+
&& FORWARD.equals(shipStrategy)) {
310313
return buffersPerChannel + floatingBuffers;
311-
} else if ("FORWARD".equals(shipStrategy) || "RESCALE".equals(shipStrategy)) {
314+
} else if (FORWARD.equals(shipStrategy) || RESCALE.equals(shipStrategy)) {
312315
final int channelCount =
313316
(int) Math.ceil(connectedVertexParallelism / (double) currentVertexParallelism);
314317
return channelCount * buffersPerChannel + floatingBuffers;

Diff for: flink-autoscaler/src/test/java/org/apache/flink/autoscaler/BacklogBasedScalingTest.java

+5-4
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
import static org.apache.flink.autoscaler.JobAutoScalerImpl.AUTOSCALER_ERROR;
4747
import static org.apache.flink.autoscaler.TestingAutoscalerUtils.createDefaultJobAutoScalerContext;
48+
import static org.apache.flink.autoscaler.topology.ShipStrategy.REBALANCE;
4849
import static org.assertj.core.api.Assertions.assertThat;
4950
import static org.junit.jupiter.api.Assertions.assertEquals;
5051
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -81,7 +82,7 @@ public void setup() {
8182
new VertexInfo(source1, Map.of(), 1, 720, new IOMetrics(0, 0, 0)),
8283
new VertexInfo(
8384
sink,
84-
Map.of(source1, "REBALANCE"),
85+
Map.of(source1, REBALANCE),
8586
1,
8687
720,
8788
new IOMetrics(0, 0, 0))));
@@ -157,7 +158,7 @@ public void test() throws Exception {
157158
metricsCollector.setJobTopology(
158159
new JobTopology(
159160
new VertexInfo(source1, Map.of(), 4, 24),
160-
new VertexInfo(sink, Map.of(source1, "REBALANCE"), 4, 720)));
161+
new VertexInfo(sink, Map.of(source1, REBALANCE), 4, 720)));
161162

162163
metricsCollector.updateMetrics(
163164
source1,
@@ -239,7 +240,7 @@ public void test() throws Exception {
239240
metricsCollector.setJobTopology(
240241
new JobTopology(
241242
new VertexInfo(source1, Map.of(), 2, 24),
242-
new VertexInfo(sink, Map.of(source1, "REBALANCE"), 2, 720)));
243+
new VertexInfo(sink, Map.of(source1, REBALANCE), 2, 720)));
243244

244245
/* Test stability while processing backlog. */
245246

@@ -361,7 +362,7 @@ public void shouldTrackRestartDurationCorrectly() throws Exception {
361362
metricsCollector.setJobTopology(
362363
new JobTopology(
363364
new VertexInfo(source1, Map.of(), 4, 720),
364-
new VertexInfo(sink, Map.of(source1, "REBALANCE"), 4, 720)));
365+
new VertexInfo(sink, Map.of(source1, REBALANCE), 4, 720)));
365366

366367
var expectedEndTime = Instant.ofEpochMilli(10);
367368
metricsCollector.setJobUpdateTs(expectedEndTime);

0 commit comments

Comments
 (0)