Skip to content
34 changes: 34 additions & 0 deletions src/main/java/io/kestra/plugin/aws/cloudwatch/CloudWatchLogs.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package io.kestra.plugin.aws.cloudwatch;

import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.runners.RunContext;
import io.kestra.plugin.aws.AbstractConnection;
import io.kestra.plugin.aws.ConnectionUtils;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.ToString;
import lombok.experimental.SuperBuilder;
import software.amazon.awssdk.services.cloudwatchlogs.CloudWatchLogsClient;

@SuperBuilder
@ToString
@EqualsAndHashCode
@Getter
@NoArgsConstructor
public class CloudWatchLogs extends AbstractConnection {

/**
* Create a CloudWatchLogsClient using the standard Kestra AWS configuration
* (credentials, region, endpoint overrides, etc.).
*/
public CloudWatchLogsClient logsClient(final RunContext runContext)
throws IllegalVariableEvaluationException {

final AwsClientConfig clientConfig = awsClientConfig(runContext);

return ConnectionUtils
.configureSyncClient(clientConfig, CloudWatchLogsClient.builder())
.build();
}
}
81 changes: 70 additions & 11 deletions src/main/java/io/kestra/plugin/aws/lambda/Invoke.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
package io.kestra.plugin.aws.lambda;

import java.io.File;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.URI;
import java.nio.file.Files;
import java.time.Duration;
import java.time.Instant;
import java.util.Map;
import java.util.Optional;

import org.apache.http.HttpHeaders;
import org.apache.http.entity.ContentType;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;

import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.models.annotations.Example;
import io.kestra.core.models.annotations.Metric;
Expand All @@ -15,6 +29,7 @@
import io.kestra.core.serializers.JacksonMapper;
import io.kestra.plugin.aws.AbstractConnection;
import io.kestra.plugin.aws.ConnectionUtils;
import io.kestra.plugin.aws.cloudwatch.CloudWatchLogs;
import io.kestra.plugin.aws.lambda.Invoke.Output;
import io.kestra.plugin.aws.s3.ObjectOutput;
import io.swagger.v3.oas.annotations.media.Schema;
Expand All @@ -25,23 +40,15 @@
import lombok.ToString;
import lombok.experimental.SuperBuilder;
import lombok.extern.slf4j.Slf4j;
import org.apache.http.HttpHeaders;
import org.apache.http.entity.ContentType;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.cloudwatchlogs.CloudWatchLogsClient;
import software.amazon.awssdk.services.cloudwatchlogs.model.FilterLogEventsRequest;
import software.amazon.awssdk.services.cloudwatchlogs.model.FilteredLogEvent;
import software.amazon.awssdk.services.lambda.LambdaClient;
import software.amazon.awssdk.services.lambda.model.InvokeRequest;
import software.amazon.awssdk.services.lambda.model.InvokeResponse;
import software.amazon.awssdk.services.lambda.model.LambdaException;

import java.io.File;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.URI;
import java.nio.file.Files;
import java.time.Duration;
import java.util.Map;
import java.util.Optional;

@SuperBuilder
@ToString
@EqualsAndHashCode
Expand Down Expand Up @@ -112,6 +119,7 @@ public class Invoke extends AbstractConnection implements RunnableTask<Output> {
@Override
public Output run(RunContext runContext) throws Exception {
final long start = System.nanoTime();
final Instant invocationStart = Instant.now().minusSeconds(5);
var functionArn = runContext.render(this.functionArn).as(String.class).orElseThrow();
var requestPayload = runContext.render(this.functionPayload).asMap(String.class, Object.class).isEmpty() ?
null :
Expand Down Expand Up @@ -141,13 +149,19 @@ public Output run(RunContext runContext) throws Exception {
logger.debug("Lambda {} invoked successfully", functionArn);
}
Output out = handleContent(runContext, functionArn, contentType, res.payload());
fetchAndLogLambdaLogs(runContext, functionArn, invocationStart);
runContext.metric(Timer.of("duration", Duration.ofNanos(System.nanoTime() - start)));
return out;
} catch (LambdaException e) {
throw new LambdaInvokeException("Lambda Invoke task execution failed for function: " + functionArn, e);
}
}

@VisibleForTesting
CloudWatchLogsClient getCloudWatchLogsClient(RunContext runContext) throws IllegalVariableEvaluationException {
return new CloudWatchLogs().logsClient(runContext);
}

