Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions examples/src/main/java/com/salesforce/multicloudj/sts/Main.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@

public class Main {

static String provider = "gcp";
static String provider = "aws";

public static void main(String[] args) {
assumeRole();
assumeRoleWebIdentityCredentialsOverrider();
getCallerIdentity();
nativeAuthSignerUtilityWithStsCredentials();
nativeAuthSignerUtilityWithDefaultCredentials();
//assumeRoleWebIdentityCredentialsOverrider();
//getCallerIdentity();
//nativeAuthSignerUtilityWithStsCredentials();
//nativeAuthSignerUtilityWithDefaultCredentials();
}

public static void assumeRole() {
Expand All @@ -55,9 +55,9 @@ public static void assumeRole() {
.build();

AssumedRoleRequest request = AssumedRoleRequest.newBuilder()
.withRole("chameleon@substrate-sdk-gcp-poc1.iam.gserviceaccount.com")
.withRole("arn:aws:iam::654654370895:role/chameleon-multi--f4msu63ppffhs")
.withSessionName("my-session")
.withCredentialScope(credentialScope)
//.withCredentialScope(credentialScope)
.build();
StsCredentials stsCredentials = client.getAssumeRoleCredentials(request);

Expand All @@ -76,7 +76,7 @@ public static void assumeRoleWebIdentityCredentialsOverrider() {
.withWebIdentityTokenSupplier(tokenSupplier)
.build();
BucketClient bucketClient = BucketClient.builder(provider)
.withRegion("us-west-2").withBucket("chameleon-jclouds")
.withRegion("us-west-2").withBucket("chameleon-jcloud")
.withCredentialsOverrider(overrider)
.build();
ListBlobsPageResponse r=bucketClient.listPage(ListBlobsPageRequest.builder().withMaxResults(1).build());
Expand Down
11 changes: 11 additions & 0 deletions multicloudj-common-aws/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,17 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>cloudwatch</artifactId>
<version>2.35.0</version>
<exclusions>
<exclusion>
<groupId>*</groupId>
<artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>aws-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package com.salesforce.multicloudj.common.aws;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.core.SdkRequest;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttribute;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.services.cloudwatch.CloudWatchClient;
import software.amazon.awssdk.services.cloudwatch.model.MetricDatum;
import software.amazon.awssdk.services.cloudwatch.model.PutMetricDataRequest;
import software.amazon.awssdk.services.cloudwatch.model.StandardUnit;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;

import java.time.Duration;
import java.time.Instant;

public class StsDiagnosticsInterceptor implements ExecutionInterceptor {
private static final Logger log = LoggerFactory.getLogger(StsDiagnosticsInterceptor.class);
private static final ExecutionAttribute<Instant> START_TIME = new ExecutionAttribute<>("StartTime");
private final CloudWatchClient cwClient;
private final String namespace;

public StsDiagnosticsInterceptor(CloudWatchClient cwClient, String namespace) {
this.cwClient = cwClient;
this.namespace = namespace;
}

@Override
public void beforeExecution(Context.BeforeExecution context, ExecutionAttributes executionAttributes) {
executionAttributes.putAttribute(START_TIME, Instant.now());
}

@Override
public void afterExecution(Context.AfterExecution context, ExecutionAttributes executionAttributes) {
// Correctly passing the SdkRequest
processCompletion(context.request(), executionAttributes, "SUCCESS", null);
}

@Override
public void onExecutionFailure(Context.FailedExecution context, ExecutionAttributes executionAttributes) {
// Correctly passing the SdkRequest
processCompletion(context.request(), executionAttributes, "FAILURE", context.exception().getMessage());
}

private void processCompletion(SdkRequest request, ExecutionAttributes executionAttributes, String status, String errorMsg) {
String action = request.getClass().getSimpleName();
Instant start = executionAttributes.getAttribute(START_TIME);
long duration = (start != null) ? Duration.between(start, Instant.now()).toMillis() : 0;

// FIXED: The type-safe way to check and cast
String targetRole = "N/A";
if (request instanceof AssumeRoleRequest) {
// Because AssumeRoleRequest implements SdkRequest through the STS hierarchy,
// this cast is valid and safe within this block.
targetRole = ((AssumeRoleRequest) request).roleArn();
}

// 1. STRUCTURED LOGGING
log.info("STS_COMPLETE | Action: {} | Status: {} | Duration: {}ms | Role: {} | Error: {}",
action, status, duration, targetRole, (errorMsg != null ? errorMsg : "None"));

// 2. METRIC PUBLISHING
publishToCloudWatch(action, status, (double) duration, targetRole);
}

private void publishToCloudWatch(String action, String status, Double duration, String roleArn) {
try {
// Use a background thread or EMF in high-volume prod to avoid blocking
MetricDatum count = MetricDatum.builder()
.metricName("StsCallCount")
.dimensions(d -> d.name("Action").value(action),
d -> d.name("Status").value(status))
.value(1.0).unit(StandardUnit.COUNT).build();

cwClient.putMetricData(PutMetricDataRequest.builder()
.namespace(this.namespace)
.metricData(count)
.build());
} catch (Exception e) {
log.warn("Telemetry Error: Could not publish STS metrics", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.auto.service.AutoService;
import com.salesforce.multicloudj.common.aws.CommonErrorCodeMapping;
import com.salesforce.multicloudj.common.aws.StsDiagnosticsInterceptor;
import com.salesforce.multicloudj.common.exceptions.InvalidArgumentException;
import com.salesforce.multicloudj.common.exceptions.SubstrateSdkException;
import com.salesforce.multicloudj.common.exceptions.UnAuthorizedException;
Expand All @@ -17,6 +18,8 @@
import com.salesforce.multicloudj.sts.model.StsCredentials;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.cloudwatch.CloudWatchClient;
import software.amazon.awssdk.services.cloudwatch.CloudWatchClientBuilder;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.StsClientBuilder;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;
Expand Down Expand Up @@ -45,9 +48,13 @@ public AwsSts(Builder builder) {
super(builder);
Region region = Region.of(builder.getRegion());
StsClientBuilder sb = StsClient.builder().region(region);
CloudWatchClientBuilder cwb = CloudWatchClient.builder().region(region);
if (builder.getEndpoint() != null) {
sb = sb.endpointOverride(builder.getEndpoint());
cwb = cwb.endpointOverride(builder.getEndpoint());
}
CloudWatchClient cw = cwb.build();
sb.overrideConfiguration(cfg -> cfg .addExecutionInterceptor(new StsDiagnosticsInterceptor(cw, this.getClass().getName())));
this.stsClient = sb.build();
}

Expand Down
Loading