Skip to content

Commit c1a16f3

Browse files
authored
feat: optionally assume a role (#13)
1 parent da9d44d commit c1a16f3

File tree

5 files changed

+54
-1
lines changed

5 files changed

+54
-1
lines changed

docs/data-sources/ssm.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ provider "postgresql" {
4444

4545
- `ssm_profile` (String) AWS profile name as set in credentials files. Can also be set using either the environment variables `AWS_PROFILE` or `AWS_DEFAULT_PROFILE`.
4646
- `ssm_region` (String) AWS Region where the instance is located. The Region must be set. Can also be set using either the environment variables `AWS_REGION` or `AWS_DEFAULT_REGION`.
47+
- `ssm_role_arn` (String) ARN of an IAM role to assume.
4748

4849
### Read-Only
4950

docs/ephemeral-resources/ssm.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ provider "postgresql" {
4444

4545
- `ssm_profile` (String) AWS profile name as set in credentials files. Can also be set using either the environment variables `AWS_PROFILE` or `AWS_DEFAULT_PROFILE`.
4646
- `ssm_region` (String) AWS Region where the instance is located. The Region must be set. Can also be set using either the environment variables `AWS_REGION` or `AWS_DEFAULT_REGION`.
47+
- `ssm_role_arn` (String) ARN of an IAM role to assume.
4748

4849
### Read-Only
4950

internal/provider/data_source_ssm.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ type SSMDataSourceModel struct {
2828
LocalPort types.Int64 `tfsdk:"local_port"`
2929
SSMInstance types.String `tfsdk:"ssm_instance"`
3030
SSMProfile types.String `tfsdk:"ssm_profile"`
31+
SSMRoleARN types.String `tfsdk:"ssm_role_arn"`
3132
SSMRegion types.String `tfsdk:"ssm_region"`
3233
TargetHost types.String `tfsdk:"target_host"`
3334
TargetPort types.Int64 `tfsdk:"target_port"`
@@ -60,6 +61,11 @@ func (d *SSMDataSource) Schema(ctx context.Context, req datasource.SchemaRequest
6061
Optional: true,
6162
Computed: true,
6263
},
64+
"ssm_role_arn": schema.StringAttribute{
65+
MarkdownDescription: "ARN of an IAM role to assume.",
66+
Optional: true,
67+
Computed: true,
68+
},
6369
"ssm_region": schema.StringAttribute{
6470
MarkdownDescription: "AWS Region where the instance is located. The Region must be set. Can also be set using either the environment variables `AWS_REGION` or `AWS_DEFAULT_REGION`.",
6571
Optional: true,
@@ -105,6 +111,7 @@ func (d *SSMDataSource) Read(ctx context.Context, req datasource.ReadRequest, re
105111
LocalPort: strconv.Itoa(localPort),
106112
SSMInstance: data.SSMInstance.ValueString(),
107113
SSMProfile: data.SSMProfile.ValueString(),
114+
SSMRoleARN: data.SSMRoleARN.ValueString(),
108115
SSMRegion: data.SSMRegion.ValueString(),
109116
TargetHost: data.TargetHost.ValueString(),
110117
TargetPort: strconv.Itoa(int(data.TargetPort.ValueInt64())),
@@ -119,8 +126,14 @@ func (d *SSMDataSource) Read(ctx context.Context, req datasource.ReadRequest, re
119126
tunnelCfg.SSMRegion = awsCfg.Region
120127
tunnelCfg.SSMProfile = ssm.GetSDKConfigProfile(awsCfg)
121128

129+
// Only update SSMRoleARN if it wasn't explicitly provided
130+
if tunnelCfg.SSMRoleARN == "" {
131+
tunnelCfg.SSMRoleARN = ssm.GetSDKConfigRole(awsCfg)
132+
}
133+
122134
data.SSMRegion = types.StringValue(tunnelCfg.SSMRegion)
123135
data.SSMProfile = types.StringValue(tunnelCfg.SSMProfile)
136+
data.SSMRoleARN = types.StringValue(tunnelCfg.SSMRoleARN)
124137

125138
_, err = ssm.ForkRemoteTunnel(ctx, awsCfg, tunnelCfg)
126139
if err != nil {

internal/provider/ephemeral_ssm.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ type SSMEphemeralModel struct {
2828
LocalPort types.Int64 `tfsdk:"local_port"`
2929
SSMInstance types.String `tfsdk:"ssm_instance"`
3030
SSMProfile types.String `tfsdk:"ssm_profile"`
31+
SSMRoleARN types.String `tfsdk:"ssm_role_arn"`
3132
SSMRegion types.String `tfsdk:"ssm_region"`
3233
TargetHost types.String `tfsdk:"target_host"`
3334
TargetPort types.Int64 `tfsdk:"target_port"`
@@ -60,6 +61,11 @@ func (d *SSMEphemeral) Schema(ctx context.Context, req ephemeral.SchemaRequest,
6061
Optional: true,
6162
Computed: true,
6263
},
64+
"ssm_role_arn": schema.StringAttribute{
65+
MarkdownDescription: "ARN of an IAM role to assume.",
66+
Optional: true,
67+
Computed: true,
68+
},
6369
"ssm_region": schema.StringAttribute{
6470
MarkdownDescription: "AWS Region where the instance is located. The Region must be set. Can also be set using either the environment variables `AWS_REGION` or `AWS_DEFAULT_REGION`.",
6571
Optional: true,
@@ -105,6 +111,7 @@ func (d *SSMEphemeral) Open(ctx context.Context, req ephemeral.OpenRequest, resp
105111
LocalPort: strconv.Itoa(localPort),
106112
SSMInstance: data.SSMInstance.ValueString(),
107113
SSMProfile: data.SSMProfile.ValueString(),
114+
SSMRoleARN: data.SSMRoleARN.ValueString(),
108115
SSMRegion: data.SSMRegion.ValueString(),
109116
TargetHost: data.TargetHost.ValueString(),
110117
TargetPort: strconv.Itoa(int(data.TargetPort.ValueInt64())),
@@ -119,8 +126,14 @@ func (d *SSMEphemeral) Open(ctx context.Context, req ephemeral.OpenRequest, resp
119126
tunnelCfg.SSMRegion = awsCfg.Region
120127
tunnelCfg.SSMProfile = ssm.GetSDKConfigProfile(awsCfg)
121128

129+
// Only update SSMRoleARN if it wasn't explicitly provided
130+
if tunnelCfg.SSMRoleARN == "" {
131+
tunnelCfg.SSMRoleARN = ssm.GetSDKConfigRole(awsCfg)
132+
}
133+
122134
data.SSMRegion = types.StringValue(tunnelCfg.SSMRegion)
123135
data.SSMProfile = types.StringValue(tunnelCfg.SSMProfile)
136+
data.SSMRoleARN = types.StringValue(tunnelCfg.SSMRoleARN)
124137

125138
cmd, err := ssm.ForkRemoteTunnel(ctx, awsCfg, tunnelCfg)
126139
if err != nil {

internal/ssm/session.go

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ import (
55

66
"github.com/aws/aws-sdk-go-v2/aws"
77
"github.com/aws/aws-sdk-go-v2/config"
8+
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
89
"github.com/aws/aws-sdk-go-v2/service/ssm"
10+
"github.com/aws/aws-sdk-go-v2/service/sts"
911
)
1012

1113
const DEFAULT_SSM_ENV_NAME = "AWS_SSM_START_SESSION_RESPONSE"
@@ -14,6 +16,7 @@ type TunnelConfig struct {
1416
LocalPort string
1517
SSMInstance string
1618
SSMProfile string
19+
SSMRoleARN string
1720
SSMRegion string
1821
TargetHost string
1922
TargetPort string
@@ -34,7 +37,20 @@ func GetNewSDKConfig(ctx context.Context, cfg TunnelConfig) (aws.Config, error)
3437
loadOptions = append(loadOptions, config.WithSharedConfigProfile(cfg.SSMProfile))
3538
}
3639

37-
return config.LoadDefaultConfig(ctx, loadOptions...)
40+
// Load base config first
41+
awsCfg, err := config.LoadDefaultConfig(ctx, loadOptions...)
42+
if err != nil {
43+
return aws.Config{}, err
44+
}
45+
46+
// If role assumption is required, create STS client and configure assume role
47+
if cfg.SSMRoleARN != "" {
48+
stsClient := sts.NewFromConfig(awsCfg)
49+
assumeRoleProvider := stscreds.NewAssumeRoleProvider(stsClient, cfg.SSMRoleARN)
50+
awsCfg.Credentials = aws.NewCredentialsCache(assumeRoleProvider)
51+
}
52+
53+
return awsCfg, nil
3854
}
3955

4056
func GetSDKConfigProfile(awsCfg aws.Config) string {
@@ -46,6 +62,15 @@ func GetSDKConfigProfile(awsCfg aws.Config) string {
4662
return ""
4763
}
4864

65+
func GetSDKConfigRole(awsCfg aws.Config) string {
66+
for _, cfg := range awsCfg.ConfigSources {
67+
if p, ok := cfg.(config.SharedConfig); ok {
68+
return p.RoleARN
69+
}
70+
}
71+
return ""
72+
}
73+
4974
func CreateSessionInput(cfg TunnelConfig) ssm.StartSessionInput {
5075
reqParams := make(map[string][]string)
5176
reqParams["portNumber"] = []string{cfg.TargetPort}

0 commit comments

Comments
 (0)