From 02c85a92234367aedff363e5332d42efc86bc1b9 Mon Sep 17 00:00:00 2001 From: Sylvain Rabot Date: Thu, 29 Jan 2026 11:06:48 +0100 Subject: [PATCH] feat: add new GetAWSIAMAuthConnector function Signed-off-by: Sylvain Rabot --- bun/bunconnect/flags.go | 46 ++++++++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 12 deletions(-) diff --git a/bun/bunconnect/flags.go b/bun/bunconnect/flags.go index 4b68d468..5873caa7 100644 --- a/bun/bunconnect/flags.go +++ b/bun/bunconnect/flags.go @@ -6,6 +6,7 @@ import ( "fmt" "time" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/stdlib" @@ -45,26 +46,47 @@ func WithRuntimeParams(params map[string]string) Option { } } +func GetAWSIAMAuthConnector(cmd *cobra.Command, opts ...Option) (func(s string) (driver.Connector, error), error) { + var cfg aws.Config + var err error + var ctx context.Context + + if cmd != nil { + ctx = cmd.Context() + cfg, err = config.LoadDefaultConfig(ctx, iam.LoadOptionFromCommand(cmd)) + } else { + ctx = context.Background() + cfg, err = config.LoadDefaultConfig(ctx) + } + + if err != nil { + return nil, err + } + + connector := func(s string) (driver.Connector, error) { + return &iamConnector{ + dsn: s, + driver: &iamDriver{ + awsConfig: cfg, + }, + options: opts, + logger: logging.FromContext(ctx), + }, nil + } + + return connector, nil +} + func ConnectionOptionsFromFlags(cmd *cobra.Command, opts ...Option) (*ConnectionOptions, error) { + var err error var connector func(string) (driver.Connector, error) awsEnable, _ := cmd.Flags().GetBool(PostgresAWSEnableIAMFlag) if awsEnable { - cfg, err := config.LoadDefaultConfig(context.Background(), iam.LoadOptionFromCommand(cmd)) + connector, err = GetAWSIAMAuthConnector(cmd, opts...) if err != nil { return nil, err } - - connector = func(s string) (driver.Connector, error) { - return &iamConnector{ - dsn: s, - driver: &iamDriver{ - awsConfig: cfg, - }, - options: opts, - logger: logging.FromContext(cmd.Context()), - }, nil - } } else { connector = func(dsn string) (driver.Connector, error) { parseConfig, err := pgx.ParseConfig(dsn)