diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractContainerOptions.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractContainerOptions.java index 81f4eb3f2..5beda1765 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractContainerOptions.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractContainerOptions.java @@ -20,6 +20,7 @@ import io.awspring.cloud.sqs.support.converter.MessagingMessageConverter; import io.awspring.cloud.sqs.support.converter.SqsMessagingMessageConverter; import java.time.Duration; +import java.util.function.Supplier; import org.springframework.core.task.TaskExecutor; import org.springframework.lang.Nullable; import org.springframework.retry.backoff.BackOffPolicy; @@ -47,12 +48,16 @@ public abstract class AbstractContainerOptions, private final Duration maxDelayBetweenPolls; + private final Duration standbyLimitPollingInterval; + private final Duration listenerShutdownTimeout; private final Duration acknowledgementShutdownTimeout; private final BackPressureMode backPressureMode; + private final Supplier backPressureHandlerSupplier; + private final ListenerMode listenerMode; private final MessagingMessageConverter messageConverter; @@ -80,10 +85,12 @@ protected AbstractContainerOptions(Builder builder) { this.autoStartup = builder.autoStartup; this.pollTimeout = builder.pollTimeout; this.pollBackOffPolicy = builder.pollBackOffPolicy; + this.standbyLimitPollingInterval = builder.standbyLimitPollingInterval; this.maxDelayBetweenPolls = builder.maxDelayBetweenPolls; this.listenerShutdownTimeout = builder.listenerShutdownTimeout; this.acknowledgementShutdownTimeout = builder.acknowledgementShutdownTimeout; this.backPressureMode = builder.backPressureMode; + this.backPressureHandlerSupplier = builder.backPressureHandlerSupplier; this.listenerMode = builder.listenerMode; this.messageConverter = builder.messageConverter; this.acknowledgementMode = builder.acknowledgementMode; @@ -122,6 +129,11 @@ public BackOffPolicy getPollBackOffPolicy() { return this.pollBackOffPolicy; } + @Override + public Duration getStandbyLimitPollingInterval() { + return this.standbyLimitPollingInterval; + } + @Override public Duration getMaxDelayBetweenPolls() { return this.maxDelayBetweenPolls; @@ -154,6 +166,11 @@ public BackPressureMode getBackPressureMode() { return this.backPressureMode; } + @Override + public Supplier getBackPressureHandlerSupplier() { + return this.backPressureHandlerSupplier; + } + @Override public ListenerMode getListenerMode() { return this.listenerMode; @@ -206,6 +223,8 @@ protected abstract static class Builder, private static final BackOffPolicy DEFAULT_POLL_BACK_OFF_POLICY = buildDefaultBackOffPolicy(); + private static final Duration DEFAULT_STANDBY_LIMIT_POLLING_INTERVAL = Duration.ofMillis(100); + private static final Duration DEFAULT_SEMAPHORE_TIMEOUT = Duration.ofSeconds(10); private static final Duration DEFAULT_LISTENER_SHUTDOWN_TIMEOUT = Duration.ofSeconds(20); @@ -214,6 +233,8 @@ protected abstract static class Builder, private static final BackPressureMode DEFAULT_THROUGHPUT_CONFIGURATION = BackPressureMode.AUTO; + private static final Supplier DEFAULT_BACKPRESSURE_LIMITER = null; + private static final ListenerMode DEFAULT_MESSAGE_DELIVERY_STRATEGY = ListenerMode.SINGLE_MESSAGE; private static final MessagingMessageConverter DEFAULT_MESSAGE_CONVERTER = new SqsMessagingMessageConverter(); @@ -230,10 +251,14 @@ protected abstract static class Builder, private BackOffPolicy pollBackOffPolicy = DEFAULT_POLL_BACK_OFF_POLICY; + private Duration standbyLimitPollingInterval = DEFAULT_STANDBY_LIMIT_POLLING_INTERVAL; + private Duration maxDelayBetweenPolls = DEFAULT_SEMAPHORE_TIMEOUT; private BackPressureMode backPressureMode = DEFAULT_THROUGHPUT_CONFIGURATION; + private Supplier backPressureHandlerSupplier = DEFAULT_BACKPRESSURE_LIMITER; + private Duration listenerShutdownTimeout = DEFAULT_LISTENER_SHUTDOWN_TIMEOUT; private Duration acknowledgementShutdownTimeout = DEFAULT_ACKNOWLEDGEMENT_SHUTDOWN_TIMEOUT; @@ -272,6 +297,7 @@ protected Builder(AbstractContainerOptions options) { this.listenerShutdownTimeout = options.listenerShutdownTimeout; this.acknowledgementShutdownTimeout = options.acknowledgementShutdownTimeout; this.backPressureMode = options.backPressureMode; + this.backPressureHandlerSupplier = options.backPressureHandlerSupplier; this.listenerMode = options.listenerMode; this.messageConverter = options.messageConverter; this.acknowledgementMode = options.acknowledgementMode; @@ -315,6 +341,13 @@ public B pollBackOffPolicy(BackOffPolicy pollBackOffPolicy) { return self(); } + @Override + public B standbyLimitPollingInterval(Duration standbyLimitPollingInterval) { + Assert.notNull(standbyLimitPollingInterval, "standbyLimitPollingInterval cannot be null"); + this.standbyLimitPollingInterval = standbyLimitPollingInterval; + return self(); + } + @Override public B maxDelayBetweenPolls(Duration maxDelayBetweenPolls) { Assert.notNull(maxDelayBetweenPolls, "semaphoreAcquireTimeout cannot be null"); @@ -364,6 +397,12 @@ public B backPressureMode(BackPressureMode backPressureMode) { return self(); } + @Override + public B backPressureHandlerSupplier(Supplier backPressureHandlerSupplier) { + this.backPressureHandlerSupplier = backPressureHandlerSupplier; + return self(); + } + @Override public B acknowledgementInterval(Duration acknowledgementInterval) { Assert.notNull(acknowledgementInterval, "acknowledgementInterval cannot be null"); diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractPipelineMessageListenerContainer.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractPipelineMessageListenerContainer.java index 6808f647a..6d7a2637a 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractPipelineMessageListenerContainer.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractPipelineMessageListenerContainer.java @@ -225,6 +225,10 @@ private TaskExecutor validateCustomExecutor(TaskExecutor taskExecutor) { } protected BackPressureHandler createBackPressureHandler() { + O containerOptions = getContainerOptions(); + if (containerOptions.getBackPressureHandlerSupplier() != null) { + return containerOptions.getBackPressureHandlerSupplier().get(); + } return SemaphoreBackPressureHandler.builder().batchSize(getContainerOptions().getMaxMessagesPerPoll()) .totalPermits(getContainerOptions().getMaxConcurrentMessages()) .acquireTimeout(getContainerOptions().getMaxDelayBetweenPolls()) diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BackPressureHandler.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BackPressureHandler.java index 1d76d6589..55e5a25f0 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BackPressureHandler.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BackPressureHandler.java @@ -29,7 +29,7 @@ public interface BackPressureHandler { /** - * Request a number of permits. Each obtained permit allows the + * Requests a number of permits. Each obtained permit allows the * {@link io.awspring.cloud.sqs.listener.source.MessageSource} to retrieve one message. * @param amount the amount of permits to request. * @return the amount of permits obtained. @@ -37,12 +37,40 @@ public interface BackPressureHandler { */ int request(int amount) throws InterruptedException; + /** + * Releases the specified amount of permits for processed messages. Each message that has been processed should + * release one permit, whether processing was successful or not. + *

+ * This method can is called in the following use cases: + *

    + *
  • {@link ReleaseReason#LIMITED}: permits were not used because another BackPressureHandler has a lower permits + * limit and the difference in permits needs to be returned.
  • + *
  • {@link ReleaseReason#NONE_FETCHED}: none of the permits were actually used because no messages were retrieved + * from SQS. Permits need to be returned.
  • + *
  • {@link ReleaseReason#PARTIAL_FETCH}: some of the permits were used (some messages were retrieved from SQS). + * The unused ones need to be returned. The amount to be returned might be {@literal 0}, in which case it means all + * the permits will be used as the same number of messages were fetched from SQS.
  • + *
  • {@link ReleaseReason#PROCESSED}: a message processing finished, successfully or not.
  • + *
+ * @param amount the amount of permits to release. + * @param reason the reason why the permits were released. + */ + default void release(int amount, ReleaseReason reason) { + release(amount); + } + /** * Release the specified amount of permits. Each message that has been processed should release one permit, whether * processing was successful or not. * @param amount the amount of permits to release. + * + * @deprecated This method is deprecated and will not be called by the Spring Cloud AWS SQS listener anymore. + * Implement {@link #release(int, ReleaseReason)} instead. */ - void release(int amount); + @Deprecated + default void release(int amount) { + release(amount, ReleaseReason.PROCESSED); + } /** * Attempts to acquire all permits up to the specified timeout. If successful, means all permits were returned and @@ -52,4 +80,24 @@ public interface BackPressureHandler { */ boolean drain(Duration timeout); + enum ReleaseReason { + /** + * Permits were not used because another BackPressureHandler has a lower permits limit and the difference need + * to be aligned across all handlers. + */ + LIMITED, + /** + * No messages were retrieved from SQS, so all permits need to be returned. + */ + NONE_FETCHED, + /** + * Some messages were fetched from SQS. Unused permits need to be returned. + */ + PARTIAL_FETCH, + /** + * The processing of one or more messages finished, successfully or not. + */ + PROCESSED; + } + } diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BatchAwareBackPressureHandler.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BatchAwareBackPressureHandler.java index 51e12e0a0..c5ccf0ba4 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BatchAwareBackPressureHandler.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/BatchAwareBackPressureHandler.java @@ -35,13 +35,24 @@ public interface BatchAwareBackPressureHandler extends BackPressureHandler { * Release a batch of permits. This has the semantics of letting the {@link BackPressureHandler} know that all * permits from a batch are being released, in opposition to {@link #release(int)} in which any number of permits * can be specified. + * + * @deprecated This method is deprecated and will not be called by the Spring Cloud AWS SQS listener anymore. + * Implement {@link BackPressureHandler#release(int, ReleaseReason)} instead. */ - void releaseBatch(); + @Deprecated + default void releaseBatch() { + release(getBatchSize(), ReleaseReason.NONE_FETCHED); + } /** * Return the configured batch size for this handler. * @return the batch size. + * + * @deprecated This method is deprecated and will not be used by the Spring Cloud AWS SQS listener anymore. */ - int getBatchSize(); + @Deprecated + default int getBatchSize() { + return 0; + } } diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/CompositeBackPressureHandler.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/CompositeBackPressureHandler.java new file mode 100644 index 000000000..930f7dc6e --- /dev/null +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/CompositeBackPressureHandler.java @@ -0,0 +1,149 @@ +/* + * Copyright 2013-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.listener; + +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class CompositeBackPressureHandler implements BatchAwareBackPressureHandler, IdentifiableContainerComponent { + + private static final Logger logger = LoggerFactory.getLogger(CompositeBackPressureHandler.class); + + private final List backPressureHandlers; + + private final int batchSize; + + private final ReentrantLock noPermitsReturnedWaitLock = new ReentrantLock(); + + private final Condition permitsReleasedCondition = noPermitsReturnedWaitLock.newCondition(); + + private final Duration noPermitsReturnedWaitTimeout; + + private String id; + + public CompositeBackPressureHandler(List backPressureHandlers, int batchSize, + Duration waitTimeout) { + this.backPressureHandlers = backPressureHandlers; + this.batchSize = batchSize; + this.noPermitsReturnedWaitTimeout = waitTimeout; + } + + @Override + public void setId(String id) { + this.id = id; + backPressureHandlers.stream().filter(IdentifiableContainerComponent.class::isInstance) + .map(IdentifiableContainerComponent.class::cast) + .forEach(bph -> bph.setId(bph.getClass().getSimpleName() + "-" + id)); + } + + @Override + public String getId() { + return id; + } + + @Override + public int requestBatch() throws InterruptedException { + return request(batchSize); + } + + @Override + public int request(int amount) throws InterruptedException { + logger.debug("[{}] Requesting {} permits", this.id, amount); + int obtained = amount; + int[] obtainedPerBph = new int[backPressureHandlers.size()]; + for (int i = 0; i < backPressureHandlers.size() && obtained > 0; i++) { + obtainedPerBph[i] = backPressureHandlers.get(i).request(obtained); + obtained = Math.min(obtained, obtainedPerBph[i]); + } + for (int i = 0; i < backPressureHandlers.size(); i++) { + int obtainedForBph = obtainedPerBph[i]; + if (obtainedForBph > obtained) { + backPressureHandlers.get(i).release(obtainedForBph - obtained, ReleaseReason.LIMITED); + } + } + if (obtained == 0) { + waitForPermitsToBeReleased(); + } + logger.debug("[{}] Obtained {} permits ({} requested)", this.id, obtained, amount); + return obtained; + } + + @Override + public void release(int amount, ReleaseReason reason) { + logger.debug("[{}] Releasing {} permits ({})", this.id, amount, reason); + for (BackPressureHandler handler : backPressureHandlers) { + handler.release(amount, reason); + } + if (amount > 0) { + signalPermitsWereReleased(); + } + } + + /** + * Waits for permits to be released up to {@link #noPermitsReturnedWaitTimeout}. If no permits were released within + * the configured {@link #noPermitsReturnedWaitTimeout}, returns immediately. This allows {@link #request(int)} to + * return {@code 0} permits and will trigger another round of back-pressure handling. + * + * @throws InterruptedException if the Thread is interrupted while waiting for permits. + */ + @SuppressWarnings({ "java:S899" // we are not interested in the await return value here + }) + private void waitForPermitsToBeReleased() throws InterruptedException { + noPermitsReturnedWaitLock.lock(); + try { + logger.trace("[{}] No permits were obtained, waiting for a release up to {}", this.id, + noPermitsReturnedWaitTimeout); + permitsReleasedCondition.await(noPermitsReturnedWaitTimeout.toMillis(), TimeUnit.MILLISECONDS); + } + finally { + noPermitsReturnedWaitLock.unlock(); + } + } + + private void signalPermitsWereReleased() { + noPermitsReturnedWaitLock.lock(); + try { + permitsReleasedCondition.signal(); + } + finally { + noPermitsReturnedWaitLock.unlock(); + } + } + + @Override + public boolean drain(Duration timeout) { + logger.debug("[{}] Draining back-pressure handlers initiated", this.id); + boolean result = true; + Instant start = Instant.now(); + for (BackPressureHandler handler : backPressureHandlers) { + Duration remainingTimeout = maxDuration(timeout.minus(Duration.between(start, Instant.now())), + Duration.ZERO); + result &= handler.drain(remainingTimeout); + } + logger.debug("[{}] Draining back-pressure handlers completed", this.id); + return result; + } + + private static Duration maxDuration(Duration first, Duration second) { + return first.compareTo(second) > 0 ? first : second; + } +} diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ConcurrencyLimiterBlockingBackPressureHandler.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ConcurrencyLimiterBlockingBackPressureHandler.java new file mode 100644 index 000000000..e389ba7c3 --- /dev/null +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ConcurrencyLimiterBlockingBackPressureHandler.java @@ -0,0 +1,163 @@ +/* + * Copyright 2013-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.listener; + +import java.time.Duration; +import java.util.Arrays; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.util.Assert; + +/** + * {@link BackPressureHandler} implementation that uses a {@link Semaphore} for handling backpressure. + * + * @author Tomaz Fernandes + * @see io.awspring.cloud.sqs.listener.source.PollingMessageSource + * @since 3.0 + */ +public class ConcurrencyLimiterBlockingBackPressureHandler + implements BatchAwareBackPressureHandler, IdentifiableContainerComponent { + + private static final Logger logger = LoggerFactory.getLogger(ConcurrencyLimiterBlockingBackPressureHandler.class); + + private final Semaphore semaphore; + + private final int batchSize; + + private final int totalPermits; + + private final Duration acquireTimeout; + + private final boolean alwaysPollMasMessages; + + private String id = getClass().getSimpleName(); + + private ConcurrencyLimiterBlockingBackPressureHandler(Builder builder) { + this.batchSize = builder.batchSize; + this.totalPermits = builder.totalPermits; + this.acquireTimeout = builder.acquireTimeout; + this.alwaysPollMasMessages = BackPressureMode.ALWAYS_POLL_MAX_MESSAGES.equals(builder.backPressureMode); + this.semaphore = new Semaphore(totalPermits); + logger.debug( + "ConcurrencyLimiterBlockingBackPressureHandler created with configuration " + + "totalPermits: {}, batchSize: {}, acquireTimeout: {}, an alwaysPollMasMessages: {}", + this.totalPermits, this.batchSize, this.acquireTimeout, this.alwaysPollMasMessages); + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public void setId(String id) { + this.id = id; + } + + @Override + public String getId() { + return this.id; + } + + @Override + public int requestBatch() throws InterruptedException { + return request(this.batchSize); + } + + @Override + public int request(int amount) throws InterruptedException { + int acquiredPermits = tryAcquire(amount, this.acquireTimeout); + if (alwaysPollMasMessages || acquiredPermits > 0) { + return acquiredPermits; + } + int availablePermits = Math.min(this.semaphore.availablePermits(), amount); + if (availablePermits > 0) { + return tryAcquire(availablePermits, this.acquireTimeout); + } + return 0; + } + + private int tryAcquire(int amount, Duration duration) throws InterruptedException { + if (this.semaphore.tryAcquire(amount, duration.toMillis(), TimeUnit.MILLISECONDS)) { + logger.debug("[{}] Acquired {} permits ({} / {} available)", this.id, amount, + this.semaphore.availablePermits(), this.totalPermits); + return amount; + } + return 0; + } + + @Override + public void release(int amount, ReleaseReason reason) { + this.semaphore.release(amount); + logger.debug("[{}] Released {} permits ({}) ({} / {} available)", this.id, amount, reason, + this.semaphore.availablePermits(), this.totalPermits); + } + + @Override + public boolean drain(Duration timeout) { + logger.debug("[{}] Waiting for up to {} for approx. {} permits to be released", this.id, timeout, + this.totalPermits - this.semaphore.availablePermits()); + try { + return tryAcquire(this.totalPermits, timeout) > 0; + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + logger.debug("[{}] Draining interrupted", this.id); + return false; + } + } + + public static class Builder { + + private int batchSize; + + private int totalPermits; + + private Duration acquireTimeout; + + private BackPressureMode backPressureMode; + + public Builder batchSize(int batchSize) { + this.batchSize = batchSize; + return this; + } + + public Builder totalPermits(int totalPermits) { + this.totalPermits = totalPermits; + return this; + } + + public Builder acquireTimeout(Duration acquireTimeout) { + this.acquireTimeout = acquireTimeout; + return this; + } + + public Builder throughputConfiguration(BackPressureMode backPressureConfiguration) { + this.backPressureMode = backPressureConfiguration; + return this; + } + + public ConcurrencyLimiterBlockingBackPressureHandler build() { + Assert.noNullElements( + Arrays.asList(this.batchSize, this.totalPermits, this.acquireTimeout, this.backPressureMode), + "Missing configuration"); + Assert.isTrue(this.batchSize > 0, "The batch size must be greater than 0"); + Assert.isTrue(this.totalPermits >= this.batchSize, "Total permits must be greater than the batch size"); + return new ConcurrencyLimiterBlockingBackPressureHandler(this); + } + } +} diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptions.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptions.java index ad7313cf6..95921f33e 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptions.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptions.java @@ -20,6 +20,7 @@ import io.awspring.cloud.sqs.support.converter.MessagingMessageConverter; import java.time.Duration; import java.util.Collection; +import java.util.function.Supplier; import org.springframework.core.task.TaskExecutor; import org.springframework.lang.Nullable; import org.springframework.retry.backoff.BackOffPolicy; @@ -59,7 +60,13 @@ public interface ContainerOptions, B extends Co boolean isAutoStartup(); /** - * Set the maximum time the polling thread should wait for a full batch of permits to be available before trying to + * {@return the amount of time to wait before checking again for the current limit when the queue processing is on + * standby} Default is 100 milliseconds. + */ + Duration getStandbyLimitPollingInterval(); + + /** + * Sets the maximum time the polling thread should wait for a full batch of permits to be available before trying to * acquire a partial batch if so configured. A poll is only actually executed if at least one permit is available. * Default is 10 seconds. * @@ -127,6 +134,12 @@ default BackOffPolicy getPollBackOffPolicy() { */ BackPressureMode getBackPressureMode(); + /** + * Return the a {@link Supplier} to create a {@link BackPressureHandler} for this container. + * @return the BackPressureHandler supplier. + */ + Supplier getBackPressureHandlerSupplier(); + /** * Return the {@link ListenerMode} mode for this container. * @return the listener mode. diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptionsBuilder.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptionsBuilder.java index 9d03b7964..161687b6c 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptionsBuilder.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ContainerOptionsBuilder.java @@ -19,6 +19,7 @@ import io.awspring.cloud.sqs.listener.acknowledgement.handler.AcknowledgementMode; import io.awspring.cloud.sqs.support.converter.MessagingMessageConverter; import java.time.Duration; +import java.util.function.Supplier; import org.springframework.core.task.TaskExecutor; import org.springframework.retry.backoff.BackOffPolicy; @@ -56,6 +57,15 @@ public interface ContainerOptionsBuilder */ B autoStartup(boolean autoStartup); + /** + * Sets the amount of time to wait before checking again for the current limit when the queue processing is on + * standby. + * + * @param standbyLimitPollingInterval the limit polling interval when the queue processing is on standby. + * @return this instance. + */ + B standbyLimitPollingInterval(Duration standbyLimitPollingInterval); + /** * Set the maximum time the polling thread should wait for a full batch of permits to be available before trying to * acquire a partial batch if so configured. A poll is only actually executed if at least one permit is available. @@ -145,6 +155,68 @@ default B pollBackOffPolicy(BackOffPolicy pollBackOffPolicy) { */ B backPressureMode(BackPressureMode backPressureMode); + /** + * Sets the {@link Supplier} of {@link BackPressureHandler} for this container. Default is {@code null} which + * results in a default {@link SemaphoreBackPressureHandler} to be instantiated. In case a supplier is provided, the + * {@link BackPressureHandler} will be instantiated by the supplier. + *

+ * NOTE: it is important for the supplier to always return a new instance as otherwise it might + * result in a BackPressureHandler internal resources (counters, semaphores, ...) to be shared by multiple + * containers which is very likely not the desired behavior. + *

+ * Spring Cloud AWS provides the following {@link BackPressureHandler} implementations: + *

    + *
  • {@link ConcurrencyLimiterBlockingBackPressureHandler}: Limits the maximum number of messages that can be + * processed concurrently by the application.
  • + *
  • {@link ThroughputBackPressureHandler}: Adapts the throughput dynamically between high and low modes in order + * to reduce SQS pull costs when few messages are coming in.
  • + *
  • {@link CompositeBackPressureHandler}: Allows combining multiple {@link BackPressureHandler} together and + * ensures they cooperate.
  • + *
+ *

+ * Below are a few examples of how common use cases can be achieved. Keep in mind you can always create your own + * {@link BackPressureHandler} implementation and if needed combine it with the provided ones thanks to the + * {@link CompositeBackPressureHandler}. + * + *

A BackPressureHandler limiting the max concurrency with high throughput

+ * + *
{@code
+	 * containerOptionsBuilder.backPressureHandlerSupplier(() -> {
+	 * 		return ConcurrencyLimiterBlockingBackPressureHandler.builder()
+	 * 			.batchSize(batchSize)
+	 * 			.totalPermits(maxConcurrentMessages)
+	 * 			.acquireTimeout(acquireTimeout)
+	 * 			.throughputConfiguration(BackPressureMode.FIXED_HIGH_THROUGHPUT)
+	 * 			.build()
+	 * }}
+ * + *

A BackPressureHandler limiting the max concurrency with dynamic throughput

+ * + *
{@code
+	 * containerOptionsBuilder.backPressureHandlerSupplier(() -> {
+	 * 		var concurrencyLimiterBlockingBackPressureHandler = ConcurrencyLimiterBlockingBackPressureHandler.builder()
+	 * 			.batchSize(batchSize)
+	 * 			.totalPermits(maxConcurrentMessages)
+	 * 			.acquireTimeout(acquireTimeout)
+	 * 			.throughputConfiguration(BackPressureMode.AUTO)
+	 * 			.build()
+	 * 		var throughputBackPressureHandler = ThroughputBackPressureHandler.builder()
+	 * 			.batchSize(batchSize)
+	 * 			.build();
+	 * 		return new CompositeBackPressureHandler(List.of(
+	 * 				concurrencyLimiterBlockingBackPressureHandler,
+	 * 				throughputBackPressureHandler
+	 * 			),
+	 * 			batchSize,
+	 * 			standbyLimitPollingInterval
+	 * 		);
+	 * }}
+ * + * @param backPressureHandlerSupplier the BackPressureHandler supplier. + * @return this instance. + */ + B backPressureHandlerSupplier(Supplier backPressureHandlerSupplier); + /** * Set the maximum interval between acknowledgements for batch acknowledgements. The default depends on the specific * {@link ContainerComponentFactory} implementation. diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/SemaphoreBackPressureHandler.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/SemaphoreBackPressureHandler.java index 310b64519..31617c405 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/SemaphoreBackPressureHandler.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/SemaphoreBackPressureHandler.java @@ -180,6 +180,16 @@ public void release(int amount) { this.semaphore.availablePermits()); } + @Override + public void release(int amount, ReleaseReason reason) { + if (amount == this.batchSize && reason == ReleaseReason.NONE_FETCHED) { + releaseBatch(); + } + else { + release(amount); + } + } + private int getPermitsToRelease(int amount) { return this.hasAcquiredFullPermits.compareAndSet(true, false) // The first process that gets here should release all permits except for inflight messages diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ThroughputBackPressureHandler.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ThroughputBackPressureHandler.java new file mode 100644 index 000000000..3ef1410d9 --- /dev/null +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/ThroughputBackPressureHandler.java @@ -0,0 +1,154 @@ +/* + * Copyright 2013-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.listener; + +import io.awspring.cloud.sqs.listener.source.PollingMessageSource; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.util.Assert; + +/** + * {@link BackPressureHandler} implementation that uses a switches between high and low throughput modes. + *

+ * The initial throughput mode is low, which means, only one batch at a time can be requested. If some messages are + * fetched, then the throughput mode is switched to high, which means, the multiple batches can be requested (i.e. there + * is no need to wait for the previous batch's processing to complete before requesting a new one). If no messages are + * returned fetched by a poll, the throughput mode is switched back to low. + *

+ * This {@link BackPressureHandler} is designed to be used in combination with another {@link BackPressureHandler} like + * the {@link ConcurrencyLimiterBlockingBackPressureHandler} that will handle the maximum concurrency level within the + * application. + * + * @author Tomaz Fernandes + * @see PollingMessageSource + * @since 3.0 + */ +public class ThroughputBackPressureHandler implements BatchAwareBackPressureHandler, IdentifiableContainerComponent { + + private static final Logger logger = LoggerFactory.getLogger(ThroughputBackPressureHandler.class); + + private final int batchSize; + + private final AtomicReference currentThroughputMode = new AtomicReference<>( + CurrentThroughputMode.LOW); + + private final AtomicInteger inFlightRequests = new AtomicInteger(0); + + private final AtomicBoolean drained = new AtomicBoolean(false); + + private String id = getClass().getSimpleName(); + + private ThroughputBackPressureHandler(Builder builder) { + this.batchSize = builder.batchSize; + logger.debug("ThroughputBackPressureHandler created with batchSize {}", this.batchSize); + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public void setId(String id) { + this.id = id; + } + + @Override + public String getId() { + return this.id; + } + + @Override + public int requestBatch() throws InterruptedException { + return request(this.batchSize); + } + + @Override + public int request(int amount) throws InterruptedException { + if (drained.get()) { + return 0; + } + int permits; + int inFlight = inFlightRequests.get(); + if (CurrentThroughputMode.LOW == this.currentThroughputMode.get()) { + permits = Math.max(0, Math.min(amount, this.batchSize - inFlight)); + logger.debug("[{}] Acquired {} permits (low throughput mode), in flight: {}", this.id, amount, inFlight); + } + else { + permits = amount; + logger.debug("[{}] Acquired {} permits (high throughput mode), in flight: {}", this.id, amount, inFlight); + } + inFlightRequests.addAndGet(permits); + return permits; + } + + @Override + public void release(int amount, ReleaseReason reason) { + if (drained.get()) { + return; + } + logger.debug("[{}] Releasing {} permits ({})", this.id, amount, reason); + inFlightRequests.addAndGet(-amount); + switch (reason) { + case NONE_FETCHED -> updateThroughputMode(CurrentThroughputMode.HIGH, CurrentThroughputMode.LOW); + case PARTIAL_FETCH -> updateThroughputMode(CurrentThroughputMode.LOW, CurrentThroughputMode.HIGH); + case LIMITED, PROCESSED -> { + // No need to switch throughput mode + } + } + } + + private void updateThroughputMode(CurrentThroughputMode currentTarget, CurrentThroughputMode newTarget) { + if (this.currentThroughputMode.compareAndSet(currentTarget, newTarget)) { + logger.debug("[{}] throughput mode updated to {}", this.id, newTarget); + } + } + + @Override + public boolean drain(Duration timeout) { + logger.debug("[{}] Draining", this.id); + drained.set(true); + return true; + } + + private enum CurrentThroughputMode { + + HIGH, + + LOW; + + } + + public static class Builder { + + private int batchSize; + + public Builder batchSize(int batchSize) { + this.batchSize = batchSize; + return this; + } + + public ThroughputBackPressureHandler build() { + Assert.noNullElements(List.of(this.batchSize), "Missing configuration"); + Assert.isTrue(this.batchSize > 0, "batch size must be greater than 0"); + return new ThroughputBackPressureHandler(this); + } + } +} diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSource.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSource.java index e71dc4319..9041cd9c8 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSource.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSource.java @@ -17,6 +17,7 @@ import io.awspring.cloud.sqs.ConfigUtils; import io.awspring.cloud.sqs.listener.BackPressureHandler; +import io.awspring.cloud.sqs.listener.BackPressureHandler.ReleaseReason; import io.awspring.cloud.sqs.listener.BatchAwareBackPressureHandler; import io.awspring.cloud.sqs.listener.ContainerOptions; import io.awspring.cloud.sqs.listener.IdentifiableContainerComponent; @@ -214,7 +215,7 @@ private void pollAndEmitMessages() { if (!isRunning()) { logger.debug("MessageSource was stopped after permits where acquired. Returning {} permits", acquiredPermits); - this.backPressureHandler.release(acquiredPermits); + this.backPressureHandler.release(acquiredPermits, ReleaseReason.NONE_FETCHED); continue; } // @formatter:off @@ -252,15 +253,12 @@ private void handlePollBackOff() { protected abstract CompletableFuture> doPollForMessages(int messagesToRequest); public Collection> releaseUnusedPermits(int permits, Collection> msgs) { - if (msgs.isEmpty() && permits == this.backPressureHandler.getBatchSize()) { - this.backPressureHandler.releaseBatch(); - logger.trace("Released batch of unused permits for queue {}", this.pollingEndpointName); - } - else { - int permitsToRelease = permits - msgs.size(); - this.backPressureHandler.release(permitsToRelease); - logger.trace("Released {} unused permits for queue {}", permitsToRelease, this.pollingEndpointName); - } + int polledMessages = msgs.size(); + int permitsToRelease = permits - polledMessages; + ReleaseReason releaseReason = polledMessages == 0 ? ReleaseReason.NONE_FETCHED : ReleaseReason.PARTIAL_FETCH; + this.backPressureHandler.release(permitsToRelease, releaseReason); + logger.trace("Released {} unused ({}) permits for queue {} (messages polled {})", permitsToRelease, + releaseReason, this.pollingEndpointName, polledMessages); return msgs; } @@ -285,7 +283,7 @@ protected AcknowledgementCallback getAcknowledgementCallback() { private void releaseBackPressure() { logger.debug("Releasing permit for queue {}", this.pollingEndpointName); - this.backPressureHandler.release(1); + this.backPressureHandler.release(1, ReleaseReason.PROCESSED); } private Void handleSinkException(Throwable t) { diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsBackPressureIntegrationTests.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsBackPressureIntegrationTests.java new file mode 100644 index 000000000..8038f70d2 --- /dev/null +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsBackPressureIntegrationTests.java @@ -0,0 +1,532 @@ +/* + * Copyright 2013-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.integration; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.awspring.cloud.sqs.config.SqsBootstrapConfiguration; +import io.awspring.cloud.sqs.listener.*; +import io.awspring.cloud.sqs.operations.SqsTemplate; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Queue; +import java.util.Random; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Semaphore; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.IntUnaryOperator; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Import; +import org.springframework.messaging.Message; +import org.springframework.messaging.support.MessageBuilder; + +/** + * Integration tests for SQS containers back pressure management. + * + * @author Loïc Rouchon + */ +@SpringBootTest +class SqsBackPressureIntegrationTests extends BaseSqsIntegrationTest { + + private static final Logger logger = LoggerFactory.getLogger(SqsBackPressureIntegrationTests.class); + + @Autowired + SqsTemplate sqsTemplate; + + static final class NonBlockingExternalConcurrencyLimiterBackPressureHandler implements BackPressureHandler { + private final AtomicInteger limit; + private final AtomicInteger inFlight = new AtomicInteger(0); + private final AtomicBoolean draining = new AtomicBoolean(false); + + NonBlockingExternalConcurrencyLimiterBackPressureHandler(int max) { + limit = new AtomicInteger(max); + } + + public void setLimit(int value) { + logger.info("adjusting limit from {} to {}", limit.get(), value); + limit.set(value); + } + + @Override + public int request(int amount) { + if (draining.get()) { + return 0; + } + int permits = Math.max(0, Math.min(limit.get() - inFlight.get(), amount)); + inFlight.addAndGet(permits); + return permits; + } + + @Override + public void release(int amount, ReleaseReason reason) { + inFlight.addAndGet(-amount); + } + + @Override + public boolean drain(Duration timeout) { + Duration drainingTimeout = Duration.ofSeconds(10L); + Duration drainingPollingIntervalCheck = Duration.ofMillis(50L); + draining.set(true); + limit.set(0); + Instant start = Instant.now(); + while (Duration.between(start, Instant.now()).compareTo(drainingTimeout) < 0) { + if (inFlight.get() == 0) { + return true; + } + sleep(drainingPollingIntervalCheck.toMillis()); + } + return false; + } + } + + @ParameterizedTest + @CsvSource({ "2,2", "4,4", "5,5", "20,5" }) + void staticBackPressureLimitShouldCapQueueProcessingCapacity(int staticLimit, int expectedMaxConcurrentRequests) + throws Exception { + AtomicInteger concurrentRequest = new AtomicInteger(); + AtomicInteger maxConcurrentRequest = new AtomicInteger(); + NonBlockingExternalConcurrencyLimiterBackPressureHandler limiter = new NonBlockingExternalConcurrencyLimiterBackPressureHandler( + staticLimit); + String queueName = "BACK_PRESSURE_LIMITER_QUEUE_NAME_STATIC_LIMIT_" + staticLimit; + IntStream.range(0, 10).forEach(index -> { + List> messages = create10Messages("staticBackPressureLimit" + staticLimit); + sqsTemplate.sendMany(queueName, messages); + }); + logger.debug("Sent 100 messages to queue {}", queueName); + var latch = new CountDownLatch(100); + var container = SqsMessageListenerContainer + .builder().sqsAsyncClient( + BaseSqsIntegrationTest.createAsyncClient()) + .queueNames( + queueName) + .configure(options -> options.pollTimeout(Duration.ofSeconds(1)) + .backPressureHandlerSupplier(() -> new CompositeBackPressureHandler( + List.of(limiter, + ConcurrencyLimiterBlockingBackPressureHandler.builder().batchSize(5) + .totalPermits(5).acquireTimeout(Duration.ofSeconds(1L)) + .throughputConfiguration(BackPressureMode.AUTO).build()), + 5, Duration.ofMillis(50L)))) + .messageListener(msg -> { + int concurrentRqs = concurrentRequest.incrementAndGet(); + maxConcurrentRequest.updateAndGet(max -> Math.max(max, concurrentRqs)); + sleep(50L); + logger.debug("concurrent rq {}, max concurrent rq {}, latch count {}", concurrentRequest.get(), + maxConcurrentRequest.get(), latch.getCount()); + latch.countDown(); + concurrentRequest.decrementAndGet(); + }).build(); + container.start(); + assertThat(latch.await(10, TimeUnit.SECONDS)).isTrue(); + assertThat(maxConcurrentRequest.get()).isEqualTo(expectedMaxConcurrentRequests); + container.stop(); + } + + @Test + void zeroBackPressureLimitShouldStopQueueProcessing() throws Exception { + AtomicInteger concurrentRequest = new AtomicInteger(); + AtomicInteger maxConcurrentRequest = new AtomicInteger(); + NonBlockingExternalConcurrencyLimiterBackPressureHandler limiter = new NonBlockingExternalConcurrencyLimiterBackPressureHandler( + 0); + String queueName = "BACK_PRESSURE_LIMITER_QUEUE_NAME_STATIC_LIMIT_0"; + IntStream.range(0, 10).forEach(index -> { + List> messages = create10Messages("staticBackPressureLimit0"); + sqsTemplate.sendMany(queueName, messages); + }); + logger.debug("Sent 100 messages to queue {}", queueName); + var latch = new CountDownLatch(100); + var container = SqsMessageListenerContainer + .builder().sqsAsyncClient( + BaseSqsIntegrationTest.createAsyncClient()) + .queueNames( + queueName) + .configure(options -> options.pollTimeout(Duration.ofSeconds(1)) + .backPressureHandlerSupplier(() -> new CompositeBackPressureHandler( + List.of(limiter, + ConcurrencyLimiterBlockingBackPressureHandler.builder().batchSize(5) + .totalPermits(5).acquireTimeout(Duration.ofSeconds(1L)) + .throughputConfiguration(BackPressureMode.AUTO).build()), + 5, Duration.ofMillis(50L)))) + .messageListener(msg -> { + int concurrentRqs = concurrentRequest.incrementAndGet(); + maxConcurrentRequest.updateAndGet(max -> Math.max(max, concurrentRqs)); + sleep(50L); + logger.debug("concurrent rq {}, max concurrent rq {}, latch count {}", concurrentRequest.get(), + maxConcurrentRequest.get(), latch.getCount()); + latch.countDown(); + concurrentRequest.decrementAndGet(); + }).build(); + container.start(); + assertThat(latch.await(2, TimeUnit.SECONDS)).isFalse(); + assertThat(maxConcurrentRequest.get()).isZero(); + assertThat(latch.getCount()).isEqualTo(100L); + container.stop(); + } + + @Test + void changeInBackPressureLimitShouldAdaptQueueProcessingCapacity() throws Exception { + AtomicInteger concurrentRequest = new AtomicInteger(); + AtomicInteger maxConcurrentRequest = new AtomicInteger(); + NonBlockingExternalConcurrencyLimiterBackPressureHandler limiter = new NonBlockingExternalConcurrencyLimiterBackPressureHandler( + 5); + String queueName = "BACK_PRESSURE_LIMITER_QUEUE_NAME_SYNC_ADAPTIVE_LIMIT"; + int nbMessages = 280; + IntStream.range(0, nbMessages / 10).forEach(index -> { + List> messages = create10Messages("syncAdaptiveBackPressureLimit"); + sqsTemplate.sendMany(queueName, messages); + }); + logger.debug("Sent {} messages to queue {}", nbMessages, queueName); + var latch = new CountDownLatch(nbMessages); + var controlSemaphore = new Semaphore(0); + var advanceSemaphore = new Semaphore(0); + var processingFailed = new AtomicBoolean(false); + var isDraining = new AtomicBoolean(false); + var container = SqsMessageListenerContainer + .builder().sqsAsyncClient( + BaseSqsIntegrationTest.createAsyncClient()) + .queueNames( + queueName) + .configure(options -> options.pollTimeout(Duration.ofSeconds(1)) + .backPressureHandlerSupplier(() -> new CompositeBackPressureHandler( + List.of(limiter, + ConcurrencyLimiterBlockingBackPressureHandler.builder().batchSize(5) + .totalPermits(5).acquireTimeout(Duration.ofSeconds(1L)) + .throughputConfiguration(BackPressureMode.AUTO).build()), + 5, Duration.ofMillis(50L)))) + .messageListener(msg -> { + try { + if (!controlSemaphore.tryAcquire(5, TimeUnit.SECONDS) && !isDraining.get()) { + processingFailed.set(true); + throw new IllegalStateException("Failed to wait for control semaphore"); + } + } + catch (InterruptedException e) { + if (!isDraining.get()) { + processingFailed.set(true); + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + int concurrentRqs = concurrentRequest.incrementAndGet(); + maxConcurrentRequest.updateAndGet(max -> Math.max(max, concurrentRqs)); + latch.countDown(); + logger.debug("concurrent rq {}, max concurrent rq {}, latch count {}", concurrentRequest.get(), + maxConcurrentRequest.get(), latch.getCount()); + sleep(10L); + concurrentRequest.decrementAndGet(); + advanceSemaphore.release(); + }).build(); + class Controller { + private final Semaphore advanceSemaphore; + private final Semaphore controlSemaphore; + private final NonBlockingExternalConcurrencyLimiterBackPressureHandler limiter; + private final AtomicInteger maxConcurrentRequest; + private final AtomicBoolean processingFailed; + + Controller(Semaphore advanceSemaphore, Semaphore controlSemaphore, + NonBlockingExternalConcurrencyLimiterBackPressureHandler limiter, + AtomicInteger maxConcurrentRequest, AtomicBoolean processingFailed) { + this.advanceSemaphore = advanceSemaphore; + this.controlSemaphore = controlSemaphore; + this.limiter = limiter; + this.maxConcurrentRequest = maxConcurrentRequest; + this.processingFailed = processingFailed; + } + + public void updateLimit(int newLimit) { + limiter.setLimit(newLimit); + } + + void updateLimitAndWaitForReset(int newLimit) throws InterruptedException { + updateLimit(newLimit); + int atLeastTwoPollingCycles = 2 * 5; + controlSemaphore.release(atLeastTwoPollingCycles); + waitForAdvance(atLeastTwoPollingCycles); + maxConcurrentRequest.set(0); + } + + void advance(int permits) { + controlSemaphore.release(permits); + } + + void waitForAdvance(int permits) throws InterruptedException { + assertThat(advanceSemaphore.tryAcquire(permits, 5, TimeUnit.SECONDS)) + .withFailMessage(() -> "Waiting for %d permits timed out. Only %d permits available" + .formatted(permits, advanceSemaphore.availablePermits())) + .isTrue(); + assertThat(processingFailed.get()).isFalse(); + } + } + var controller = new Controller(advanceSemaphore, controlSemaphore, limiter, maxConcurrentRequest, + processingFailed); + try { + container.start(); + + controller.advance(50); + controller.waitForAdvance(50); + // not limiting queue processing capacity + assertThat(controller.maxConcurrentRequest.get()).isEqualTo(5); + controller.updateLimitAndWaitForReset(2); + controller.advance(50); + + controller.waitForAdvance(50); + // limiting queue processing capacity + assertThat(controller.maxConcurrentRequest.get()).isEqualTo(2); + controller.updateLimitAndWaitForReset(7); + controller.advance(50); + + controller.waitForAdvance(50); + // not limiting queue processing capacity + assertThat(controller.maxConcurrentRequest.get()).isEqualTo(5); + controller.updateLimitAndWaitForReset(3); + controller.advance(50); + sleep(10L); + limiter.setLimit(1); + sleep(10L); + limiter.setLimit(2); + sleep(10L); + limiter.setLimit(3); + + controller.waitForAdvance(50); + assertThat(controller.maxConcurrentRequest.get()).isEqualTo(3); + // stopping processing of the queue + controller.updateLimit(0); + controller.advance(50); + assertThat(advanceSemaphore.tryAcquire(10, 5, TimeUnit.SECONDS)) + .withFailMessage("Acquiring semaphore should have timed out as limit was set to 0").isFalse(); + + // resume queue processing + controller.updateLimit(6); + + controller.waitForAdvance(50); + assertThat(latch.await(10, TimeUnit.SECONDS)).isTrue(); + assertThat(controller.maxConcurrentRequest.get()).isEqualTo(5); + assertThat(processingFailed.get()).isFalse(); + } + finally { + isDraining.set(true); + container.stop(); + } + } + + static class EventsCsvWriter { + private final Queue events = new ConcurrentLinkedQueue<>(List.of("event,time,value")); + + void registerEvent(String event, int value) { + events.add("%s,%s,%d".formatted(event, Instant.now(), value)); + } + + void write(Path path) throws Exception { + Files.writeString(path, String.join("\n", events), StandardCharsets.UTF_8, StandardOpenOption.CREATE, + StandardOpenOption.TRUNCATE_EXISTING); + } + } + + static class StatisticsBphDecorator implements BatchAwareBackPressureHandler, IdentifiableContainerComponent { + private final BatchAwareBackPressureHandler delegate; + private final EventsCsvWriter eventCsv; + private String id; + + StatisticsBphDecorator(BatchAwareBackPressureHandler delegate, EventsCsvWriter eventsCsvWriter) { + this.delegate = delegate; + this.eventCsv = eventsCsvWriter; + } + + @Override + public int requestBatch() throws InterruptedException { + int permits = delegate.requestBatch(); + if (permits > 0) { + eventCsv.registerEvent("obtained_permits", permits); + } + return permits; + } + + @Override + public int request(int amount) throws InterruptedException { + int permits = delegate.request(amount); + if (permits > 0) { + eventCsv.registerEvent("obtained_permits", permits); + } + return permits; + } + + @Override + public void release(int amount, ReleaseReason reason) { + if (amount > 0) { + eventCsv.registerEvent("release_" + reason, amount); + } + delegate.release(amount, reason); + } + + @Override + public boolean drain(Duration timeout) { + eventCsv.registerEvent("drain", 1); + return delegate.drain(timeout); + } + + @Override + public void setId(String id) { + this.id = id; + if (delegate instanceof IdentifiableContainerComponent icc) { + icc.setId("delegate-" + id); + } + } + + @Override + public String getId() { + return id; + } + } + + /** + * This test simulates a progressive change in the back pressure limit. Unlike + * {@link #changeInBackPressureLimitShouldAdaptQueueProcessingCapacity()}, this test does not block message + * consumption while updating the limit. + *

+ * The limit is updated in a loop until all messages are consumed. The update follows a triangle wave pattern with a + * minimum of 0, a maximum of 15, and a period of 30 iterations. After each update of the limit, the test waits up + * to 10ms and samples the maximum number of concurrent messages that were processed since the update. This number + * can be higher than the defined limit during the adaptation period of the decreasing limit wave. For the + * increasing limit wave, it is usually lower due to the adaptation delay. In both cases, the maximum number of + * concurrent messages being processed rapidly converges toward the defined limit. + *

+ * The test passes if the sum of the sampled maximum number of concurrently processed messages is lower than the sum + * of the limits at those points in time. + */ + @Test + void unsynchronizedChangesInBackPressureLimitShouldAdaptQueueProcessingCapacity() throws Exception { + AtomicInteger concurrentRequest = new AtomicInteger(); + AtomicInteger maxConcurrentRequest = new AtomicInteger(); + NonBlockingExternalConcurrencyLimiterBackPressureHandler limiter = new NonBlockingExternalConcurrencyLimiterBackPressureHandler( + 0); + String queueName = "REACTIVE_BACK_PRESSURE_LIMITER_QUEUE_NAME_ADAPTIVE_LIMIT"; + int nbMessages = 1000; + Semaphore advanceSemaphore = new Semaphore(0); + IntStream.range(0, nbMessages / 10).forEach(index -> { + List> messages = create10Messages("reactAdaptiveBackPressureLimit"); + sqsTemplate.sendMany(queueName, messages); + }); + logger.debug("Sent {} messages to queue {}", nbMessages, queueName); + var latch = new CountDownLatch(nbMessages); + EventsCsvWriter eventsCsvWriter = new EventsCsvWriter(); + var container = SqsMessageListenerContainer.builder().sqsAsyncClient(BaseSqsIntegrationTest.createAsyncClient()) + .queueNames(queueName) + .configure(options -> options.pollTimeout(Duration.ofSeconds(1)) + .standbyLimitPollingInterval(Duration.ofMillis(1)) + .backPressureHandlerSupplier(() -> new StatisticsBphDecorator(new CompositeBackPressureHandler( + List.of(limiter, + ConcurrencyLimiterBlockingBackPressureHandler.builder().batchSize(10) + .totalPermits(10).acquireTimeout(Duration.ofSeconds(1L)) + .throughputConfiguration(BackPressureMode.AUTO).build()), + 10, Duration.ofMillis(50L)), eventsCsvWriter))) + .messageListener(msg -> { + int currentConcurrentRq = concurrentRequest.incrementAndGet(); + maxConcurrentRequest.updateAndGet(max -> Math.max(max, currentConcurrentRq)); + sleep(ThreadLocalRandom.current().nextInt(10)); + latch.countDown(); + logger.debug("concurrent rq {}, max concurrent rq {}, latch count {}", concurrentRequest.get(), + maxConcurrentRequest.get(), latch.getCount()); + concurrentRequest.decrementAndGet(); + advanceSemaphore.release(); + }).build(); + IntUnaryOperator progressiveLimitChange = (int x) -> { + int period = 30; + int halfPeriod = period / 2; + if (x % period < halfPeriod) { + return (x % halfPeriod); + } + else { + return (halfPeriod - (x % halfPeriod)); + } + }; + try { + container.start(); + Random random = new Random(); + int limitsSum = 0; + int maxConcurrentRqSum = 0; + int changeLimitCount = 0; + while (latch.getCount() > 0 && changeLimitCount < nbMessages) { + changeLimitCount++; + int limit = progressiveLimitChange.applyAsInt(changeLimitCount); + int expectedMax = Math.min(10, limit); + limiter.setLimit(limit); + maxConcurrentRequest.set(0); + sleep(random.nextInt(20)); + int actualLimit = Math.min(10, limit); + int max = maxConcurrentRequest.get(); + if (max > 0) { + // Ignore iterations where nothing was polled (messages consumption slower than iteration) + limitsSum += actualLimit; + maxConcurrentRqSum += max; + } + eventsCsvWriter.registerEvent("max_concurrent_rq", max); + eventsCsvWriter.registerEvent("concurrent_rq", concurrentRequest.get()); + eventsCsvWriter.registerEvent("limit", limit); + eventsCsvWriter.registerEvent("in_flight", limiter.inFlight.get()); + eventsCsvWriter.registerEvent("expected_max", expectedMax); + eventsCsvWriter.registerEvent("max_minus_expected_max", max - expectedMax); + } + eventsCsvWriter.write(Path.of("target/stats-%s.csv".formatted(queueName))); + assertThat(maxConcurrentRqSum).isLessThanOrEqualTo(limitsSum); + assertThat(latch.await(10, TimeUnit.SECONDS)).isTrue(); + } + finally { + container.stop(); + } + } + + private static void sleep(long millis) { + try { + Thread.sleep(millis); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + private List> create10Messages(String testName) { + return IntStream.range(0, 10).mapToObj(index -> testName + "-payload-" + index) + .map(payload -> MessageBuilder.withPayload(payload).build()).collect(Collectors.toList()); + } + + @Import(SqsBootstrapConfiguration.class) + @Configuration + static class SQSConfiguration { + + @Bean + SqsTemplate sqsTemplate() { + return SqsTemplate.builder().sqsAsyncClient(BaseSqsIntegrationTest.createAsyncClient()).build(); + } + } +} diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsIntegrationTests.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsIntegrationTests.java index 50bded839..76a7a65f7 100644 --- a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsIntegrationTests.java +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsIntegrationTests.java @@ -269,6 +269,7 @@ void manuallyCreatesInactiveContainer() throws Exception { logger.debug("Sent message to queue {} with messageBody {}", MANUALLY_CREATE_INACTIVE_CONTAINER_QUEUE_NAME, messageBody); assertThat(latchContainer.manuallyInactiveCreatedContainerLatch.await(10, TimeUnit.SECONDS)).isTrue(); + inactiveMessageListenerContainer.stop(); } // @formatter:off diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSourceTests.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSourceTests.java index b03b308c6..14e80cb07 100644 --- a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSourceTests.java +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/source/AbstractPollingMessageSourceTests.java @@ -23,27 +23,17 @@ import static org.mockito.Mockito.times; import io.awspring.cloud.sqs.MessageExecutionThreadFactory; -import io.awspring.cloud.sqs.listener.BackPressureMode; -import io.awspring.cloud.sqs.listener.SemaphoreBackPressureHandler; -import io.awspring.cloud.sqs.listener.SqsContainerOptions; +import io.awspring.cloud.sqs.listener.*; import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementCallback; import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementProcessor; import io.awspring.cloud.sqs.support.converter.MessageConversionContext; import io.awspring.cloud.sqs.support.converter.SqsMessagingMessageConverter; import java.time.Duration; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Semaphore; -import java.util.concurrent.ThreadFactory; -import java.util.concurrent.TimeUnit; +import java.util.*; +import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.IntStream; import org.assertj.core.api.InstanceOfAssertFactories; import org.awaitility.Awaitility; @@ -69,13 +59,77 @@ class AbstractPollingMessageSourceTests { void shouldAcquireAndReleaseFullPermits() { String testName = "shouldAcquireAndReleaseFullPermits"; - SemaphoreBackPressureHandler backPressureHandler = SemaphoreBackPressureHandler.builder() + BackPressureHandler backPressureHandler = ConcurrencyLimiterBlockingBackPressureHandler.builder() .acquireTimeout(Duration.ofMillis(200)).batchSize(10).totalPermits(10) .throughputConfiguration(BackPressureMode.ALWAYS_POLL_MAX_MESSAGES).build(); ExecutorService threadPool = Executors.newCachedThreadPool(); CountDownLatch pollingCounter = new CountDownLatch(3); CountDownLatch processingCounter = new CountDownLatch(1); + AbstractPollingMessageSource source = new AbstractPollingMessageSource<>() { + + private final AtomicBoolean hasReceived = new AtomicBoolean(false); + + @Override + protected CompletableFuture> doPollForMessages(int messagesToRequest) { + return CompletableFuture.supplyAsync(() -> { + try { + // Since BackPressureMode.ALWAYS_POLL_MAX_MESSAGES, should always be 10. + assertThat(messagesToRequest).isEqualTo(10); + assertAvailablePermits(backPressureHandler, 0); + boolean firstPoll = hasReceived.compareAndSet(false, true); + return firstPoll + ? (Collection) List.of(Message.builder() + .messageId(UUID.randomUUID().toString()).body("message").build()) + : Collections. emptyList(); + } + catch (Throwable t) { + logger.error("Error", t); + throw new RuntimeException(t); + } + }, threadPool).whenComplete((v, t) -> { + if (t == null) { + pollingCounter.countDown(); + } + }); + } + }; + + source.setBackPressureHandler(backPressureHandler); + source.setMessageSink((msgs, context) -> { + assertAvailablePermits(backPressureHandler, 9); + msgs.forEach(msg -> context.runBackPressureReleaseCallback()); + return CompletableFuture.runAsync(processingCounter::countDown); + }); + + source.setId(testName + " source"); + source.configure(SqsContainerOptions.builder().build()); + source.setTaskExecutor(createTaskExecutor(testName)); + source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor()); + source.start(); + assertThat(doAwait(pollingCounter)).isTrue(); + assertThat(doAwait(processingCounter)).isTrue(); + } + + @Test + void shouldAdaptThroughputMode() { + String testName = "shouldAdaptThroughputMode"; + + int totalPermits = 20; + int batchSize = 10; + var concurrencyLimiterBlockingBackPressureHandler = ConcurrencyLimiterBlockingBackPressureHandler.builder() + .batchSize(batchSize).totalPermits(totalPermits) + .throughputConfiguration(BackPressureMode.ALWAYS_POLL_MAX_MESSAGES) + .acquireTimeout(Duration.ofSeconds(5L)).build(); + var throughputBackPressureHandler = ThroughputBackPressureHandler.builder().batchSize(batchSize).build(); + var backPressureHandler = new CompositeBackPressureHandler( + List.of(concurrencyLimiterBlockingBackPressureHandler, throughputBackPressureHandler), batchSize, + Duration.ofMillis(100L)); + ExecutorService threadPool = Executors.newCachedThreadPool(); + CountDownLatch pollingCounter = new CountDownLatch(3); + CountDownLatch processingCounter = new CountDownLatch(1); + Collection errors = new ConcurrentLinkedQueue<>(); + AbstractPollingMessageSource source = new AbstractPollingMessageSource<>() { private final AtomicBoolean hasReceived = new AtomicBoolean(false); @@ -88,20 +142,20 @@ protected CompletableFuture> doPollForMessages(int messagesT try { // Since BackPressureMode.ALWAYS_POLL_MAX_MESSAGES, should always be 10. assertThat(messagesToRequest).isEqualTo(10); - assertAvailablePermits(backPressureHandler, 0); + // assertAvailablePermits(backPressureHandler, 10); boolean firstPoll = hasReceived.compareAndSet(false, true); if (firstPoll) { - logger.debug("First poll"); + logger.warn("First poll"); // No permits released yet, should be TM low assertThroughputMode(backPressureHandler, "low"); } else if (hasMadeSecondPoll.compareAndSet(false, true)) { - logger.debug("Second poll"); + logger.warn("Second poll"); // Permits returned, should be high assertThroughputMode(backPressureHandler, "high"); } else { - logger.debug("Third poll"); + logger.warn("Third poll"); // Already returned full permits, should be low assertThroughputMode(backPressureHandler, "low"); } @@ -111,20 +165,25 @@ else if (hasMadeSecondPoll.compareAndSet(false, true)) { : Collections. emptyList(); } catch (Throwable t) { - logger.error("Error", t); + logger.error("Error (not expecting it)", t); + errors.add(t); throw new RuntimeException(t); } }, threadPool).whenComplete((v, t) -> { if (t == null) { + logger.warn("pas boom", t); pollingCounter.countDown(); } + else { + logger.warn("BOOOOOOOM", t); + errors.add(t); + } }); } }; source.setBackPressureHandler(backPressureHandler); source.setMessageSink((msgs, context) -> { - assertAvailablePermits(backPressureHandler, 9); msgs.forEach(msg -> context.runBackPressureReleaseCallback()); return CompletableFuture.runAsync(processingCounter::countDown); }); @@ -133,9 +192,16 @@ else if (hasMadeSecondPoll.compareAndSet(false, true)) { source.configure(SqsContainerOptions.builder().build()); source.setTaskExecutor(createTaskExecutor(testName)); source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor()); - source.start(); - assertThat(doAwait(pollingCounter)).isTrue(); - assertThat(doAwait(processingCounter)).isTrue(); + try { + logger.warn("Yolo, let's start"); + source.start(); + assertThat(doAwait(pollingCounter)).isTrue(); + assertThat(doAwait(processingCounter)).isTrue(); + assertThat(errors).isEmpty(); + } + finally { + source.stop(); + } } private static final AtomicInteger testCounter = new AtomicInteger(); @@ -143,8 +209,8 @@ else if (hasMadeSecondPoll.compareAndSet(false, true)) { @Test void shouldAcquireAndReleasePartialPermits() { String testName = "shouldAcquireAndReleasePartialPermits"; - SemaphoreBackPressureHandler backPressureHandler = SemaphoreBackPressureHandler.builder() - .acquireTimeout(Duration.ofMillis(150)).batchSize(10).totalPermits(10) + ConcurrencyLimiterBlockingBackPressureHandler backPressureHandler = ConcurrencyLimiterBlockingBackPressureHandler + .builder().acquireTimeout(Duration.ofMillis(150)).batchSize(10).totalPermits(10) .throughputConfiguration(BackPressureMode.AUTO).build(); ExecutorService threadPool = Executors .newCachedThreadPool(new MessageExecutionThreadFactory("test " + testCounter.incrementAndGet())); @@ -159,8 +225,6 @@ void shouldAcquireAndReleasePartialPermits() { private final AtomicBoolean hasAcquired9 = new AtomicBoolean(false); - private final AtomicBoolean hasMadeThirdPoll = new AtomicBoolean(false); - @Override protected CompletableFuture> doPollForMessages(int messagesToRequest) { return CompletableFuture.supplyAsync(() -> { @@ -176,31 +240,20 @@ protected CompletableFuture> doPollForMessages(int messagesT assertThat(messagesToRequest).isEqualTo(10); assertAvailablePermits(backPressureHandler, 0); // No permits have been released yet - assertThroughputMode(backPressureHandler, "low"); } else if (hasAcquired9.compareAndSet(false, true)) { // Second poll, should have 9 logger.debug("Second poll - should request 9 messages"); assertThat(messagesToRequest).isEqualTo(9); assertAvailablePermitsLessThanOrEqualTo(backPressureHandler, 1); - // Has released 9 permits, should be TM HIGH - assertThroughputMode(backPressureHandler, "high"); + // Has released 9 permits processingLatch.countDown(); // Release processing now } else { - boolean thirdPoll = hasMadeThirdPoll.compareAndSet(false, true); // Third poll or later, should have 10 again logger.debug("Third poll - should request 10 messages"); assertThat(messagesToRequest).isEqualTo(10); assertAvailablePermits(backPressureHandler, 0); - if (thirdPoll) { - // Hasn't yet returned a full batch, should be TM High - assertThroughputMode(backPressureHandler, "high"); - } - else { - // Has returned all permits in third poll - assertThroughputMode(backPressureHandler, "low"); - } } if (shouldReturnMessage) { logger.debug("shouldReturnMessage, returning one message"); @@ -241,8 +294,8 @@ else if (hasAcquired9.compareAndSet(false, true)) { @Test void shouldReleasePermitsOnConversionErrors() { String testName = "shouldReleasePermitsOnConversionErrors"; - SemaphoreBackPressureHandler backPressureHandler = SemaphoreBackPressureHandler.builder() - .acquireTimeout(Duration.ofMillis(150)).batchSize(10).totalPermits(10) + ConcurrencyLimiterBlockingBackPressureHandler backPressureHandler = ConcurrencyLimiterBlockingBackPressureHandler + .builder().acquireTimeout(Duration.ofMillis(150)).batchSize(10).totalPermits(10) .throughputConfiguration(BackPressureMode.AUTO).build(); AtomicInteger convertedMessages = new AtomicInteger(0); @@ -304,9 +357,16 @@ void shouldBackOffIfPollingThrowsAnError() { var testName = "shouldBackOffIfPollingThrowsAnError"; - var backPressureHandler = SemaphoreBackPressureHandler.builder().acquireTimeout(Duration.ofMillis(200)) - .batchSize(10).totalPermits(40).throughputConfiguration(BackPressureMode.ALWAYS_POLL_MAX_MESSAGES) - .build(); + int totalPermits = 40; + int batchSize = 10; + var concurrencyLimiterBlockingBackPressureHandler = ConcurrencyLimiterBlockingBackPressureHandler.builder() + .batchSize(batchSize).totalPermits(totalPermits) + .throughputConfiguration(BackPressureMode.ALWAYS_POLL_MAX_MESSAGES) + .acquireTimeout(Duration.ofMillis(200)).build(); + var throughputBackPressureHandler = ThroughputBackPressureHandler.builder().batchSize(batchSize).build(); + var backPressureHandler = new CompositeBackPressureHandler( + List.of(concurrencyLimiterBlockingBackPressureHandler, throughputBackPressureHandler), batchSize, + Duration.ofSeconds(5L)); var currentPoll = new AtomicInteger(0); var waitThirdPollLatch = new CountDownLatch(4); @@ -363,24 +423,45 @@ private static boolean doAwait(CountDownLatch processingLatch) { } } - private void assertThroughputMode(SemaphoreBackPressureHandler backPressureHandler, String expectedThroughputMode) { - assertThat(ReflectionTestUtils.getField(backPressureHandler, "currentThroughputMode")) - .extracting(Object::toString).extracting(String::toLowerCase) + private void assertThroughputMode(BackPressureHandler backPressureHandler, String expectedThroughputMode) { + var bph = extractBackPressureHandler(backPressureHandler, ThroughputBackPressureHandler.class); + assertThat(getThroughputModeValue(bph, "currentThroughputMode")) .isEqualTo(expectedThroughputMode.toLowerCase()); } - private void assertAvailablePermits(SemaphoreBackPressureHandler backPressureHandler, int expectedPermits) { - assertThat(ReflectionTestUtils.getField(backPressureHandler, "semaphore")).asInstanceOf(type(Semaphore.class)) + private static String getThroughputModeValue(ThroughputBackPressureHandler bph, String targetThroughputMode) { + return ((AtomicReference) ReflectionTestUtils.getField(bph, targetThroughputMode)).get().toString() + .toLowerCase(Locale.ROOT); + } + + private void assertAvailablePermits(BackPressureHandler backPressureHandler, int expectedPermits) { + var bph = extractBackPressureHandler(backPressureHandler, ConcurrencyLimiterBlockingBackPressureHandler.class); + assertThat(ReflectionTestUtils.getField(bph, "semaphore")).asInstanceOf(type(Semaphore.class)) .extracting(Semaphore::availablePermits).isEqualTo(expectedPermits); } - private void assertAvailablePermitsLessThanOrEqualTo(SemaphoreBackPressureHandler backPressureHandler, - int maxExpectedPermits) { - assertThat(ReflectionTestUtils.getField(backPressureHandler, "semaphore")).asInstanceOf(type(Semaphore.class)) + private void assertAvailablePermitsLessThanOrEqualTo( + ConcurrencyLimiterBlockingBackPressureHandler backPressureHandler, int maxExpectedPermits) { + var bph = extractBackPressureHandler(backPressureHandler, ConcurrencyLimiterBlockingBackPressureHandler.class); + assertThat(ReflectionTestUtils.getField(bph, "semaphore")).asInstanceOf(type(Semaphore.class)) .extracting(Semaphore::availablePermits).asInstanceOf(InstanceOfAssertFactories.INTEGER) .isLessThanOrEqualTo(maxExpectedPermits); } + private T extractBackPressureHandler(BackPressureHandler bph, Class type) { + if (type.isInstance(bph)) { + return type.cast(bph); + } + if (bph instanceof CompositeBackPressureHandler cbph) { + List backPressureHandlers = (List) ReflectionTestUtils + .getField(cbph, "backPressureHandlers"); + return extractBackPressureHandler( + backPressureHandlers.stream().filter(type::isInstance).map(type::cast).findFirst().orElseThrow(), + type); + } + throw new NoSuchElementException("%s not found in %s".formatted(type.getSimpleName(), bph)); + } + // Used to slow down tests while developing private void doSleep(int time) { try {