@VisibleForTesting
LambdaClient client(final RunContext runContext) throws IllegalVariableEvaluationException {
final AwsClientConfig clientConfig = awsClientConfig(runContext);
Expand Down Expand Up @@ -195,6 +209,51 @@ Optional<String> readError(String payload) {
return Optional.empty();
}

@VisibleForTesting
void fetchAndLogLambdaLogs(
RunContext runContext,
String functionArn,
Instant startTime
) {
var logger = runContext.logger();

// Extract function name from ARN
String functionName;
try {
functionName = functionArn.split(":function:")[1].split(":")[0];
} catch (Exception e) {
logger.warn("Unable to determine Lambda function name from ARN: {}", functionArn);
return;
}

String logGroupName = "/aws/lambda/" + functionName;

try (CloudWatchLogsClient logsClient = getCloudWatchLogsClient(runContext)) {
FilterLogEventsRequest request = FilterLogEventsRequest.builder()
.logGroupName(logGroupName)
.startTime(startTime.toEpochMilli())
.build();

// Fetch logs using CloudWatch paginator
logsClient.filterLogEventsPaginator(request)
.events()
.stream()
// Hard cap to prevent excessive log volume in a single task execution
.limit(1_000)
.map(FilteredLogEvent::message)
.filter(message -> message != null && !message.isBlank())
.forEach(message -> logger.info("[lambda] {}", message));

} catch (Exception e) {
// Logs must never fail the task execution
logger.warn(
"Failed to fetch CloudWatch logs for Lambda {}: {}",
functionArn,
e.getMessage()
);
}
}

@VisibleForTesting
void handleError(String functionArn, ContentType contentType, SdkBytes payload) {
String errorPayload;
Expand Down
96 changes: 73 additions & 23 deletions src/test/java/io/kestra/plugin/aws/lambda/InvokeUnitTest.java
Original file line number Diff line number Diff line change
@@ -1,45 +1,57 @@
package io.kestra.plugin.aws.lambda;

import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.models.property.Property;
import io.kestra.core.runners.RunContext;
import io.kestra.core.runners.RunContextProperty;
import io.kestra.core.runners.WorkingDir;
import io.kestra.core.storages.Storage;
import io.kestra.plugin.aws.lambda.Invoke.Output;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.net.URI;
import java.nio.file.Files;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.apache.http.entity.ContentType;
import org.junit.jupiter.api.AfterEach;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
import org.mockito.Mock;
import org.mockito.Mock.Strictness;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.stubbing.Answer;
import org.slf4j.Logger;

import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.models.property.Property;
import io.kestra.core.runners.RunContext;
import io.kestra.core.runners.RunContextProperty;
import io.kestra.core.runners.WorkingDir;
import io.kestra.core.storages.Storage;
import io.kestra.plugin.aws.lambda.Invoke.Output;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.http.SdkHttpResponse;
import software.amazon.awssdk.services.cloudwatchlogs.CloudWatchLogsClient;
import software.amazon.awssdk.services.cloudwatchlogs.model.FilterLogEventsRequest;
import software.amazon.awssdk.services.cloudwatchlogs.model.FilterLogEventsResponse;
import software.amazon.awssdk.services.cloudwatchlogs.model.FilteredLogEvent;
import software.amazon.awssdk.services.cloudwatchlogs.paginators.FilterLogEventsIterable;
import software.amazon.awssdk.services.lambda.LambdaClient;
import software.amazon.awssdk.services.lambda.model.InvokeRequest;
import software.amazon.awssdk.services.lambda.model.InvokeResponse;

import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.net.URI;
import java.nio.file.Files;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.spy;

@ExtendWith(MockitoExtension.class)
public class InvokeUnitTest {

Expand All @@ -48,6 +60,9 @@ public class InvokeUnitTest {
@Mock(strictness = Strictness.LENIENT)
private RunContext context;

@Mock
private CloudWatchLogsClient logsClient;

@Mock(strictness = Strictness.LENIENT)
private RunContextProperty runContextProperty;

Expand Down Expand Up @@ -203,4 +218,39 @@ void givenFunctionArnNoParams_whenInvokeLambda_thenOutputWithFile(
// Then
checkOutput(data, res);
}

@Test
void givenLambdaInvocation_whenCloudWatchLogsEnabled_thenLogsAreFetchedAndLogged() throws Exception {
// Given
FilteredLogEvent logEvent = FilteredLogEvent.builder()
.message("Hello from CloudWatch Logs")
.timestamp(Instant.now().toEpochMilli())
.build();

// Mock paginator
FilterLogEventsIterable paginator = mock(FilterLogEventsIterable.class);

// Correct AWS SDK iterable for events()
software.amazon.awssdk.core.pagination.sync.SdkIterable<FilteredLogEvent> eventIterable =
() -> List.of(logEvent).iterator();

given(paginator.events()).willReturn(eventIterable);
given(logsClient.filterLogEventsPaginator(any(FilterLogEventsRequest.class)))
.willReturn(paginator);

// Spy Invoke to inject mocked logs client
Invoke spyInvoke = spy(invoke);
doReturn(logsClient).when(spyInvoke).getCloudWatchLogsClient(any());

// When
spyInvoke.fetchAndLogLambdaLogs(
context,
"arn:aws:lambda:eu-central-1:123456789012:function:test",
Instant.now()
);

// Then
verify(logger).info("[lambda] {}", "Hello from CloudWatch Logs");
}

}
Loading