Skip to content

Commit 6e6ab6e

Browse files
authored
FIX: sts headers override in AWS secret extension (#5506)
Signed-off-by: George Chen <[email protected]>
1 parent 1098421 commit 6e6ab6e

File tree

3 files changed

+62
-0
lines changed

3 files changed

+62
-0
lines changed

Diff for: data-prepper-plugins/aws-plugin/src/main/java/org/opensearch/dataprepper/plugins/aws/AwsSecretManagerConfiguration.java

+10
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;
2222

2323
import java.time.Duration;
24+
import java.util.Map;
2425
import java.util.Optional;
2526
import java.util.UUID;
2627

@@ -42,6 +43,10 @@ public class AwsSecretManagerConfiguration {
4243
@Size(min = 20, max = 2048, message = "awsStsRoleArn length should be between 1 and 2048 characters")
4344
private String awsStsRoleArn;
4445

46+
@JsonProperty("sts_header_overrides")
47+
@Size(max = 5, message = "sts_header_overrides supports a maximum of 5 headers to override")
48+
private Map<String, String> awsStsHeaderOverrides;
49+
4550
@JsonProperty("refresh_interval")
4651
@NotNull(message = "refresh_interval must not be null")
4752
@DurationMin(hours = 1L, message = "Refresh interval must be at least 1 hour.")
@@ -101,6 +106,11 @@ private AwsCredentialsProvider authenticateAwsConfiguration() {
101106
.roleSessionName("aws-secret-" + UUID.randomUUID())
102107
.roleArn(awsStsRoleArn);
103108

109+
if (awsStsHeaderOverrides != null && !awsStsHeaderOverrides.isEmpty()) {
110+
assumeRoleRequestBuilder = assumeRoleRequestBuilder.overrideConfiguration(
111+
configuration -> awsStsHeaderOverrides.forEach(configuration::putHeader));
112+
}
113+
104114
awsCredentialsProvider = StsAssumeRoleCredentialsProvider.builder()
105115
.stsClient(stsClient)
106116
.refreshRequest(assumeRoleRequestBuilder.build())

Diff for: data-prepper-plugins/aws-plugin/src/test/java/org/opensearch/dataprepper/plugins/aws/AwsSecretManagerConfigurationTest.java

+47
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,19 @@
3030
import org.mockito.junit.jupiter.MockitoExtension;
3131
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
3232
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
33+
import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
3334
import software.amazon.awssdk.regions.Region;
3435
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
3536
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClientBuilder;
3637
import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest;
3738
import software.amazon.awssdk.services.secretsmanager.model.PutSecretValueRequest;
3839
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;
40+
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;
3941

4042
import java.io.IOException;
4143
import java.io.InputStream;
4244
import java.time.Duration;
45+
import java.util.List;
4346
import java.util.Set;
4447

4548
import static org.hamcrest.CoreMatchers.equalTo;
@@ -49,6 +52,7 @@
4952
import static org.junit.jupiter.api.Assertions.assertThrows;
5053
import static org.mockito.ArgumentMatchers.any;
5154
import static org.mockito.ArgumentMatchers.anyString;
55+
import static org.mockito.Mockito.mock;
5256
import static org.mockito.Mockito.mockStatic;
5357
import static org.mockito.Mockito.verify;
5458
import static org.mockito.Mockito.when;
@@ -207,6 +211,49 @@ void testCreateSecretManagerClientWithStsCredential() throws IOException {
207211
assertThat(awsCredentialsProvider, instanceOf(StsAssumeRoleCredentialsProvider.class));
208212
}
209213

214+
@Test
215+
void testCreateSecretManagerClientWithStsHeaderOverrides() throws IOException {
216+
final InputStream inputStream = AwsSecretPluginConfigTest.class.getResourceAsStream(
217+
"/test-aws-secret-manager-configuration-with-sts-headers.yaml");
218+
final AwsSecretManagerConfiguration awsSecretManagerConfiguration = objectMapper.readValue(
219+
inputStream, AwsSecretManagerConfiguration.class);
220+
assertThat(awsSecretManagerConfiguration.getAwsSecretId(), equalTo("test-secret"));
221+
final StsAssumeRoleCredentialsProvider.Builder stsAssumeRoleCredentialsProviderBuilder =
222+
mock(StsAssumeRoleCredentialsProvider.Builder.class);
223+
final StsAssumeRoleCredentialsProvider stsAssumeRoleCredentialsProvider =
224+
mock(StsAssumeRoleCredentialsProvider.class);
225+
when(stsAssumeRoleCredentialsProviderBuilder.stsClient(any()))
226+
.thenReturn(stsAssumeRoleCredentialsProviderBuilder);
227+
when(stsAssumeRoleCredentialsProviderBuilder.refreshRequest(any(AssumeRoleRequest.class)))
228+
.thenReturn(stsAssumeRoleCredentialsProviderBuilder);
229+
when(stsAssumeRoleCredentialsProviderBuilder.build()).thenReturn(stsAssumeRoleCredentialsProvider);
230+
when(secretsManagerClientBuilder.region(any(Region.class))).thenReturn(secretsManagerClientBuilder);
231+
when(secretsManagerClientBuilder.credentialsProvider(any(AwsCredentialsProvider.class)))
232+
.thenReturn(secretsManagerClientBuilder);
233+
when(secretsManagerClientBuilder.build()).thenReturn(secretsManagerClient);
234+
try (final MockedStatic<SecretsManagerClient> secretsManagerClientMockedStatic = mockStatic(
235+
SecretsManagerClient.class);
236+
final MockedStatic<StsAssumeRoleCredentialsProvider> stsAssumeRoleCredentialsProviderMockedStatic =
237+
mockStatic(StsAssumeRoleCredentialsProvider.class)) {
238+
secretsManagerClientMockedStatic.when(SecretsManagerClient::builder).thenReturn(secretsManagerClientBuilder);
239+
stsAssumeRoleCredentialsProviderMockedStatic.when(StsAssumeRoleCredentialsProvider::builder).thenReturn(
240+
stsAssumeRoleCredentialsProviderBuilder);
241+
assertThat(awsSecretManagerConfiguration.createSecretManagerClient(), is(secretsManagerClient));
242+
}
243+
verify(secretsManagerClientBuilder).credentialsProvider(awsCredentialsProviderArgumentCaptor.capture());
244+
final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsProviderArgumentCaptor.getValue();
245+
assertThat(awsCredentialsProvider, instanceOf(StsAssumeRoleCredentialsProvider.class));
246+
final ArgumentCaptor<AssumeRoleRequest> assumeRoleRequestArgumentCaptor =
247+
ArgumentCaptor.forClass(AssumeRoleRequest.class);
248+
verify(stsAssumeRoleCredentialsProviderBuilder).refreshRequest(assumeRoleRequestArgumentCaptor.capture());
249+
final AssumeRoleRequest assumeRoleRequest = assumeRoleRequestArgumentCaptor.getValue();
250+
assertThat(assumeRoleRequest.overrideConfiguration().isPresent(), is(true));
251+
final AwsRequestOverrideConfiguration awsRequestOverrideConfiguration = assumeRoleRequest
252+
.overrideConfiguration().get();
253+
assertThat(awsRequestOverrideConfiguration.headers().size(), equalTo(1));
254+
assertThat(awsRequestOverrideConfiguration.headers().get("test-header"), equalTo(List.of("test-value")));
255+
}
256+
210257
@ParameterizedTest
211258
@ValueSource(strings = {
212259
"/test-aws-secret-manager-configuration-invalid-sts-1.yaml",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
secret_id: test-secret
2+
region: us-east-1
3+
sts_role_arn: arn:aws:iam::123456789012:role/test-role
4+
sts_header_overrides:
5+
test-header: test-value

0 commit comments

Comments
 (0)