Skip to content

Add cancellation support in TransportGetAllocationStatsAction #127371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/127371.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 127371
summary: Add cancellation support in `TransportGetAllocationStatsAction`
area: Allocation
type: feature
issues:
- 123248
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.SingleResultDeduplicator;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequestParameters.Metric;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.SubscribableListener;
Expand All @@ -31,10 +30,12 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.util.CancellableSingleObjectCache;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
Expand All @@ -43,7 +44,8 @@
import java.io.IOException;
import java.util.EnumSet;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.ExecutorService;
import java.util.function.BooleanSupplier;

public class TransportGetAllocationStatsAction extends TransportMasterNodeReadAction<
TransportGetAllocationStatsAction.Request,
Expand All @@ -62,7 +64,6 @@ public class TransportGetAllocationStatsAction extends TransportMasterNodeReadAc
);

private final AllocationStatsCache allocationStatsCache;
private final SingleResultDeduplicator<Map<String, NodeAllocationStats>> allocationStatsSupplier;
private final DiskThresholdSettings diskThresholdSettings;

@Inject
Expand All @@ -85,21 +86,7 @@ public TransportGetAllocationStatsAction(
// very cheaply.
EsExecutors.DIRECT_EXECUTOR_SERVICE
);
final var managementExecutor = threadPool.executor(ThreadPool.Names.MANAGEMENT);
this.allocationStatsCache = new AllocationStatsCache(threadPool, DEFAULT_CACHE_TTL);
this.allocationStatsSupplier = new SingleResultDeduplicator<>(threadPool.getThreadContext(), l -> {
final var cachedStats = allocationStatsCache.get();
if (cachedStats != null) {
l.onResponse(cachedStats);
return;
}

managementExecutor.execute(ActionRunnable.supply(l, () -> {
final var stats = allocationStatsService.stats();
allocationStatsCache.put(stats);
return stats;
}));
});
this.allocationStatsCache = new AllocationStatsCache(threadPool, allocationStatsService, DEFAULT_CACHE_TTL);
this.diskThresholdSettings = new DiskThresholdSettings(clusterService.getSettings(), clusterService.getClusterSettings());
clusterService.getClusterSettings().initializeAndWatch(CACHE_TTL_SETTING, this.allocationStatsCache::setTTL);
}
Expand All @@ -118,8 +105,11 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
protected void masterOperation(Task task, Request request, ClusterState state, ActionListener<Response> listener) throws Exception {
// NB we are still on a transport thread here - if adding more functionality here make sure to fork to a different pool

assert task instanceof CancellableTask;
final var cancellableTask = (CancellableTask) task;

final SubscribableListener<Map<String, NodeAllocationStats>> allocationStatsStep = request.metrics().contains(Metric.ALLOCATIONS)
? SubscribableListener.newForked(allocationStatsSupplier::execute)
? SubscribableListener.newForked(l -> allocationStatsCache.get(cancellableTask::isCancelled, l))
: SubscribableListener.newSucceeded(Map.of());

allocationStatsStep.andThenApply(
Expand Down Expand Up @@ -167,6 +157,11 @@ public EnumSet<Metric> metrics() {
public ActionRequestValidationException validate() {
return null;
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, "", parentTaskId, headers);
}
}

public static class Response extends ActionResponse {
Expand Down Expand Up @@ -209,39 +204,60 @@ public DiskThresholdSettings getDiskThresholdSettings() {
}
}

private record CachedAllocationStats(Map<String, NodeAllocationStats> stats, long timestampMillis) {}

private static class AllocationStatsCache {
private static class AllocationStatsCache extends CancellableSingleObjectCache<Long, Long, Map<String, NodeAllocationStats>> {
private volatile long ttlMillis;
private final ThreadPool threadPool;
private final AtomicReference<CachedAllocationStats> cachedStats;
private final ExecutorService executorService;
private final AllocationStatsService allocationStatsService;

AllocationStatsCache(ThreadPool threadPool, TimeValue ttl) {
AllocationStatsCache(ThreadPool threadPool, AllocationStatsService allocationStatsService, TimeValue ttl) {
super(threadPool.getThreadContext());
this.threadPool = threadPool;
this.cachedStats = new AtomicReference<>();
this.executorService = threadPool.executor(ThreadPool.Names.MANAGEMENT);
this.allocationStatsService = allocationStatsService;
setTTL(ttl);
}

void setTTL(TimeValue ttl) {
ttlMillis = ttl.millis();
if (ttlMillis == 0L) {
cachedStats.set(null);
}
clearCacheIfDisabled();
}

Map<String, NodeAllocationStats> get() {
if (ttlMillis == 0L) {
return null;
void get(BooleanSupplier isCancelled, ActionListener<Map<String, NodeAllocationStats>> listener) {
get(threadPool.relativeTimeInMillis(), isCancelled, listener);
}

@Override
protected void refresh(
Long aLong,
Runnable ensureNotCancelled,
BooleanSupplier supersedeIfStale,
ActionListener<Map<String, NodeAllocationStats>> listener
) {
if (supersedeIfStale.getAsBoolean() == false) {
executorService.execute(
ActionRunnable.supply(
// If caching is disabled the item is only cached long enough to prevent duplicate concurrent requests.
ActionListener.runBefore(listener, this::clearCacheIfDisabled),
() -> allocationStatsService.stats(ensureNotCancelled)
)
);
}
}

// We don't set the atomic ref to null here upon expiration since we know it is about to be replaced with a fresh instance.
final var stats = cachedStats.get();
return stats == null || threadPool.relativeTimeInMillis() - stats.timestampMillis > ttlMillis ? null : stats.stats;
@Override
protected Long getKey(Long timestampMillis) {
return timestampMillis;
}

@Override
protected boolean isFresh(Long currentKey, Long newKey) {
return ttlMillis == 0 || newKey - currentKey <= ttlMillis;
}

void put(Map<String, NodeAllocationStats> stats) {
if (ttlMillis > 0L) {
cachedStats.set(new CachedAllocationStats(stats, threadPool.relativeTimeInMillis()));
private void clearCacheIfDisabled() {
if (ttlMillis == 0) {
clearCurrentCachedItem();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,22 @@ public AllocationStatsService(
* Returns a map of node IDs to node allocation stats.
*/
public Map<String, NodeAllocationStats> stats() {
return stats(() -> {});
}

/**
* Returns a map of node IDs to node allocation stats, promising to execute the provided {@link Runnable} during the computation to
* test for cancellation.
*/
public Map<String, NodeAllocationStats> stats(Runnable ensureNotCancelled) {
assert Transports.assertNotTransportThread("too expensive for a transport worker");

var clusterState = clusterService.state();
var nodesStatsAndWeights = nodeAllocationStatsAndWeightsCalculator.nodesAllocationStatsAndWeights(
clusterState.metadata(),
clusterState.getRoutingNodes(),
clusterInfoService.getClusterInfo(),
ensureNotCancelled,
desiredBalanceSupplier.get()
);
return nodesStatsAndWeights.entrySet()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public Map<String, NodeAllocationStatsAndWeight> nodesAllocationStatsAndWeights(
Metadata metadata,
RoutingNodes routingNodes,
ClusterInfo clusterInfo,
Runnable ensureNotCancelled,
@Nullable DesiredBalance desiredBalance
) {
if (metadata.hasAnyIndices()) {
Expand All @@ -78,6 +79,7 @@ public Map<String, NodeAllocationStatsAndWeight> nodesAllocationStatsAndWeights(
long forecastedDiskUsage = 0;
long currentDiskUsage = 0;
for (ShardRouting shardRouting : node) {
ensureNotCancelled.run();
if (shardRouting.relocating()) {
// Skip the shard if it is moving off this node. The node running recovery will count it.
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ private void updateDesireBalanceMetrics(
routingAllocation.metadata(),
routingAllocation.routingNodes(),
routingAllocation.clusterInfo(),
() -> {},
desiredBalance
);
Map<DiscoveryNode, NodeAllocationStatsAndWeightsCalculator.NodeAllocationStatsAndWeight> filteredNodeAllocationStatsAndWeights =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ protected boolean isFresh(Key currentKey, Key newKey) {
return currentKey.equals(newKey);
}

/**
* Sets the currently cached item reference to {@code null}, which will result in a {@code refresh()} on the next {@code get()} call.
*/
protected void clearCurrentCachedItem() {
this.currentCachedItemRef.set(null);
}

/**
* Start a retrieval for the value associated with the given {@code input}, and pass it to the given {@code listener}.
* <p>
Expand All @@ -110,7 +117,8 @@ protected boolean isFresh(Key currentKey, Key newKey) {
*
* @param input The input to compute the desired value, converted to a {@link Key} to determine if the value that's currently
* cached or pending is fresh enough.
* @param isCancelled Returns {@code true} if the listener no longer requires the value being computed.
* @param isCancelled Returns {@code true} if the listener no longer requires the value being computed. The listener is expected to be
* completed as soon as possible when cancellation is detected.
* @param listener The listener to notify when the desired value becomes available.
*/
public final void get(Input input, BooleanSupplier isCancelled, ActionListener<Value> listener) {
Expand Down Expand Up @@ -230,11 +238,15 @@ boolean addListener(ActionListener<Value> listener, BooleanSupplier isCancelled)
ActionListener.completeWith(listener, future::actionResult);
} else {
// Refresh is still pending; it's not cancelled because there are still references.
future.addListener(ContextPreservingActionListener.wrapPreservingContext(listener, threadContext));
final var cancellableListener = ActionListener.notifyOnce(
ContextPreservingActionListener.wrapPreservingContext(listener, threadContext)
);
future.addListener(cancellableListener);
final AtomicBoolean released = new AtomicBoolean();
cancellationChecks.add(() -> {
if (released.get() == false && isCancelled.getAsBoolean() && released.compareAndSet(false, true)) {
decRef();
cancellableListener.onFailure(new TaskCancelledException("task cancelled"));
}
});
}
Expand Down
Loading
Loading