Skip to content

Commit 8ec4082

Browse files
authored
Merge pull request #9 from arnaud-dfns/feat/ssm-profile
feat: add aws ssm profile setting
2 parents cce9dad + b15a509 commit 8ec4082

File tree

6 files changed

+99
-32
lines changed

6 files changed

+99
-32
lines changed

docs/data-sources/ssm.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,14 @@ provider "postgresql" {
3737
### Required
3838

3939
- `ssm_instance` (String) Specify the exact Instance ID of the managed node to connect to for the session
40-
- `ssm_region` (String) AWS Region where the instance is located
4140
- `target_host` (String) The DNS name or IP address of the remote host
4241
- `target_port` (Number) The port number of the remote host
4342

43+
### Optional
44+
45+
- `ssm_profile` (String) AWS profile name as set in credentials files. Can also be set using either the environment variables `AWS_PROFILE` or `AWS_DEFAULT_PROFILE`.
46+
- `ssm_region` (String) AWS Region where the instance is located. The Region must be set. Can also be set using either the environment variables `AWS_REGION` or `AWS_DEFAULT_REGION`.
47+
4448
### Read-Only
4549

4650
- `local_host` (String) The DNS name or IP address of the local host

docs/ephemeral-resources/ssm.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,14 @@ provider "postgresql" {
3737
### Required
3838

3939
- `ssm_instance` (String) Specify the exact Instance ID of the managed node to connect to for the session
40-
- `ssm_region` (String) AWS Region where the instance is located
4140
- `target_host` (String) The DNS name or IP address of the remote host
4241
- `target_port` (Number) The port number of the remote host
4342

43+
### Optional
44+
45+
- `ssm_profile` (String) AWS profile name as set in credentials files. Can also be set using either the environment variables `AWS_PROFILE` or `AWS_DEFAULT_PROFILE`.
46+
- `ssm_region` (String) AWS Region where the instance is located. The Region must be set. Can also be set using either the environment variables `AWS_REGION` or `AWS_DEFAULT_REGION`.
47+
4448
### Read-Only
4549

4650
- `local_host` (String) The DNS name or IP address of the local host

internal/provider/data_source_ssm.go

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@ type SSMDataSource struct{}
2424

2525
// SSMDataSourceModel describes the data source data model.
2626
type SSMDataSourceModel struct {
27-
TargetHost types.String `tfsdk:"target_host"`
28-
TargetPort types.Int64 `tfsdk:"target_port"`
2927
LocalHost types.String `tfsdk:"local_host"`
3028
LocalPort types.Int64 `tfsdk:"local_port"`
3129
SSMInstance types.String `tfsdk:"ssm_instance"`
30+
SSMProfile types.String `tfsdk:"ssm_profile"`
3231
SSMRegion types.String `tfsdk:"ssm_region"`
32+
TargetHost types.String `tfsdk:"target_host"`
33+
TargetPort types.Int64 `tfsdk:"target_port"`
3334
}
3435

3536
func (d *SSMDataSource) Metadata(ctx context.Context, req datasource.MetadataRequest, resp *datasource.MetadataResponse) {
@@ -54,9 +55,15 @@ func (d *SSMDataSource) Schema(ctx context.Context, req datasource.SchemaRequest
5455
MarkdownDescription: "Specify the exact Instance ID of the managed node to connect to for the session",
5556
Required: true,
5657
},
58+
"ssm_profile": schema.StringAttribute{
59+
MarkdownDescription: "AWS profile name as set in credentials files. Can also be set using either the environment variables `AWS_PROFILE` or `AWS_DEFAULT_PROFILE`.",
60+
Optional: true,
61+
Computed: true,
62+
},
5763
"ssm_region": schema.StringAttribute{
58-
MarkdownDescription: "AWS Region where the instance is located",
59-
Required: true,
64+
MarkdownDescription: "AWS Region where the instance is located. The Region must be set. Can also be set using either the environment variables `AWS_REGION` or `AWS_DEFAULT_REGION`.",
65+
Optional: true,
66+
Computed: true,
6067
},
6168

6269
// Computed attributes
@@ -94,13 +101,28 @@ func (d *SSMDataSource) Read(ctx context.Context, req datasource.ReadRequest, re
94101
data.LocalHost = types.StringValue("localhost")
95102
data.LocalPort = types.Int64Value(int64(localPort))
96103

97-
_, err = ssm.ForkRemoteTunnel(ctx, ssm.TunnelConfig{
98-
SSMRegion: data.SSMRegion.ValueString(),
104+
tunnelCfg := ssm.TunnelConfig{
105+
LocalPort: strconv.Itoa(localPort),
99106
SSMInstance: data.SSMInstance.ValueString(),
107+
SSMProfile: data.SSMProfile.ValueString(),
108+
SSMRegion: data.SSMRegion.ValueString(),
100109
TargetHost: data.TargetHost.ValueString(),
101110
TargetPort: strconv.Itoa(int(data.TargetPort.ValueInt64())),
102-
LocalPort: strconv.Itoa(localPort),
103-
})
111+
}
112+
113+
awsCfg, err := ssm.GetNewSDKConfig(ctx, tunnelCfg)
114+
if err != nil {
115+
resp.Diagnostics.AddError("Failed to initialize AWS SDK", fmt.Sprintf("Error: %s", err))
116+
return
117+
}
118+
119+
tunnelCfg.SSMRegion = awsCfg.Region
120+
tunnelCfg.SSMProfile = ssm.GetSDKConfigProfile(awsCfg)
121+
122+
data.SSMRegion = types.StringValue(tunnelCfg.SSMRegion)
123+
data.SSMProfile = types.StringValue(tunnelCfg.SSMProfile)
124+
125+
_, err = ssm.ForkRemoteTunnel(ctx, awsCfg, tunnelCfg)
104126
if err != nil {
105127
resp.Diagnostics.AddError("Failed to fork tunnel process", fmt.Sprintf("Error: %s", err))
106128
return

internal/provider/ephemeral_ssm.go

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@ type SSMEphemeral struct{}
2424

2525
// SSMEphemeralModel describes the data source data model.
2626
type SSMEphemeralModel struct {
27-
TargetHost types.String `tfsdk:"target_host"`
28-
TargetPort types.Int64 `tfsdk:"target_port"`
2927
LocalHost types.String `tfsdk:"local_host"`
3028
LocalPort types.Int64 `tfsdk:"local_port"`
3129
SSMInstance types.String `tfsdk:"ssm_instance"`
30+
SSMProfile types.String `tfsdk:"ssm_profile"`
3231
SSMRegion types.String `tfsdk:"ssm_region"`
32+
TargetHost types.String `tfsdk:"target_host"`
33+
TargetPort types.Int64 `tfsdk:"target_port"`
3334
}
3435

3536
func (d *SSMEphemeral) Metadata(ctx context.Context, req ephemeral.MetadataRequest, resp *ephemeral.MetadataResponse) {
@@ -54,9 +55,15 @@ func (d *SSMEphemeral) Schema(ctx context.Context, req ephemeral.SchemaRequest,
5455
MarkdownDescription: "Specify the exact Instance ID of the managed node to connect to for the session",
5556
Required: true,
5657
},
58+
"ssm_profile": schema.StringAttribute{
59+
MarkdownDescription: "AWS profile name as set in credentials files. Can also be set using either the environment variables `AWS_PROFILE` or `AWS_DEFAULT_PROFILE`.",
60+
Optional: true,
61+
Computed: true,
62+
},
5763
"ssm_region": schema.StringAttribute{
58-
MarkdownDescription: "AWS Region where the instance is located",
59-
Required: true,
64+
MarkdownDescription: "AWS Region where the instance is located. The Region must be set. Can also be set using either the environment variables `AWS_REGION` or `AWS_DEFAULT_REGION`.",
65+
Optional: true,
66+
Computed: true,
6067
},
6168

6269
// Computed attributes
@@ -94,13 +101,28 @@ func (d *SSMEphemeral) Open(ctx context.Context, req ephemeral.OpenRequest, resp
94101
data.LocalHost = types.StringValue("localhost")
95102
data.LocalPort = types.Int64Value(int64(localPort))
96103

97-
cmd, err := ssm.ForkRemoteTunnel(ctx, ssm.TunnelConfig{
98-
SSMRegion: data.SSMRegion.ValueString(),
104+
tunnelCfg := ssm.TunnelConfig{
105+
LocalPort: strconv.Itoa(localPort),
99106
SSMInstance: data.SSMInstance.ValueString(),
107+
SSMProfile: data.SSMProfile.ValueString(),
108+
SSMRegion: data.SSMRegion.ValueString(),
100109
TargetHost: data.TargetHost.ValueString(),
101110
TargetPort: strconv.Itoa(int(data.TargetPort.ValueInt64())),
102-
LocalPort: strconv.Itoa(localPort),
103-
})
111+
}
112+
113+
awsCfg, err := ssm.GetNewSDKConfig(ctx, tunnelCfg)
114+
if err != nil {
115+
resp.Diagnostics.AddError("Failed to initialize AWS SDK", fmt.Sprintf("Error: %s", err))
116+
return
117+
}
118+
119+
tunnelCfg.SSMRegion = awsCfg.Region
120+
tunnelCfg.SSMProfile = ssm.GetSDKConfigProfile(awsCfg)
121+
122+
data.SSMRegion = types.StringValue(tunnelCfg.SSMRegion)
123+
data.SSMProfile = types.StringValue(tunnelCfg.SSMProfile)
124+
125+
cmd, err := ssm.ForkRemoteTunnel(ctx, awsCfg, tunnelCfg)
104126
if err != nil {
105127
resp.Diagnostics.AddError("Failed to fork tunnel process", fmt.Sprintf("Error: %s", err))
106128
return

internal/ssm/session.go

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@ import (
1111
const DEFAULT_SSM_ENV_NAME = "AWS_SSM_START_SESSION_RESPONSE"
1212

1313
type TunnelConfig struct {
14-
SSMRegion string
14+
LocalPort string
1515
SSMInstance string
16+
SSMProfile string
17+
SSMRegion string
1618
TargetHost string
1719
TargetPort string
18-
LocalPort string
1920
}
2021

2122
type SessionParams struct {
@@ -24,6 +25,27 @@ type SessionParams struct {
2425
StreamUrl string
2526
}
2627

28+
func GetNewSDKConfig(ctx context.Context, cfg TunnelConfig) (aws.Config, error) {
29+
loadOptions := []func(*config.LoadOptions) error{}
30+
if cfg.SSMRegion != "" {
31+
loadOptions = append(loadOptions, config.WithRegion(cfg.SSMRegion))
32+
}
33+
if cfg.SSMProfile != "" {
34+
loadOptions = append(loadOptions, config.WithSharedConfigProfile(cfg.SSMProfile))
35+
}
36+
37+
return config.LoadDefaultConfig(ctx, loadOptions...)
38+
}
39+
40+
func GetSDKConfigProfile(awsCfg aws.Config) string {
41+
for _, cfg := range awsCfg.ConfigSources {
42+
if p, ok := cfg.(config.SharedConfig); ok {
43+
return p.Profile
44+
}
45+
}
46+
return ""
47+
}
48+
2749
func CreateSessionInput(cfg TunnelConfig) ssm.StartSessionInput {
2850
reqParams := make(map[string][]string)
2951
reqParams["portNumber"] = []string{cfg.TargetPort}
@@ -37,14 +59,7 @@ func CreateSessionInput(cfg TunnelConfig) ssm.StartSessionInput {
3759
}
3860
}
3961

40-
func StartTunnelSession(ctx context.Context, cfg TunnelConfig) (SessionParams, error) {
41-
// Load AWS SDK config
42-
awsCfg, err := config.LoadDefaultConfig(ctx)
43-
if err != nil {
44-
return SessionParams{}, err
45-
}
46-
awsCfg.Region = cfg.SSMRegion
47-
62+
func StartTunnelSession(ctx context.Context, awsCfg aws.Config, cfg TunnelConfig) (SessionParams, error) {
4863
// Create SSM client
4964
ssmClient := ssm.NewFromConfig(awsCfg)
5065

internal/ssm/tunnel.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"strconv"
1111
"time"
1212

13+
"github.com/aws/aws-sdk-go-v2/aws"
1314
"github.com/aws/aws-sdk-go-v2/service/ssm"
1415
pluginSession "github.com/aws/session-manager-plugin/src/sessionmanagerplugin/session"
1516
_ "github.com/aws/session-manager-plugin/src/sessionmanagerplugin/session/portsession"
@@ -30,10 +31,10 @@ func GetEndpoint(ctx context.Context, region string) (string, error) {
3031
return endpoint.URI.String(), nil
3132
}
3233

33-
func ForkRemoteTunnel(ctx context.Context, cfg TunnelConfig) (*exec.Cmd, error) {
34+
func ForkRemoteTunnel(ctx context.Context, awsCfg aws.Config, cfg TunnelConfig) (*exec.Cmd, error) {
3435
// First we start a session using AWS SDK
3536
// see https://github.com/aws/aws-cli/blob/master/awscli/customizations/sessionmanager.py#L104
36-
sessionParams, err := StartTunnelSession(ctx, cfg)
37+
sessionParams, err := StartTunnelSession(ctx, awsCfg, cfg)
3738
if err != nil {
3839
return nil, err
3940
}
@@ -102,7 +103,6 @@ func StartRemoteTunnel(ctx context.Context, cfgJson string, parentPid int) (err
102103
return err
103104
}
104105

105-
profileName := ""
106106
endpointUrl, err := GetEndpoint(ctx, cfg.SSMRegion)
107107
if err != nil {
108108
return err
@@ -113,7 +113,7 @@ func StartRemoteTunnel(ctx context.Context, cfgJson string, parentPid int) (err
113113
DEFAULT_SSM_ENV_NAME,
114114
cfg.SSMRegion,
115115
"StartSession",
116-
profileName,
116+
cfg.SSMProfile,
117117
string(sessionInputJson),
118118
endpointUrl,
119119
}

0 commit comments

Comments
 (0)