Skip to content

Commit f4b6086

Browse files
authored
Fix race condition when resolving new location for multiple shards at once (#128062)
1 parent 27a3eb0 commit f4b6086

File tree

2 files changed

+48
-22
lines changed

2 files changed

+48
-22
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,7 @@ private void trySendingRequestsForPendingShards(TargetShards targetShards, Compu
192192
var pendingRetries = new HashSet<ShardId>();
193193
for (ShardId shardId : pendingShardIds) {
194194
if (targetShards.getShard(shardId).remainingNodes.isEmpty()) {
195-
var failure = shardFailures.get(shardId);
196-
if (failure != null && failure.fatal == false && failure.failure instanceof NoShardAvailableActionException) {
195+
if (isRetryableFailure(shardFailures.get(shardId))) {
197196
pendingRetries.add(shardId);
198197
}
199198
}
@@ -204,7 +203,8 @@ private void trySendingRequestsForPendingShards(TargetShards targetShards, Compu
204203
}
205204
}
206205
for (ShardId shardId : pendingShardIds) {
207-
if (targetShards.getShard(shardId).remainingNodes.isEmpty()) {
206+
if (targetShards.getShard(shardId).remainingNodes.isEmpty()
207+
&& (isRetryableFailure(shardFailures.get(shardId)) == false || pendingRetries.contains(shardId))) {
208208
shardFailures.compute(
209209
shardId,
210210
(k, v) -> new ShardFailure(
@@ -378,6 +378,10 @@ record NodeRequest(DiscoveryNode node, List<ShardId> shardIds, Map<Index, AliasF
378378

379379
private record ShardFailure(boolean fatal, Exception failure) {}
380380

381+
private static boolean isRetryableFailure(ShardFailure failure) {
382+
return failure != null && failure.fatal == false && failure.failure instanceof NoShardAvailableActionException;
383+
}
384+
381385
/**
382386
* Selects the next nodes to send requests to. Limits to at most one outstanding request per node.
383387
* If there is already a request in-flight to a node, another request will not be sent to the same node

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderTests.java

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@
5555
import java.util.concurrent.atomic.AtomicBoolean;
5656
import java.util.concurrent.atomic.AtomicInteger;
5757
import java.util.function.Function;
58-
import java.util.stream.Collectors;
5958

59+
import static java.util.stream.Collectors.toMap;
6060
import static org.elasticsearch.cluster.node.DiscoveryNodeRole.DATA_COLD_NODE_ROLE;
6161
import static org.elasticsearch.cluster.node.DiscoveryNodeRole.DATA_FROZEN_NODE_ROLE;
6262
import static org.elasticsearch.cluster.node.DiscoveryNodeRole.DATA_HOT_NODE_ROLE;
@@ -450,6 +450,32 @@ public void testRetryMovedShard() {
450450
assertThat(attempt.get(), equalTo(3));
451451
}
452452

453+
public void testRetryMultipleMovedShards() {
454+
var attempt = new AtomicInteger(0);
455+
var response = safeGet(
456+
sendRequests(
457+
randomBoolean(),
458+
-1,
459+
List.of(targetShard(shard1, node1), targetShard(shard2, node2), targetShard(shard3, node3)),
460+
shardIds -> shardIds.stream().collect(toMap(Function.identity(), shardId -> List.of(randomFrom(node1, node2, node3)))),
461+
(node, shardIds, aliasFilters, listener) -> runWithDelay(
462+
() -> listener.onResponse(
463+
attempt.incrementAndGet() <= 6
464+
? new DataNodeComputeResponse(
465+
DriverCompletionInfo.EMPTY,
466+
shardIds.stream().collect(toMap(Function.identity(), ShardNotFoundException::new))
467+
)
468+
: new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of())
469+
)
470+
)
471+
)
472+
);
473+
assertThat(response.totalShards, equalTo(3));
474+
assertThat(response.successfulShards, equalTo(3));
475+
assertThat(response.skippedShards, equalTo(0));
476+
assertThat(response.failedShards, equalTo(0));
477+
}
478+
453479
public void testDoesNotRetryMovedShardIndefinitely() {
454480
var attempt = new AtomicInteger(0);
455481
var response = safeGet(sendRequests(true, -1, List.of(targetShard(shard1, node1)), shardIds -> {
@@ -517,28 +543,28 @@ public void testRetryUnassignedShardWithoutPartialResults() {
517543

518544
);
519545
expectThrows(NoShardAvailableActionException.class, containsString("no such shard"), future::actionGet);
546+
assertThat(attempt.get(), equalTo(1));
520547
}
521548

522549
public void testRetryUnassignedShardWithPartialResults() {
523-
var response = safeGet(
524-
sendRequests(
525-
true,
526-
-1,
527-
List.of(targetShard(shard1, node1), targetShard(shard2, node2)),
528-
shardIds -> Map.of(shard1, List.of()),
529-
(node, shardIds, aliasFilters, listener) -> runWithDelay(
530-
() -> listener.onResponse(
531-
Objects.equals(shardIds, List.of(shard2))
532-
? new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of())
533-
: new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of(shard1, new ShardNotFoundException(shard1)))
534-
)
550+
var attempt = new AtomicInteger(0);
551+
var response = safeGet(sendRequests(true, -1, List.of(targetShard(shard1, node1), targetShard(shard2, node2)), shardIds -> {
552+
attempt.incrementAndGet();
553+
return Map.of(shard1, List.of());
554+
},
555+
(node, shardIds, aliasFilters, listener) -> runWithDelay(
556+
() -> listener.onResponse(
557+
Objects.equals(shardIds, List.of(shard2))
558+
? new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of())
559+
: new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of(shard1, new ShardNotFoundException(shard1)))
535560
)
536561
)
537-
);
562+
));
538563
assertThat(response.totalShards, equalTo(2));
539564
assertThat(response.successfulShards, equalTo(1));
540565
assertThat(response.skippedShards, equalTo(0));
541566
assertThat(response.failedShards, equalTo(1));
567+
assertThat(attempt.get(), equalTo(1));
542568
}
543569

544570
static DataNodeRequestSender.TargetShard targetShard(ShardId shardId, DiscoveryNode... nodes) {
@@ -621,11 +647,7 @@ PlainActionFuture<ComputeResponse> sendRequests(
621647
void searchShards(Set<String> concreteIndices, ActionListener<TargetShards> listener) {
622648
runWithDelay(
623649
() -> listener.onResponse(
624-
new TargetShards(
625-
shards.stream().collect(Collectors.toMap(TargetShard::shardId, Function.identity())),
626-
shards.size(),
627-
0
628-
)
650+
new TargetShards(shards.stream().collect(toMap(TargetShard::shardId, Function.identity())), shards.size(), 0)
629651
)
630652
);
631653
}

0 commit comments

Comments
 (0)