diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/operations/AbstractMessagingTemplate.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/operations/AbstractMessagingTemplate.java index bb6ed9e50..f01852c26 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/operations/AbstractMessagingTemplate.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/operations/AbstractMessagingTemplate.java @@ -301,23 +301,31 @@ public CompletableFuture> sendAsync(@Nullable String endpointN public CompletableFuture> sendAsync(@Nullable String endpointName, Message message) { String endpointToUse = getEndpointName(endpointName); logger.trace("Sending message {} to endpoint {}", MessageHeaderUtils.getId(message), endpointName); - return preProcessMessageForSendAsync(endpointToUse, message) - .thenCompose(messageToUse -> observeAndSendAsync(messageToUse, endpointToUse) - .exceptionallyCompose( - t -> CompletableFuture.failedFuture(new MessagingOperationFailedException( - "Message send operation failed for message %s to endpoint %s" - .formatted(MessageHeaderUtils.getId(message), endpointToUse), - endpointToUse, message, t))) - .whenComplete((v, t) -> logSendMessageResult(endpointToUse, message, t))); - } - - private CompletableFuture> observeAndSendAsync(Message message, String endpointToUse) { - AbstractTemplateObservation.Context context = this.observationSpecifics.createContext(message, endpointToUse); - Observation observation = startObservation(context); - Map carrier = Objects.requireNonNull(context.getCarrier(), "No carrier found in context."); - Message messageWithObservationHeader = MessageHeaderUtils.addHeadersIfAbsent(message, carrier); - return doSendAsync(endpointToUse, convertMessageToSend(messageWithObservationHeader), - messageWithObservationHeader) + + // Capture parent observation on the calling thread to propagate trace context across async boundary + var parentObservation = this.observationRegistry.getCurrentObservation(); + + return preProcessMessageForSendAsync(endpointToUse, message).thenCompose( + preprocessedMessage -> observeAndSend(preprocessedMessage, message, endpointToUse, parentObservation)); + } + + private CompletableFuture> observeAndSend(Message preprocessedMessage, + Message originalMessage, String endpointToUse, @Nullable Observation parentObservation) { + var context = this.observationSpecifics.createContext(preprocessedMessage, endpointToUse); + Observation observation = startObservation(context, parentObservation); + var carrier = Objects.requireNonNull(context.getCarrier(), "No carrier found in context."); + var messageWithObservationHeaders = MessageHeaderUtils.addHeadersIfAbsent(preprocessedMessage, carrier); + return doSendAndCompleteObservation(messageWithObservationHeaders, endpointToUse, context, observation) + .exceptionallyCompose(t -> CompletableFuture.failedFuture(new MessagingOperationFailedException( + "Message send operation failed for message %s to endpoint %s" + .formatted(MessageHeaderUtils.getId(originalMessage), endpointToUse), + endpointToUse, originalMessage, t))) + .whenComplete((v, t) -> logSendMessageResult(endpointToUse, originalMessage, t)); + } + + private CompletableFuture> doSendAndCompleteObservation(Message message, String endpointToUse, + AbstractTemplateObservation.Context context, Observation observation) { + return doSendAsync(endpointToUse, convertMessageToSend(message), message) .whenComplete((sendResult, t) -> completeObservation(sendResult, context, t, observation)); } @@ -335,13 +343,18 @@ private void completeObservation(@Nullable SendResult sendResult, AbstractTem } @SuppressWarnings("unchecked") - private Observation startObservation(Context observationContext) { + private Observation startObservation(Context observationContext, + @Nullable Observation parentObservation) { ObservationConvention defaultConvention = (ObservationConvention) observationSpecifics .getDefaultConvention(); ObservationConvention customConvention = (ObservationConvention) this.customObservationConvention; ObservationDocumentation documentation = observationSpecifics.getDocumentation(); - return documentation.start(customConvention, defaultConvention, () -> observationContext, - this.observationRegistry); + Observation observation = documentation.observation(customConvention, defaultConvention, + () -> observationContext, this.observationRegistry); + if (parentObservation != null) { + observation.parentObservation(parentObservation); + } + return observation.start(); } protected abstract Message preProcessMessageForSend(String endpointToUse, Message message); diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsTemplateFifoTracingIntegrationTest.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsTemplateFifoTracingIntegrationTest.java new file mode 100644 index 000000000..f03b8f3ef --- /dev/null +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsTemplateFifoTracingIntegrationTest.java @@ -0,0 +1,239 @@ +/* + * Copyright 2013-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.awspring.cloud.sqs.integration; + +import io.awspring.cloud.sqs.operations.SqsTemplate; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.tracing.CurrentTraceContext; +import io.micrometer.tracing.Span; +import io.micrometer.tracing.TraceContext; +import io.micrometer.tracing.Tracer; +import io.micrometer.tracing.handler.DefaultTracingObservationHandler; +import io.micrometer.tracing.handler.PropagatingReceiverTracingObservationHandler; +import io.micrometer.tracing.handler.PropagatingSenderTracingObservationHandler; +import io.micrometer.tracing.propagation.Propagator; +import io.micrometer.tracing.test.simple.SimpleTraceContext; +import io.micrometer.tracing.test.simple.SimpleTracer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import software.amazon.awssdk.services.sqs.SqsAsyncClient; +import software.amazon.awssdk.services.sqs.model.QueueAttributeName; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for trace context propagation in FIFO queues with SqsTemplate. + *

+ * Verifies that trace headers (traceparent) are correctly propagated from sender to receiver when using + * {@code sendAsync()} with FIFO queues, including scenarios where queue attributes must be resolved asynchronously on + * the first call and when they are cached on subsequent calls. + * + * @author Igor Quintanilha + */ +@SpringBootTest +public class SqsTemplateFifoTracingIntegrationTest extends BaseSqsIntegrationTest { + private static final Logger logger = LoggerFactory.getLogger(SqsTemplateFifoTracingIntegrationTest.class); + + private static final String FIFO_QUEUE_NAME = "trace-context-test-queue.fifo"; + private static final String FIFO_CACHE_HIT_QUEUE_NAME = "trace-context-test-queue-cache-hit.fifo"; + + @Autowired + private SqsTemplate sqsTemplate; + + @Autowired + private TestObservationRegistry observationRegistry; + + @Autowired + private CurrentTraceContext currentTraceContext; + + @BeforeAll + static void beforeTests() { + var client = createAsyncClient(); + createFifoQueue(client, FIFO_QUEUE_NAME, Map.of(QueueAttributeName.CONTENT_BASED_DEDUPLICATION, "false")).join(); + createFifoQueue(client, FIFO_CACHE_HIT_QUEUE_NAME, Map.of(QueueAttributeName.CONTENT_BASED_DEDUPLICATION, "true")).join(); + + } + + @AfterEach + void cleanupAfterEach() { + observationRegistry.clear(); + } + + @Test + void sendAsync_toFifoQueue_shouldPropagateObservationScopeOnFirstCall() { + var parentObservation = Observation.start("parent-observation", observationRegistry); + var payload = new TestEvent(UUID.randomUUID().toString()); + String expectedTraceId; + + try (var ignored = parentObservation.openScope()) { + expectedTraceId = currentTraceContext.context().traceId(); + sqsTemplate.sendAsync(FIFO_QUEUE_NAME, payload).join(); + } + finally { + parentObservation.stop(); + } + + logger.info("expectedTraceId={}", expectedTraceId); + + var receivedMessage = sqsTemplate + .receive(from -> from.queue(FIFO_QUEUE_NAME).pollTimeout(Duration.ofSeconds(5)), TestEvent.class) + .orElseThrow(() -> new AssertionError("Expected message was not received")); + + assertThat(receivedMessage.getPayload()).isEqualTo(payload); + var traceparent = (String) receivedMessage.getHeaders().get("traceparent"); + assertThat(traceparent).as("traceparent header should be present").isNotNull(); + assertThat(traceparent).as("traceparent should contain the traceId").contains(expectedTraceId); + } + + @Test + void sendAsync_toFifoQueue_shouldCreateObservationOnCallingThreadAfterCacheHit() { + // Given - Warm up: send a message to populate the queue attribute cache + var warmupPayload = new TestEvent(UUID.randomUUID().toString()); + sqsTemplate.sendAsync(FIFO_CACHE_HIT_QUEUE_NAME, warmupPayload).join(); + + // Drain the warmup message + sqsTemplate.receive(from -> from.queue(FIFO_CACHE_HIT_QUEUE_NAME).pollTimeout(Duration.ofSeconds(5)), TestEvent.class); + + // Given - Start a NEW observation for the actual test + var observation = Observation.start("test-send-second", observationRegistry); + String expectedTraceId; + + var payload = new TestEvent(UUID.randomUUID().toString()); + try (var ignored = observation.openScope()) { + expectedTraceId = currentTraceContext.context().traceId(); + // When - Second call (cache hit - queue attributes already resolved) + sqsTemplate.sendAsync(FIFO_CACHE_HIT_QUEUE_NAME, payload).join(); + } + finally { + observation.stop(); + } + + logger.info("expectedTraceId={}", expectedTraceId); + + var receivedMessage = sqsTemplate + .receive(from -> from.queue(FIFO_CACHE_HIT_QUEUE_NAME).pollTimeout(Duration.ofSeconds(5)), TestEvent.class) + .orElseThrow(() -> new AssertionError("Expected message was not received")); + + assertThat(receivedMessage.getPayload()).isEqualTo(payload); + var traceparent = (String) receivedMessage.getHeaders().get("traceparent"); + assertThat(traceparent).as("traceparent header should be present").isNotNull(); + assertThat(traceparent).as("traceparent should contain the traceId").contains(expectedTraceId); + } + + @Configuration + static class TestConfiguration { + + @Bean + public SqsAsyncClient sqsAsyncClient() { + return createAsyncClient(); + } + + @Bean + public Tracer tracer() { + return new SimpleTracer(); + } + + @Bean + public CurrentTraceContext currentTraceContext(Tracer tracer) { + return ((SimpleTracer) tracer).currentTraceContext(); + } + + @Bean + public Propagator propagator(Tracer tracer) { + return new SimplePropagator(tracer); + } + + @Bean + public ObservationRegistry observationRegistry(Tracer tracer, Propagator propagator) { + TestObservationRegistry registry = TestObservationRegistry.create(); + registry.observationConfig().observationHandler(new DefaultTracingObservationHandler(tracer)); + registry.observationConfig() + .observationHandler(new PropagatingSenderTracingObservationHandler<>(tracer, propagator)); + registry.observationConfig() + .observationHandler(new PropagatingReceiverTracingObservationHandler<>(tracer, propagator)); + return registry; + } + + @Bean + public SqsTemplate sqsTemplate(SqsAsyncClient sqsAsyncClient, ObservationRegistry observationRegistry) { + return SqsTemplate.builder().sqsAsyncClient(sqsAsyncClient) + .configure(options -> options.observationRegistry(observationRegistry)).build(); + } + } + + /** + * Simple W3C Trace Context propagator for testing. In production, you would use a library like + * micrometer-tracing-bridge-brave or micrometer-tracing-bridge-otel which provide full-featured propagators. + */ + static class SimplePropagator implements Propagator { + + private final Tracer tracer; + + SimplePropagator(Tracer tracer) { + this.tracer = tracer; + } + + @Override + public List fields() { + return List.of("traceparent", "tracestate"); + } + + @Override + public void inject(TraceContext context, C carrier, Setter setter) { + // W3C Trace Context format: version-traceId-spanId-flags + var traceparent = String.format("00-%s-%s-01", context.traceId(), context.spanId()); + setter.set(carrier, "traceparent", traceparent); + } + + @Override + public Span.Builder extract(C carrier, Getter getter) { + var traceparent = getter.get(carrier, "traceparent"); + if (traceparent == null || traceparent.isEmpty()) { + return tracer.spanBuilder().setNoParent(); + } + // Parse W3C format: 00-traceId-spanId-01 + String[] parts = traceparent.split("-"); + if (parts.length < 4) { + return tracer.spanBuilder().setNoParent(); + } + // Use tracer to create span builder with extracted context + Span.Builder builder = tracer.spanBuilder(); + var traceContext = new SimpleTraceContext(); + traceContext.setTraceId(parts[1]); + traceContext.setParentId(parts[2]); + traceContext.setSpanId(parts[3]); + builder.setParent(traceContext); + return builder; + } + } + + record TestEvent(String data) { + } +} diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/operations/SqsTemplateObservationTest.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/operations/SqsTemplateObservationTest.java index 49405ece1..bdca3c448 100644 --- a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/operations/SqsTemplateObservationTest.java +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/operations/SqsTemplateObservationTest.java @@ -22,6 +22,7 @@ import io.awspring.cloud.sqs.listener.SqsHeaders; import io.awspring.cloud.sqs.support.observation.SqsTemplateObservation; import io.micrometer.common.KeyValues; +import io.micrometer.observation.Observation; import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; import java.util.UUID; @@ -163,6 +164,44 @@ void shouldCaptureErrorsInObservation() { .hasSingleObservationThat().hasError().assertThatError().isInstanceOf(RuntimeException.class); } + @Test + void shouldApplyCustomConventionWhenParentObservationIsPresent() { + // given + SqsTemplateObservation.Convention customConvention = mock(SqsTemplateObservation.Convention.class); + given(customConvention.supportsContext(any())).willReturn(true); + given(customConvention.getName()).willReturn("spring.aws.sqs.template"); + + String lowCardinalityCustomKeyName = "custom.lowCardinality.key"; + String lowCardinalityCustomValue = "custom-lowCardinality-value"; + String highCardinalityCustomKeyName = "custom.highCardinality.key"; + String highCardinalityCustomValue = "custom-highCardinality-value"; + given(customConvention.getLowCardinalityKeyValues(any())) + .willReturn(KeyValues.of(lowCardinalityCustomKeyName, lowCardinalityCustomValue)); + given(customConvention.getHighCardinalityKeyValues(any())) + .willReturn(KeyValues.of(highCardinalityCustomKeyName, highCardinalityCustomValue)); + + TestObservationRegistry customRegistry = TestObservationRegistry.create(); + + SqsTemplate templateWithCustomConvention = SqsTemplate.builder().sqsAsyncClient(mockSqsAsyncClient) + .configure( + options -> options.observationRegistry(customRegistry).observationConvention(customConvention)) + .build(); + + // when - send within a parent observation scope + Observation parentObservation = Observation.start("parent-observation", customRegistry); + try (var ignored = parentObservation.openScope()) { + templateWithCustomConvention.send(queueName, "test-payload"); + } + finally { + parentObservation.stop(); + } + + // then - custom convention should be applied even with parent observation + TestObservationRegistryAssert.then(customRegistry).hasNumberOfObservationsEqualTo(2) + .hasAnObservationWithAKeyValue(lowCardinalityCustomKeyName, lowCardinalityCustomValue) + .hasAnObservationWithAKeyValue(highCardinalityCustomKeyName, highCardinalityCustomValue); + } + @Test void shouldSupportCustomKeyValuesInActiveSending() { // given