diff --git a/.golangci.yml b/.golangci.yml index 8b1c10a1d1..30d895ffcc 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -67,6 +67,7 @@ linters: - "!**/pkg/auth/types/aws_credentials.go" - "!**/pkg/auth/types/github_oidc_credentials.go" - "!**/internal/aws_utils/**" + - "!**/pkg/aws/identity/**" - "!**/pkg/provisioner/backend/**" - "$test" deny: diff --git a/internal/aws_utils/aws_utils.go b/internal/aws_utils/aws_utils.go deleted file mode 100644 index 9a48130f63..0000000000 --- a/internal/aws_utils/aws_utils.go +++ /dev/null @@ -1,179 +0,0 @@ -package aws_utils - -import ( - "context" - "fmt" - "time" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/credentials/stscreds" - "github.com/aws/aws-sdk-go-v2/service/sts" - - errUtils "github.com/cloudposse/atmos/errors" - log "github.com/cloudposse/atmos/pkg/logger" - "github.com/cloudposse/atmos/pkg/perf" - "github.com/cloudposse/atmos/pkg/schema" -) - -// LoadAWSConfigWithAuth loads AWS config, preferring auth context if available. -/* - When authContext is provided, it uses the Atmos-managed credentials files and profile. - Otherwise, it falls back to standard AWS SDK credential resolution. - - Standard AWS SDK credential resolution order: - - Environment variables: - AWS_ACCESS_KEY_ID - AWS_SECRET_ACCESS_KEY - AWS_SESSION_TOKEN (optional, for temporary credentials) - - Shared credentials file: - Typically at ~/.aws/credentials - Controlled by: - AWS_PROFILE (defaults to default) - AWS_SHARED_CREDENTIALS_FILE - - Shared config file: - Typically at ~/.aws/config - Also supports named profiles and region settings - - Amazon EC2 Instance Metadata Service (IMDS): - If running on EC2 or ECS - Uses IAM roles attached to the instance/task - - Web Identity Token credentials: - When AWS_WEB_IDENTITY_TOKEN_FILE and AWS_ROLE_ARN are set (e.g., in EKS) - - SSO credentials (if configured) - - Custom credential sources: - Provided programmatically using config.WithCredentialsProvider(...) -*/ -func LoadAWSConfigWithAuth( - ctx context.Context, - region string, - roleArn string, - assumeRoleDuration time.Duration, - authContext *schema.AWSAuthContext, -) (aws.Config, error) { - defer perf.Track(nil, "aws_utils.LoadAWSConfigWithAuth")() - - var cfgOpts []func(*config.LoadOptions) error - - // If auth context is provided, use Atmos-managed credentials. - if authContext != nil { - log.Debug("Using Atmos auth context for AWS SDK", - "profile", authContext.Profile, - "credentials", authContext.CredentialsFile, - "config", authContext.ConfigFile, - ) - - // Set custom credential and config file paths. - // This overrides the default ~/.aws/credentials and ~/.aws/config. - cfgOpts = append(cfgOpts, - config.WithSharedCredentialsFiles([]string{authContext.CredentialsFile}), - config.WithSharedConfigFiles([]string{authContext.ConfigFile}), - config.WithSharedConfigProfile(authContext.Profile), - ) - - // Use region from auth context if not explicitly provided. - if region == "" && authContext.Region != "" { - region = authContext.Region - } - } else { - log.Debug("Using standard AWS SDK credential resolution (no auth context provided)") - } - - // Set region if provided. - if region != "" { - log.Debug("Using explicit region", "region", region) - cfgOpts = append(cfgOpts, config.WithRegion(region)) - } - - // Load base config. - log.Debug("Loading AWS SDK config", "num_options", len(cfgOpts)) - baseCfg, err := config.LoadDefaultConfig(ctx, cfgOpts...) - if err != nil { - log.Debug("Failed to load AWS config", "error", err) - return aws.Config{}, fmt.Errorf("%w: %w", errUtils.ErrLoadAWSConfig, err) - } - log.Debug("Successfully loaded AWS SDK config", "region", baseCfg.Region) - - // Conditionally assume role if specified. - if roleArn != "" { - log.Debug("Assuming role", "ARN", roleArn) - stsClient := sts.NewFromConfig(baseCfg) - - creds := stscreds.NewAssumeRoleProvider(stsClient, roleArn, func(o *stscreds.AssumeRoleOptions) { - o.Duration = assumeRoleDuration - }) - - cfgOpts = append(cfgOpts, config.WithCredentialsProvider(aws.NewCredentialsCache(creds))) - - // Reload full config with assumed role credentials. - return config.LoadDefaultConfig(ctx, cfgOpts...) - } - - return baseCfg, nil -} - -// LoadAWSConfig loads AWS config using standard AWS SDK credential resolution. -// This is a wrapper around LoadAWSConfigWithAuth for backward compatibility. -// For new code that needs Atmos auth support, use LoadAWSConfigWithAuth instead. -func LoadAWSConfig(ctx context.Context, region string, roleArn string, assumeRoleDuration time.Duration) (aws.Config, error) { - defer perf.Track(nil, "aws_utils.LoadAWSConfig")() - - return LoadAWSConfigWithAuth(ctx, region, roleArn, assumeRoleDuration, nil) -} - -// AWSCallerIdentityResult holds the result of GetAWSCallerIdentity. -type AWSCallerIdentityResult struct { - Account string - Arn string - UserID string - Region string -} - -// GetAWSCallerIdentity retrieves AWS caller identity using STS GetCallerIdentity API. -// Returns account ID, ARN, user ID, and region. -// This function keeps AWS SDK STS imports contained within aws_utils package. -func GetAWSCallerIdentity( - ctx context.Context, - region string, - roleArn string, - assumeRoleDuration time.Duration, - authContext *schema.AWSAuthContext, -) (*AWSCallerIdentityResult, error) { - defer perf.Track(nil, "aws_utils.GetAWSCallerIdentity")() - - // Load AWS config. - cfg, err := LoadAWSConfigWithAuth(ctx, region, roleArn, assumeRoleDuration, authContext) - if err != nil { - return nil, err - } - - // Create STS client and get caller identity. - stsClient := sts.NewFromConfig(cfg) - output, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) - if err != nil { - return nil, fmt.Errorf("%w: %w", errUtils.ErrAwsGetCallerIdentity, err) - } - - result := &AWSCallerIdentityResult{ - Region: cfg.Region, - } - - // Extract values from pointers. - if output.Account != nil { - result.Account = *output.Account - } - if output.Arn != nil { - result.Arn = *output.Arn - } - if output.UserId != nil { - result.UserID = *output.UserId - } - - return result, nil -} diff --git a/internal/aws_utils/aws_utils_test.go b/internal/aws_utils/aws_utils_test.go deleted file mode 100644 index 21c5544782..0000000000 --- a/internal/aws_utils/aws_utils_test.go +++ /dev/null @@ -1,263 +0,0 @@ -package aws_utils - -import ( - "context" - "os" - "path/filepath" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/cloudposse/atmos/pkg/schema" - "github.com/cloudposse/atmos/tests" -) - -func TestLoadAWSConfig(t *testing.T) { - // Check for AWS profile precondition - tests.RequireAWSProfile(t, "cplive-core-gbl-identity") - tests := []struct { - name string - region string - roleArn string - setupEnv func() - cleanupEnv func() - wantErr bool - }{ - { - name: "basic config without region or role", - region: "", - roleArn: "", - setupEnv: func() { - t.Setenv("AWS_ACCESS_KEY_ID", "test-key") - t.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret") - }, - cleanupEnv: func() { - os.Unsetenv("AWS_ACCESS_KEY_ID") - os.Unsetenv("AWS_SECRET_ACCESS_KEY") - }, - wantErr: false, - }, - { - name: "config with custom region", - region: "us-east-2", - roleArn: "", - setupEnv: func() { - t.Setenv("AWS_ACCESS_KEY_ID", "test-key") - t.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret") - }, - cleanupEnv: func() { - os.Unsetenv("AWS_ACCESS_KEY_ID") - os.Unsetenv("AWS_SECRET_ACCESS_KEY") - }, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Clear AWS_PROFILE to prevent conflicts with local AWS configuration. - t.Setenv("AWS_PROFILE", "") - - // Setup - if tt.setupEnv != nil { - tt.setupEnv() - } - - // Cleanup - if tt.cleanupEnv != nil { - defer tt.cleanupEnv() - } - - // Execute - cfg, err := LoadAWSConfig(context.Background(), tt.region, tt.roleArn, time.Minute*15) - - // Assert - if tt.wantErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - if tt.region != "" { - assert.Equal(t, tt.region, cfg.Region) - } - } - }) - } -} - -func TestLoadAWSConfigWithAuth(t *testing.T) { - tests := []struct { - name string - region string - authContext *schema.AWSAuthContext - scenario string // Test scenario for setup logic: "mismatched-profile", "explicit-files", or "" - wantRegion string - wantErr bool - }{ - { - name: "without auth context", - region: "us-east-1", - authContext: nil, - wantRegion: "us-east-1", - wantErr: false, - }, - { - name: "with auth context and explicit region", - region: "us-west-2", - authContext: &schema.AWSAuthContext{ - Profile: "test-profile", - Region: "eu-west-1", - }, - wantRegion: "us-west-2", // Explicit region takes precedence. - wantErr: false, - }, - { - name: "with auth context using context region", - region: "", - authContext: &schema.AWSAuthContext{ - Profile: "test-profile", - Region: "ap-southeast-1", - }, - wantRegion: "ap-southeast-1", // Uses auth context region. - wantErr: false, - }, - { - name: "with auth context without region", - region: "", - authContext: &schema.AWSAuthContext{ - Profile: "test-profile", - Region: "", - }, - wantRegion: "", // No region specified. - wantErr: false, - }, - { - name: "non-existent credentials file", - region: "us-east-1", - authContext: &schema.AWSAuthContext{ - Profile: "test-profile", - Region: "us-east-1", - CredentialsFile: "/non/existent/credentials", - ConfigFile: "/non/existent/config", - }, - wantRegion: "", - wantErr: true, - }, - { - name: "invalid profile name in auth context", - region: "us-east-1", - authContext: &schema.AWSAuthContext{ - Profile: "nonexistent-profile", - Region: "us-east-1", - }, - scenario: "mismatched-profile", - wantRegion: "", - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Clear AWS environment variables to avoid conflicts. - t.Setenv("AWS_PROFILE", "") - t.Setenv("AWS_REGION", "") - t.Setenv("AWS_DEFAULT_REGION", "") - t.Setenv("AWS_ACCESS_KEY_ID", "test-key") - t.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret") - - // Create a local copy of authContext to avoid mutating test table. - var authContextCopy *schema.AWSAuthContext - if tt.authContext != nil { - // Copy the struct to avoid race conditions. - authContextCopy = &schema.AWSAuthContext{ - Profile: tt.authContext.Profile, - Region: tt.authContext.Region, - } - - // Handle different test scenarios. - switch { - case tt.authContext.CredentialsFile != "": - // For error test cases with explicit file paths, use them. - authContextCopy.CredentialsFile = tt.authContext.CredentialsFile - authContextCopy.ConfigFile = tt.authContext.ConfigFile - case tt.scenario == "mismatched-profile": - // Create valid files but with a different profile name. - tempDir := t.TempDir() - credFile := filepath.Join(tempDir, "credentials") - configFile := filepath.Join(tempDir, "config") - - // Write credential file with different profile. - credContent := "[different-profile]\n" - credContent += "aws_access_key_id = test-key\n" - credContent += "aws_secret_access_key = test-secret\n" - require.NoError(t, os.WriteFile(credFile, []byte(credContent), 0o600)) - - // Write config file with different profile. - cfgContent := "[profile different-profile]\n" - cfgContent += "region = us-east-1\n" - require.NoError(t, os.WriteFile(configFile, []byte(cfgContent), 0o600)) - - authContextCopy.CredentialsFile = credFile - authContextCopy.ConfigFile = configFile - case !tt.wantErr: - // Create valid credentials for happy-path tests. - tempDir := t.TempDir() - credFile := filepath.Join(tempDir, "credentials") - configFile := filepath.Join(tempDir, "config") - - // Write minimal credential file. - credContent := "[" + authContextCopy.Profile + "]\n" - credContent += "aws_access_key_id = test-key\n" - credContent += "aws_secret_access_key = test-secret\n" - require.NoError(t, os.WriteFile(credFile, []byte(credContent), 0o600)) - - // Write minimal config file. - cfgContent := "[profile " + authContextCopy.Profile + "]\n" - if authContextCopy.Region != "" { - cfgContent += "region = " + authContextCopy.Region + "\n" - } - require.NoError(t, os.WriteFile(configFile, []byte(cfgContent), 0o600)) - - // Set file paths on the copy. - authContextCopy.CredentialsFile = credFile - authContextCopy.ConfigFile = configFile - } - } - - // Execute. - cfg, err := LoadAWSConfigWithAuth( - context.Background(), - tt.region, - "", // No role ARN for these tests. - time.Minute*15, - authContextCopy, - ) - - // Assert - if tt.wantErr { - assert.Error(t, err) - } else { - require.NoError(t, err) - assert.Equal(t, tt.wantRegion, cfg.Region) - } - }) - } -} - -func TestLoadAWSConfig_BackwardCompatibility(t *testing.T) { - // Test that LoadAWSConfig is equivalent to LoadAWSConfigWithAuth(nil) - t.Setenv("AWS_PROFILE", "") - t.Setenv("AWS_ACCESS_KEY_ID", "test-key") - t.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret") - - region := "us-east-1" - - cfg1, err1 := LoadAWSConfig(context.Background(), region, "", time.Minute*15) - cfg2, err2 := LoadAWSConfigWithAuth(context.Background(), region, "", time.Minute*15, nil) - - assert.Equal(t, err1 == nil, err2 == nil, "Both functions should have same error state") - if err1 == nil && err2 == nil { - assert.Equal(t, cfg1.Region, cfg2.Region, "Both functions should return same region") - } -} diff --git a/internal/exec/aws_getter.go b/internal/exec/aws_getter.go index 27109f8ca8..739c95918b 100644 --- a/internal/exec/aws_getter.go +++ b/internal/exec/aws_getter.go @@ -2,113 +2,29 @@ package exec import ( "context" - "fmt" - "sync" - awsUtils "github.com/cloudposse/atmos/internal/aws_utils" - log "github.com/cloudposse/atmos/pkg/logger" + awsIdentity "github.com/cloudposse/atmos/pkg/aws/identity" "github.com/cloudposse/atmos/pkg/perf" "github.com/cloudposse/atmos/pkg/schema" ) // AWSCallerIdentity holds the information returned by AWS STS GetCallerIdentity. -type AWSCallerIdentity struct { - Account string - Arn string - UserID string - Region string // The AWS region from the loaded config. -} +// This is a type alias that delegates to pkg/aws/identity.CallerIdentity. +type AWSCallerIdentity = awsIdentity.CallerIdentity // AWSGetter provides an interface for retrieving AWS caller identity information. // This interface enables dependency injection and testability. +// This is a type alias that delegates to pkg/aws/identity.Getter. // //go:generate go run go.uber.org/mock/mockgen@v0.6.0 -source=$GOFILE -destination=mock_aws_getter_test.go -package=exec -type AWSGetter interface { - // GetCallerIdentity retrieves the AWS caller identity for the current credentials. - // Returns the account ID, ARN, and user ID of the calling identity. - GetCallerIdentity( - ctx context.Context, - atmosConfig *schema.AtmosConfiguration, - authContext *schema.AWSAuthContext, - ) (*AWSCallerIdentity, error) -} - -// defaultAWSGetter is the production implementation that uses real AWS SDK calls. -type defaultAWSGetter struct{} - -// GetCallerIdentity retrieves the AWS caller identity using the STS GetCallerIdentity API. -func (d *defaultAWSGetter) GetCallerIdentity( - ctx context.Context, - atmosConfig *schema.AtmosConfiguration, - authContext *schema.AWSAuthContext, -) (*AWSCallerIdentity, error) { - defer perf.Track(atmosConfig, "exec.AWSGetter.GetCallerIdentity")() - - log.Debug("Getting AWS caller identity") - - // Use the aws_utils helper to get caller identity (keeps AWS SDK imports in aws_utils). - result, err := awsUtils.GetAWSCallerIdentity(ctx, "", "", 0, authContext) - if err != nil { - return nil, err // Error already wrapped by aws_utils. - } - - identity := &AWSCallerIdentity{ - Account: result.Account, - Arn: result.Arn, - UserID: result.UserID, - Region: result.Region, - } - - log.Debug("Retrieved AWS caller identity", - "account", identity.Account, - "arn", identity.Arn, - "user_id", identity.UserID, - "region", identity.Region, - ) - - return identity, nil -} - -// awsGetter is the global instance used by YAML functions. -// This allows test code to replace it with a mock. -var awsGetter AWSGetter = &defaultAWSGetter{} +type AWSGetter = awsIdentity.Getter // SetAWSGetter allows tests to inject a mock AWSGetter. // Returns a function to restore the original getter. func SetAWSGetter(getter AWSGetter) func() { defer perf.Track(nil, "exec.SetAWSGetter")() - original := awsGetter - awsGetter = getter - return func() { - awsGetter = original - } -} - -// cachedAWSIdentity holds the cached AWS caller identity. -// The cache is per-CLI-invocation (stored in memory) to avoid repeated STS calls. -type cachedAWSIdentity struct { - identity *AWSCallerIdentity - err error -} - -var ( - awsIdentityCache map[string]*cachedAWSIdentity - awsIdentityCacheMu sync.RWMutex -) - -func init() { - awsIdentityCache = make(map[string]*cachedAWSIdentity) -} - -// getCacheKey generates a cache key based on the auth context. -// Different auth contexts (different credentials) get different cache entries. -// Includes Profile, CredentialsFile, and ConfigFile since all three affect AWS config loading. -func getCacheKey(authContext *schema.AWSAuthContext) string { - if authContext == nil { - return "default" - } - return fmt.Sprintf("%s:%s:%s", authContext.Profile, authContext.CredentialsFile, authContext.ConfigFile) + return awsIdentity.SetGetter(getter) } // getAWSCallerIdentityCached retrieves the AWS caller identity with caching. @@ -120,37 +36,7 @@ func getAWSCallerIdentityCached( ) (*AWSCallerIdentity, error) { defer perf.Track(atmosConfig, "exec.getAWSCallerIdentityCached")() - cacheKey := getCacheKey(authContext) - - // Check cache first (read lock). - awsIdentityCacheMu.RLock() - if cached, ok := awsIdentityCache[cacheKey]; ok { - awsIdentityCacheMu.RUnlock() - log.Debug("Using cached AWS caller identity", "cache_key", cacheKey) - return cached.identity, cached.err - } - awsIdentityCacheMu.RUnlock() - - // Cache miss - acquire write lock and fetch. - awsIdentityCacheMu.Lock() - defer awsIdentityCacheMu.Unlock() - - // Double-check after acquiring write lock. - if cached, ok := awsIdentityCache[cacheKey]; ok { - log.Debug("Using cached AWS caller identity (double-check)", "cache_key", cacheKey) - return cached.identity, cached.err - } - - // Fetch from AWS. - identity, err := awsGetter.GetCallerIdentity(ctx, atmosConfig, authContext) - - // Cache the result (including errors to avoid repeated failed calls). - awsIdentityCache[cacheKey] = &cachedAWSIdentity{ - identity: identity, - err: err, - } - - return identity, err + return awsIdentity.GetCallerIdentityCached(ctx, atmosConfig, authContext) } // ClearAWSIdentityCache clears the AWS identity cache. @@ -158,7 +44,5 @@ func getAWSCallerIdentityCached( func ClearAWSIdentityCache() { defer perf.Track(nil, "exec.ClearAWSIdentityCache")() - awsIdentityCacheMu.Lock() - defer awsIdentityCacheMu.Unlock() - awsIdentityCache = make(map[string]*cachedAWSIdentity) + awsIdentity.ClearIdentityCache() } diff --git a/internal/exec/packer_test.go b/internal/exec/packer_test.go index d68a9a27ae..011c2475fe 100644 --- a/internal/exec/packer_test.go +++ b/internal/exec/packer_test.go @@ -55,23 +55,28 @@ func TestExecutePacker_Validate(t *testing.T) { r, w, _ := os.Pipe() os.Stdout = w + // Ensure stdout is restored even if test fails. + defer func() { + os.Stdout = oldStd + }() + log.SetOutput(w) err = ExecutePacker(&info, &packerFlags) - assert.NoError(t, err) - // Restore std - err = w.Close() - assert.NoError(t, err) + // Restore stdout before assertions. + w.Close() os.Stdout = oldStd - // Read the captured output + assert.NoError(t, err) + + // Read the captured output. var buf bytes.Buffer _, err = buf.ReadFrom(r) assert.NoError(t, err) output := buf.String() - // Check the output + // Check the output. expected := "The configuration is valid" if !strings.Contains(output, expected) { @@ -103,24 +108,29 @@ func TestExecutePacker_Inspect(t *testing.T) { r, w, _ := os.Pipe() os.Stdout = w + // Ensure stdout is restored even if test fails. + defer func() { + os.Stdout = oldStd + }() + log.SetOutput(w) packerFlags := PackerFlags{} err := ExecutePacker(&info, &packerFlags) - assert.NoError(t, err) - // Restore std - err = w.Close() - assert.NoError(t, err) + // Restore stdout before assertions. + w.Close() os.Stdout = oldStd - // Read the captured output + assert.NoError(t, err) + + // Read the captured output. var buf bytes.Buffer _, err = buf.ReadFrom(r) assert.NoError(t, err) output := buf.String() - // Check the output + // Check the output. expected := "var.source_ami: \"ami-0013ceeff668b979b\"" if !strings.Contains(output, expected) { @@ -146,24 +156,29 @@ func TestExecutePacker_Version(t *testing.T) { r, w, _ := os.Pipe() os.Stdout = w + // Ensure stdout is restored even if test fails. + defer func() { + os.Stdout = oldStd + }() + log.SetOutput(w) packerFlags := PackerFlags{} err := ExecutePacker(&info, &packerFlags) - assert.NoError(t, err) - // Restore std - err = w.Close() - assert.NoError(t, err) + // Restore stdout before assertions. + w.Close() os.Stdout = oldStd - // Read the captured output + assert.NoError(t, err) + + // Read the captured output. var buf bytes.Buffer _, err = buf.ReadFrom(r) assert.NoError(t, err) output := buf.String() - // Check the output + // Check the output. expected := "Packer v" if !strings.Contains(output, expected) { diff --git a/internal/exec/terraform_test.go b/internal/exec/terraform_test.go index 5630c19a9a..47d4898683 100644 --- a/internal/exec/terraform_test.go +++ b/internal/exec/terraform_test.go @@ -75,11 +75,16 @@ func TestExecuteTerraform_ExportEnvVar(t *testing.T) { SubCommand: "apply", } - // Create a pipe to capture stdout to check if terraform is executed correctly + // Create a pipe to capture stdout to check if terraform is executed correctly. oldStdout := os.Stdout r, w, _ := os.Pipe() os.Stdout = w + // Ensure stdout is restored even if test fails. + defer func() { + os.Stdout = oldStdout + }() + // Read from pipe concurrently to avoid deadlock when output exceeds pipe buffer. var buf bytes.Buffer done := make(chan struct{}) @@ -89,19 +94,21 @@ func TestExecuteTerraform_ExportEnvVar(t *testing.T) { }() err = ExecuteTerraform(info) + + // Close writer and restore stdout before checking error. + w.Close() + os.Stdout = oldStdout + + // Wait for the reader goroutine to finish. + <-done + if err != nil { t.Fatalf("Failed to execute 'ExecuteTerraform': %v", err) } - // Restore stdout and close writer to signal EOF to the reader goroutine - err = w.Close() - assert.NoError(t, err) - os.Stdout = oldStdout - // Wait for the reader goroutine to finish - <-done output := buf.String() - // Check the output ATMOS_CLI_CONFIG_PATH ATMOS_BASE_PATH exists + // Check the output ATMOS_CLI_CONFIG_PATH ATMOS_BASE_PATH exists. if !strings.Contains(output, "ATMOS_BASE_PATH") { t.Errorf("ATMOS_BASE_PATH not found in the output") } @@ -160,11 +167,16 @@ func TestExecuteTerraform_TerraformPlanWithProcessingTemplates(t *testing.T) { ProcessFunctions: true, } - // Create a pipe to capture stdout to check if terraform is executed correctly + // Create a pipe to capture stdout to check if terraform is executed correctly. oldStdout := os.Stdout r, w, _ := os.Pipe() os.Stdout = w + // Ensure stdout is restored even if test fails. + defer func() { + os.Stdout = oldStdout + }() + // Read from pipe concurrently to avoid deadlock when output exceeds pipe buffer. var buf bytes.Buffer done := make(chan struct{}) @@ -174,19 +186,21 @@ func TestExecuteTerraform_TerraformPlanWithProcessingTemplates(t *testing.T) { }() err := ExecuteTerraform(info) + + // Close writer and restore stdout before checking error. + w.Close() + os.Stdout = oldStdout + + // Wait for the reader goroutine to finish. + <-done + if err != nil { t.Fatalf("Failed to execute 'ExecuteTerraform': %v", err) } - // Restore stdout and close writer to signal EOF to the reader goroutine - err = w.Close() - assert.NoError(t, err) - os.Stdout = oldStdout - // Wait for the reader goroutine to finish - <-done output := buf.String() - // Check the output + // Check the output. if !strings.Contains(output, "component-1-a") { t.Errorf("'foo' variable should be 'component-1-a'") } @@ -217,11 +231,16 @@ func TestExecuteTerraform_TerraformPlanWithoutProcessingTemplates(t *testing.T) ProcessFunctions: true, } - // Create a pipe to capture stdout to check if terraform is executed correctly + // Create a pipe to capture stdout to check if terraform is executed correctly. oldStdout := os.Stdout r, w, _ := os.Pipe() os.Stdout = w + // Ensure stdout is restored even if test fails. + defer func() { + os.Stdout = oldStdout + }() + // Read from pipe concurrently to avoid deadlock when output exceeds pipe buffer. var buf bytes.Buffer done := make(chan struct{}) @@ -231,16 +250,18 @@ func TestExecuteTerraform_TerraformPlanWithoutProcessingTemplates(t *testing.T) }() err := ExecuteTerraform(info) + + // Close writer and restore stdout before checking error. + w.Close() + os.Stdout = oldStdout + + // Wait for the reader goroutine to finish. + <-done + if err != nil { t.Fatalf("Failed to execute 'ExecuteTerraform': %v", err) } - // Restore stdout and close writer to signal EOF to the reader goroutine - err = w.Close() - assert.NoError(t, err) - os.Stdout = oldStdout - // Wait for the reader goroutine to finish - <-done output := buf.String() t.Cleanup(func() { @@ -281,20 +302,27 @@ func TestExecuteTerraform_TerraformWorkspace(t *testing.T) { ProcessFunctions: true, } - // Create a pipe to capture stdout to check if terraform is executed correctly + // Create a pipe to capture stdout to check if terraform is executed correctly. oldStdout := os.Stdout r, w, _ := os.Pipe() os.Stdout = w + + // Ensure stdout is restored even if test fails. + defer func() { + os.Stdout = oldStdout + }() + err := ExecuteTerraform(info) + + // Close writer and restore stdout before checking error. + w.Close() + os.Stdout = oldStdout + if err != nil { t.Fatalf("Failed to execute 'ExecuteTerraform': %v", err) } - // Restore stdout - err = w.Close() - assert.NoError(t, err) - os.Stdout = oldStdout - // Read the captured output + // Read the captured output. var buf bytes.Buffer _, err = buf.ReadFrom(r) if err != nil { @@ -302,7 +330,7 @@ func TestExecuteTerraform_TerraformWorkspace(t *testing.T) { } output := buf.String() - // Check the output + // Check the output. if !strings.Contains(output, "workspace \"nonprod-component-1\"") { t.Errorf("The output should contain 'nonprod-component-1'") } @@ -356,6 +384,11 @@ func TestExecuteTerraform_TerraformInitWithVarfile(t *testing.T) { r, w, _ := os.Pipe() os.Stderr = w + // Ensure stderr is restored even if test fails. + defer func() { + os.Stderr = oldStderr + }() + log.SetLevel(log.DebugLevel) log.SetOutput(w) @@ -370,20 +403,21 @@ func TestExecuteTerraform_TerraformInitWithVarfile(t *testing.T) { }() err := ExecuteTerraform(info) + + // Close writer and restore stderr before checking error. + w.Close() + os.Stderr = oldStderr + if err != nil { t.Fatalf("Failed to execute 'ExecuteTerraform': %v", err) } - // Restore stderr and close writer to signal EOF to the reader goroutine - err = w.Close() - assert.NoError(t, err) - os.Stderr = oldStderr - // Wait for the reader goroutine to finish + // Wait for the reader goroutine to finish. <-done output := buf.String() - // Check the output + // Check the output. expected := "init -reconfigure -var-file nonprod-component-1.terraform.tfvars.json" if !strings.Contains(output, expected) { t.Logf("TestExecuteTerraform_TerraformInitWithVarfile output:\n%s", output) @@ -478,29 +512,35 @@ func TestExecuteTerraform_TerraformPlanWithSkipPlanfile(t *testing.T) { r, w, _ := os.Pipe() os.Stderr = w + // Ensure stderr is restored even if test fails. + defer func() { + os.Stderr = oldStderr + }() + log.SetLevel(log.DebugLevel) - // Create a buffer to capture the output + // Create a buffer to capture the output. var buf bytes.Buffer log.SetOutput(&buf) err := ExecuteTerraform(info) + + // Close writer and restore stderr before checking error. + w.Close() + os.Stderr = oldStderr + if err != nil { t.Fatalf("Failed to execute 'ExecuteTerraform': %v", err) } - // Restore stderr - err = w.Close() - assert.NoError(t, err) - os.Stderr = oldStderr - // Read the captured output + // Read the captured output. _, err = buf.ReadFrom(r) if err != nil { t.Fatalf("Failed to read from pipe: %v", err) } output := buf.String() - // Check the output + // Check the output. expected := "plan -var-file nonprod-cmp-1.terraform.tfvars.json" notExpected := "-out nonprod-cmp-1.planfile" @@ -639,21 +679,27 @@ func TestExecuteTerraform_DeploymentStatus(t *testing.T) { info.AdditionalArgsAndFlags = append(info.AdditionalArgsAndFlags, "--upload-status=false") } - // Create a pipe to capture stdout and stderr + // Create a pipe to capture stdout and stderr. oldStdout := os.Stdout oldStderr := os.Stderr r, w, _ := os.Pipe() os.Stdout = w os.Stderr = w - // Save original logger and set up test logger + // Ensure stdout/stderr are restored even if test fails. + defer func() { + os.Stdout = oldStdout + os.Stderr = oldStderr + }() + + // Save original logger and set up test logger. originalLogger := log.Default() logger := log.New() logger.SetOutput(w) log.SetDefault(logger) defer log.SetDefault(originalLogger) - // Create a channel to signal when the pipe is closed + // Create a channel to signal when the pipe is closed. done := make(chan struct{}) go func() { defer close(done) @@ -661,7 +707,7 @@ func TestExecuteTerraform_DeploymentStatus(t *testing.T) { _ = ExecuteTerraform(info) }() - // Read the output + // Read the output. var buf bytes.Buffer _, err := buf.ReadFrom(r) if err != nil { @@ -669,7 +715,7 @@ func TestExecuteTerraform_DeploymentStatus(t *testing.T) { } output := buf.String() - // Restore stdout, stderr, and logger + // Restore stdout, stderr, and logger. os.Stdout = oldStdout os.Stderr = oldStderr log.SetDefault(log.Default()) @@ -869,6 +915,11 @@ components: r, w, _ := os.Pipe() os.Stderr = w + // Ensure stderr is restored even if test fails. + defer func() { + os.Stderr = oldStderr + }() + err = ExecuteTerraform(info) // Restore stderr. diff --git a/internal/exec/terraform_utils_test.go b/internal/exec/terraform_utils_test.go index fedbfb38ce..2f25eeb3a0 100644 --- a/internal/exec/terraform_utils_test.go +++ b/internal/exec/terraform_utils_test.go @@ -110,6 +110,11 @@ func TestExecuteTerraformAffectedWithDependents(t *testing.T) { _, w, _ := os.Pipe() os.Stderr = w + // Ensure stderr is restored even if test fails. + defer func() { + os.Stderr = oldStd + }() + stack := "prod" info := schema.ConfigAndStacksInfo{ @@ -133,24 +138,25 @@ func TestExecuteTerraformAffectedWithDependents(t *testing.T) { } err = ExecuteTerraformAffected(&a, &info) + + // Restore stderr before checking error. + w.Close() + os.Stderr = oldStd + if err != nil { // This test may fail in environments where Git operations or terraform execution // encounter issues. Skip instead of failing to avoid blocking CI. t.Skipf("Test failed (environment issue or missing preconditions): %v", err) } - - err = w.Close() - assert.NoError(t, err) - os.Stderr = oldStd } func TestExecuteTerraformQuery(t *testing.T) { - // Check if terraform is installed + // Check if terraform is installed. tests.RequireExecutable(t, "terraform", "running Terraform query tests") os.Unsetenv("ATMOS_BASE_PATH") os.Unsetenv("ATMOS_CLI_CONFIG_PATH") - // Define the work directory and change to it + // Define the work directory and change to it. workDir := "../../tests/fixtures/scenarios/terraform-apply-affected" t.Chdir(workDir) @@ -158,6 +164,11 @@ func TestExecuteTerraformQuery(t *testing.T) { _, w, _ := os.Pipe() os.Stderr = w + // Ensure stderr is restored even if test fails. + defer func() { + os.Stderr = oldStd + }() + stack := "prod" info := schema.ConfigAndStacksInfo{ @@ -169,13 +180,14 @@ func TestExecuteTerraformQuery(t *testing.T) { } err := ExecuteTerraformQuery(&info) + + // Restore stderr before checking error. + w.Close() + os.Stderr = oldStd + if err != nil { t.Fatalf("Failed to execute 'ExecuteTerraformQuery': %v", err) } - - err = w.Close() - assert.NoError(t, err) - os.Stderr = oldStd } // TestWalkTerraformComponents verifies that walkTerraformComponents iterates over all components. diff --git a/internal/exec/workflow_adapters_test.go b/internal/exec/workflow_adapters_test.go index 31ae36adf8..3497f66506 100644 --- a/internal/exec/workflow_adapters_test.go +++ b/internal/exec/workflow_adapters_test.go @@ -340,6 +340,11 @@ func TestWorkflowUIProvider_PrintMessage(t *testing.T) { r, w, _ := os.Pipe() os.Stderr = w + // Ensure stderr is restored even if test fails. + defer func() { + os.Stderr = oldStderr + }() + provider.PrintMessage("Hello %s!", "World") w.Close() @@ -360,6 +365,11 @@ func TestWorkflowUIProvider_PrintMessage_NoArgs(t *testing.T) { r, w, _ := os.Pipe() os.Stderr = w + // Ensure stderr is restored even if test fails. + defer func() { + os.Stderr = oldStderr + }() + provider.PrintMessage("Simple message") w.Close() @@ -500,6 +510,11 @@ func TestWorkflowUIProvider_PrintMessage_MultipleArgs(t *testing.T) { r, w, _ := os.Pipe() os.Stderr = w + // Ensure stderr is restored even if test fails. + defer func() { + os.Stderr = oldStderr + }() + provider.PrintMessage("Step %d of %d: %s", 1, 3, "running") w.Close() diff --git a/internal/exec/yaml_func_aws_test.go b/internal/exec/yaml_func_aws_test.go index cebb4b82f4..b02e2441c6 100644 --- a/internal/exec/yaml_func_aws_test.go +++ b/internal/exec/yaml_func_aws_test.go @@ -334,58 +334,9 @@ func TestAWSCacheWithErrors(t *testing.T) { assert.Equal(t, 1, callCount, "Errors should be cached too") } -func TestGetCacheKey(t *testing.T) { - tests := []struct { - name string - authContext *schema.AWSAuthContext - expected string - }{ - { - name: "nil auth context", - authContext: nil, - expected: "default", - }, - { - name: "with profile credentials and config file", - authContext: &schema.AWSAuthContext{ - Profile: "my-profile", - CredentialsFile: "/home/user/.aws/credentials", - ConfigFile: "/home/user/.aws/config", - }, - expected: "my-profile:/home/user/.aws/credentials:/home/user/.aws/config", - }, - { - name: "empty profile", - authContext: &schema.AWSAuthContext{ - Profile: "", - CredentialsFile: "/some/path", - ConfigFile: "/some/config", - }, - expected: ":/some/path:/some/config", - }, - { - name: "empty config file", - authContext: &schema.AWSAuthContext{ - Profile: "prod", - CredentialsFile: "/creds", - ConfigFile: "", - }, - expected: "prod:/creds:", - }, - } +// Note: TestGetCacheKey moved to pkg/aws/identity/identity_test.go. - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := getCacheKey(tt.authContext) - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestAWSGetterInterface(t *testing.T) { - // Ensure defaultAWSGetter implements AWSGetter. - var _ AWSGetter = &defaultAWSGetter{} -} +// Note: TestAWSGetterInterface moved to pkg/aws/identity/identity_test.go. func TestProcessTagAwsWithAuthContext(t *testing.T) { // Clear cache before test. @@ -540,35 +491,9 @@ func TestErrorWrapping(t *testing.T) { assert.ErrorIs(t, err, underlyingErr) } -// TestDefaultAWSGetterExists verifies the default getter exists. -func TestDefaultAWSGetterExists(t *testing.T) { - // The awsGetter variable should be initialized. - assert.NotNil(t, awsGetter) - - // It should be a *defaultAWSGetter. - _, ok := awsGetter.(*defaultAWSGetter) - assert.True(t, ok, "Default awsGetter should be *defaultAWSGetter") -} - -// TestSetAWSGetterRestore verifies the restore function works. -func TestSetAWSGetterRestore(t *testing.T) { - originalGetter := awsGetter - - mockGetter := &mockAWSGetter{ - identity: &AWSCallerIdentity{Account: "444444444444"}, - } - - restore := SetAWSGetter(mockGetter) - - // Verify getter was replaced. - assert.Equal(t, mockGetter, awsGetter) +// Note: TestDefaultAWSGetterExists moved to pkg/aws/identity/identity_test.go. - // Restore original. - restore() - - // Verify original was restored. - assert.Equal(t, originalGetter, awsGetter) -} +// Note: TestSetAWSGetterRestore moved to pkg/aws/identity/identity_test.go. // TestErrAwsGetCallerIdentity verifies the error constant exists. func TestErrAwsGetCallerIdentity(t *testing.T) { @@ -712,42 +637,8 @@ func TestCacheConcurrency(t *testing.T) { assert.Equal(t, 1, callCount, "Concurrent access should result in only one getter call") } -// TestCacheKeyWithRegion verifies cache key includes all relevant auth context fields. -func TestCacheKeyWithRegion(t *testing.T) { - tests := []struct { - name string - authContext *schema.AWSAuthContext - expected string - }{ - { - name: "full auth context with region", - authContext: &schema.AWSAuthContext{ - Profile: "prod", - CredentialsFile: "/prod/creds", - ConfigFile: "/prod/config", - Region: "us-east-1", // Region is in auth context but not in cache key. - }, - expected: "prod:/prod/creds:/prod/config", - }, - { - name: "same profile different region should have same cache key", - authContext: &schema.AWSAuthContext{ - Profile: "prod", - CredentialsFile: "/prod/creds", - ConfigFile: "/prod/config", - Region: "eu-west-1", // Different region, same cache key. - }, - expected: "prod:/prod/creds:/prod/config", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := getCacheKey(tt.authContext) - assert.Equal(t, tt.expected, result) - }) - } -} +// NOTE: TestCacheKeyWithRegion was removed because getCacheKey is now an internal +// implementation detail in pkg/aws/identity and has its own tests there. // TestAllAWSFunctionsShareCache verifies all four functions share the same cache. func TestAllAWSFunctionsShareCache(t *testing.T) { diff --git a/internal/terraform_backend/terraform_backend_s3.go b/internal/terraform_backend/terraform_backend_s3.go index e1bae8ef48..e2ffa71ad9 100644 --- a/internal/terraform_backend/terraform_backend_s3.go +++ b/internal/terraform_backend/terraform_backend_s3.go @@ -16,7 +16,7 @@ import ( "github.com/aws/smithy-go" errUtils "github.com/cloudposse/atmos/errors" - awsUtils "github.com/cloudposse/atmos/internal/aws_utils" + awsIdentity "github.com/cloudposse/atmos/pkg/aws/identity" log "github.com/cloudposse/atmos/pkg/logger" "github.com/cloudposse/atmos/pkg/perf" "github.com/cloudposse/atmos/pkg/schema" @@ -82,7 +82,7 @@ func getCachedS3Client(backend *map[string]any, authContext *schema.AuthContext) } // The minimum `assume role` duration allowed by AWS is 15 minutes. - cfg, err := awsUtils.LoadAWSConfigWithAuth(ctx, region, roleArn, 15*time.Minute, awsAuthContext) + cfg, err := awsIdentity.LoadConfigWithAuth(ctx, region, roleArn, 15*time.Minute, awsAuthContext) if err != nil { return nil, err } diff --git a/pkg/aws/identity/doc.go b/pkg/aws/identity/doc.go new file mode 100644 index 0000000000..174dcb1060 --- /dev/null +++ b/pkg/aws/identity/doc.go @@ -0,0 +1,11 @@ +// Package identity provides AWS caller identity retrieval and caching. +// +// This package consolidates AWS identity-related functionality used by Atmos functions +// (YAML, HCL, etc.) and provides a clean, reusable interface for identity operations. +// +// Key features: +// - AWS config loading with support for auth context +// - Caller identity retrieval via STS GetCallerIdentity +// - Thread-safe caching of identity results per auth context +// - Testable via Getter interface +package identity diff --git a/pkg/aws/identity/identity.go b/pkg/aws/identity/identity.go new file mode 100644 index 0000000000..b4053d8fdd --- /dev/null +++ b/pkg/aws/identity/identity.go @@ -0,0 +1,326 @@ +package identity + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/sts" + + errUtils "github.com/cloudposse/atmos/errors" + log "github.com/cloudposse/atmos/pkg/logger" + "github.com/cloudposse/atmos/pkg/perf" + "github.com/cloudposse/atmos/pkg/schema" +) + +// CallerIdentity holds the information returned by AWS STS GetCallerIdentity. +type CallerIdentity struct { + Account string + Arn string + UserID string + Region string // The AWS region from the loaded config. +} + +// Getter provides an interface for retrieving AWS caller identity information. +// This interface enables dependency injection and testability. +// +//go:generate go run go.uber.org/mock/mockgen@v0.6.0 -source=$GOFILE -destination=mock_identity.go -package=identity +type Getter interface { + // GetCallerIdentity retrieves the AWS caller identity for the current credentials. + // Returns the account ID, ARN, and user ID of the calling identity. + GetCallerIdentity( + ctx context.Context, + atmosConfig *schema.AtmosConfiguration, + authContext *schema.AWSAuthContext, + ) (*CallerIdentity, error) +} + +// defaultGetter is the production implementation that uses real AWS SDK calls. +type defaultGetter struct{} + +// GetCallerIdentity retrieves the AWS caller identity using the STS GetCallerIdentity API. +func (d *defaultGetter) GetCallerIdentity( + ctx context.Context, + atmosConfig *schema.AtmosConfiguration, + authContext *schema.AWSAuthContext, +) (*CallerIdentity, error) { + defer perf.Track(atmosConfig, "identity.Getter.GetCallerIdentity")() + + log.Debug("Getting AWS caller identity") + + // Use the exported function to get caller identity. + result, err := GetCallerIdentity(ctx, "", "", 0, authContext) + if err != nil { + return nil, err + } + + identity := &CallerIdentity{ + Account: result.Account, + Arn: result.Arn, + UserID: result.UserID, + Region: result.Region, + } + + log.Debug("Retrieved AWS caller identity", + "account", identity.Account, + "arn", identity.Arn, + "user_id", identity.UserID, + "region", identity.Region, + ) + + return identity, nil +} + +// getter is the global instance used by functions. +// This allows test code to replace it with a mock. +var getter Getter = &defaultGetter{} + +// SetGetter allows tests to inject a mock Getter. +// Returns a function to restore the original getter. +func SetGetter(g Getter) func() { + defer perf.Track(nil, "identity.SetGetter")() + + original := getter + getter = g + return func() { + getter = original + } +} + +// cachedIdentity holds the cached AWS caller identity. +// The cache is per-CLI-invocation (stored in memory) to avoid repeated STS calls. +type cachedIdentity struct { + identity *CallerIdentity + err error +} + +var ( + identityCache map[string]*cachedIdentity + identityCacheMu sync.RWMutex +) + +func init() { + identityCache = make(map[string]*cachedIdentity) +} + +// getCacheKey generates a cache key based on the auth context. +// Different auth contexts (different credentials) get different cache entries. +// Includes Profile, CredentialsFile, and ConfigFile since all three affect AWS config loading. +func getCacheKey(authContext *schema.AWSAuthContext) string { + defer perf.Track(nil, "identity.getCacheKey")() + + if authContext == nil { + return "default" + } + return fmt.Sprintf("%s:%s:%s", authContext.Profile, authContext.CredentialsFile, authContext.ConfigFile) +} + +// GetCallerIdentityCached retrieves the AWS caller identity with caching. +// Results are cached per auth context to avoid repeated STS calls within the same CLI invocation. +func GetCallerIdentityCached( + ctx context.Context, + atmosConfig *schema.AtmosConfiguration, + authContext *schema.AWSAuthContext, +) (*CallerIdentity, error) { + defer perf.Track(atmosConfig, "identity.GetCallerIdentityCached")() + + cacheKey := getCacheKey(authContext) + + // Check cache first (read lock). + identityCacheMu.RLock() + if cached, ok := identityCache[cacheKey]; ok { + identityCacheMu.RUnlock() + log.Debug("Using cached AWS caller identity", "cache_key", cacheKey) + return cached.identity, cached.err + } + identityCacheMu.RUnlock() + + // Cache miss - acquire write lock and fetch. + identityCacheMu.Lock() + defer identityCacheMu.Unlock() + + // Double-check after acquiring write lock. + if cached, ok := identityCache[cacheKey]; ok { + log.Debug("Using cached AWS caller identity (double-check)", "cache_key", cacheKey) + return cached.identity, cached.err + } + + // Fetch from AWS. + identity, err := getter.GetCallerIdentity(ctx, atmosConfig, authContext) + + // Cache the result (including errors to avoid repeated failed calls). + identityCache[cacheKey] = &cachedIdentity{ + identity: identity, + err: err, + } + + return identity, err +} + +// ClearIdentityCache clears the AWS identity cache. +// This is useful in tests or when credentials change during execution. +func ClearIdentityCache() { + defer perf.Track(nil, "identity.ClearIdentityCache")() + + identityCacheMu.Lock() + defer identityCacheMu.Unlock() + identityCache = make(map[string]*cachedIdentity) +} + +// GetCallerIdentity retrieves AWS caller identity using STS GetCallerIdentity API. +// Returns account ID, ARN, user ID, and region. +// This function keeps AWS SDK STS imports contained within this package. +// For caching, use GetCallerIdentityCached instead. +func GetCallerIdentity( + ctx context.Context, + region string, + roleArn string, + assumeRoleDuration time.Duration, + authContext *schema.AWSAuthContext, +) (*CallerIdentity, error) { + defer perf.Track(nil, "identity.GetCallerIdentity")() + + // Load AWS config. + cfg, err := LoadConfigWithAuth(ctx, region, roleArn, assumeRoleDuration, authContext) + if err != nil { + return nil, err + } + + // Create STS client and get caller identity. + stsClient := sts.NewFromConfig(cfg) + output, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + if err != nil { + return nil, fmt.Errorf("%w: %w", errUtils.ErrAwsGetCallerIdentity, err) + } + + result := &CallerIdentity{ + Region: cfg.Region, + } + + // Extract values from pointers. + if output.Account != nil { + result.Account = *output.Account + } + if output.Arn != nil { + result.Arn = *output.Arn + } + if output.UserId != nil { + result.UserID = *output.UserId + } + + return result, nil +} + +// LoadConfigWithAuth loads AWS config, preferring auth context if available. +// +// When authContext is provided, it uses the Atmos-managed credentials files and profile. +// Otherwise, it falls back to standard AWS SDK credential resolution. +// +// Standard AWS SDK credential resolution order: +// +// Environment variables: +// AWS_ACCESS_KEY_ID +// AWS_SECRET_ACCESS_KEY +// AWS_SESSION_TOKEN (optional, for temporary credentials) +// +// Shared credentials file: +// Typically at ~/.aws/credentials +// Controlled by: +// AWS_PROFILE (defaults to default) +// AWS_SHARED_CREDENTIALS_FILE +// +// Shared config file: +// Typically at ~/.aws/config +// Also supports named profiles and region settings +// +// Amazon EC2 Instance Metadata Service (IMDS): +// If running on EC2 or ECS +// Uses IAM roles attached to the instance/task +// +// Web Identity Token credentials: +// When AWS_WEB_IDENTITY_TOKEN_FILE and AWS_ROLE_ARN are set (e.g., in EKS) +// +// SSO credentials (if configured) +// +// Custom credential sources: +// Provided programmatically using config.WithCredentialsProvider(...) +func LoadConfigWithAuth( + ctx context.Context, + region string, + roleArn string, + assumeRoleDuration time.Duration, + authContext *schema.AWSAuthContext, +) (aws.Config, error) { + defer perf.Track(nil, "identity.LoadConfigWithAuth")() + + var cfgOpts []func(*config.LoadOptions) error + + // If auth context is provided, use Atmos-managed credentials. + if authContext != nil { + log.Debug("Using Atmos auth context for AWS SDK", + "profile", authContext.Profile, + "credentials", authContext.CredentialsFile, + "config", authContext.ConfigFile, + ) + + // Set custom credential and config file paths. + // This overrides the default ~/.aws/credentials and ~/.aws/config. + cfgOpts = append(cfgOpts, + config.WithSharedCredentialsFiles([]string{authContext.CredentialsFile}), + config.WithSharedConfigFiles([]string{authContext.ConfigFile}), + config.WithSharedConfigProfile(authContext.Profile), + ) + + // Use region from auth context if not explicitly provided. + if region == "" && authContext.Region != "" { + region = authContext.Region + } + } else { + log.Debug("Using standard AWS SDK credential resolution (no auth context provided)") + } + + // Set region if provided. + if region != "" { + log.Debug("Using explicit region", "region", region) + cfgOpts = append(cfgOpts, config.WithRegion(region)) + } + + // Load base config. + log.Debug("Loading AWS SDK config", "num_options", len(cfgOpts)) + baseCfg, err := config.LoadDefaultConfig(ctx, cfgOpts...) + if err != nil { + log.Debug("Failed to load AWS config", "error", err) + return aws.Config{}, fmt.Errorf("%w: %w", errUtils.ErrLoadAWSConfig, err) + } + log.Debug("Successfully loaded AWS SDK config", "region", baseCfg.Region) + + // Conditionally assume role if specified. + if roleArn != "" { + log.Debug("Assuming role", "ARN", roleArn) + stsClient := sts.NewFromConfig(baseCfg) + + creds := stscreds.NewAssumeRoleProvider(stsClient, roleArn, func(o *stscreds.AssumeRoleOptions) { + o.Duration = assumeRoleDuration + }) + + cfgOpts = append(cfgOpts, config.WithCredentialsProvider(aws.NewCredentialsCache(creds))) + + // Reload full config with assumed role credentials. + return config.LoadDefaultConfig(ctx, cfgOpts...) + } + + return baseCfg, nil +} + +// LoadConfig loads AWS config using standard AWS SDK credential resolution. +// This is a wrapper around LoadConfigWithAuth for convenience. +// For code that needs Atmos auth support, use LoadConfigWithAuth instead. +func LoadConfig(ctx context.Context, region string, roleArn string, assumeRoleDuration time.Duration) (aws.Config, error) { + defer perf.Track(nil, "identity.LoadConfig")() + + return LoadConfigWithAuth(ctx, region, roleArn, assumeRoleDuration, nil) +} diff --git a/pkg/aws/identity/identity_test.go b/pkg/aws/identity/identity_test.go new file mode 100644 index 0000000000..71129f451a --- /dev/null +++ b/pkg/aws/identity/identity_test.go @@ -0,0 +1,264 @@ +package identity + +import ( + "context" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudposse/atmos/pkg/schema" +) + +// mockGetter is a test implementation of Getter. +type mockGetter struct { + identity *CallerIdentity + err error + calls int +} + +func (m *mockGetter) GetCallerIdentity( + ctx context.Context, + atmosConfig *schema.AtmosConfiguration, + authContext *schema.AWSAuthContext, +) (*CallerIdentity, error) { + m.calls++ + return m.identity, m.err +} + +func TestGetCacheKey(t *testing.T) { + tests := []struct { + name string + authContext *schema.AWSAuthContext + expected string + }{ + { + name: "nil auth context", + authContext: nil, + expected: "default", + }, + { + name: "empty auth context", + authContext: &schema.AWSAuthContext{}, + expected: "::", + }, + { + name: "full auth context", + authContext: &schema.AWSAuthContext{ + Profile: "prod", + CredentialsFile: "/home/user/.aws/credentials", + ConfigFile: "/home/user/.aws/config", + }, + expected: "prod:/home/user/.aws/credentials:/home/user/.aws/config", + }, + { + name: "partial auth context", + authContext: &schema.AWSAuthContext{ + Profile: "dev", + }, + expected: "dev::", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := getCacheKey(tt.authContext) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestGetCallerIdentityCached(t *testing.T) { + // Clear cache before test. + ClearIdentityCache() + + // Set up mock getter. + mock := &mockGetter{ + identity: &CallerIdentity{ + Account: "123456789012", + Arn: "arn:aws:iam::123456789012:user/test", + UserID: "AIDAEXAMPLE", + Region: "us-west-2", + }, + } + + // Replace the global getter with our mock. + restore := SetGetter(mock) + defer restore() + + ctx := context.Background() + + // First call should hit the mock. + identity, err := GetCallerIdentityCached(ctx, nil, nil) + require.NoError(t, err) + assert.Equal(t, "123456789012", identity.Account) + assert.Equal(t, 1, mock.calls) + + // Second call should use cache. + identity2, err := GetCallerIdentityCached(ctx, nil, nil) + require.NoError(t, err) + assert.Equal(t, identity, identity2) + assert.Equal(t, 1, mock.calls, "should not call mock again due to cache") + + // Call with different auth context should hit the mock again. + authContext := &schema.AWSAuthContext{Profile: "other"} + identity3, err := GetCallerIdentityCached(ctx, nil, authContext) + require.NoError(t, err) + assert.Equal(t, "123456789012", identity3.Account) + assert.Equal(t, 2, mock.calls, "should call mock for different auth context") +} + +func TestClearIdentityCache(t *testing.T) { + // Clear cache before test to ensure isolation. + ClearIdentityCache() + + // Set up mock getter. + mock := &mockGetter{ + identity: &CallerIdentity{ + Account: "123456789012", + }, + } + + restore := SetGetter(mock) + defer restore() + + ctx := context.Background() + + // First call. + _, err := GetCallerIdentityCached(ctx, nil, nil) + require.NoError(t, err) + assert.Equal(t, 1, mock.calls) + + // Second call should use cache. + _, err = GetCallerIdentityCached(ctx, nil, nil) + require.NoError(t, err) + assert.Equal(t, 1, mock.calls) + + // Clear cache. + ClearIdentityCache() + + // Third call should hit the mock again. + _, err = GetCallerIdentityCached(ctx, nil, nil) + require.NoError(t, err) + assert.Equal(t, 2, mock.calls, "should call mock after cache clear") +} + +func TestSetGetter(t *testing.T) { + originalGetter := getter + + mock := &mockGetter{} + restore := SetGetter(mock) + + assert.Same(t, mock, getter) + + restore() + + assert.Same(t, originalGetter, getter) +} + +func TestCallerIdentity(t *testing.T) { + identity := &CallerIdentity{ + Account: "123456789012", + Arn: "arn:aws:iam::123456789012:user/test", + UserID: "AIDAEXAMPLE", + Region: "us-east-1", + } + + assert.Equal(t, "123456789012", identity.Account) + assert.Equal(t, "arn:aws:iam::123456789012:user/test", identity.Arn) + assert.Equal(t, "AIDAEXAMPLE", identity.UserID) + assert.Equal(t, "us-east-1", identity.Region) +} + +func TestGetCallerIdentityCached_Error(t *testing.T) { + // Clear cache before test. + ClearIdentityCache() + + // Set up mock getter that returns an error. + expectedErr := assert.AnError + mock := &mockGetter{ + identity: nil, + err: expectedErr, + } + + restore := SetGetter(mock) + defer restore() + + ctx := context.Background() + + // First call should hit the mock and return error. + _, err := GetCallerIdentityCached(ctx, nil, nil) + require.Error(t, err) + assert.Equal(t, expectedErr, err) + assert.Equal(t, 1, mock.calls) + + // Second call should return cached error without calling mock again. + _, err = GetCallerIdentityCached(ctx, nil, nil) + require.Error(t, err) + assert.Equal(t, expectedErr, err) + assert.Equal(t, 1, mock.calls, "should use cached error") +} + +func TestGetCallerIdentityCached_ConcurrentAccess(t *testing.T) { + // Clear cache before test. + ClearIdentityCache() + + // Set up mock getter. + mock := &mockGetter{ + identity: &CallerIdentity{ + Account: "123456789012", + Arn: "arn:aws:iam::123456789012:user/test", + UserID: "AIDAEXAMPLE", + Region: "us-west-2", + }, + } + + restore := SetGetter(mock) + defer restore() + + ctx := context.Background() + numGoroutines := 10 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + results := make([]*CallerIdentity, numGoroutines) + errors := make([]error, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(idx int) { + defer wg.Done() + results[idx], errors[idx] = GetCallerIdentityCached(ctx, nil, nil) + }(i) + } + + wg.Wait() + + // All results should be successful. + for i := 0; i < numGoroutines; i++ { + require.NoError(t, errors[i]) + assert.Equal(t, "123456789012", results[i].Account) + } + + // Mock should have been called only once due to caching. + assert.Equal(t, 1, mock.calls) +} + +func TestGetCacheKey_AuthContextWithRegion(t *testing.T) { + authContext := &schema.AWSAuthContext{ + Profile: "dev", + CredentialsFile: "/path/to/creds", + ConfigFile: "/path/to/config", + Region: "eu-west-1", // Region is not included in cache key. + } + + // Region should not affect the cache key since it's handled separately. + result := getCacheKey(authContext) + assert.Equal(t, "dev:/path/to/creds:/path/to/config", result) +} + +func TestDefaultGetter_GetCallerIdentity(t *testing.T) { + // This test would require actual AWS credentials, so we skip in unit tests. + // It's here for documentation and can be run manually with credentials. + t.Skip("Requires actual AWS credentials - run manually for integration testing") +} diff --git a/pkg/aws/identity/mock_identity.go b/pkg/aws/identity/mock_identity.go new file mode 100644 index 0000000000..f367eedafb --- /dev/null +++ b/pkg/aws/identity/mock_identity.go @@ -0,0 +1,57 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: identity.go +// +// Generated by this command: +// +// mockgen -source=identity.go -destination=mock_identity.go -package=identity +// + +// Package identity is a generated GoMock package. +package identity + +import ( + context "context" + reflect "reflect" + + schema "github.com/cloudposse/atmos/pkg/schema" + gomock "go.uber.org/mock/gomock" +) + +// MockGetter is a mock of Getter interface. +type MockGetter struct { + ctrl *gomock.Controller + recorder *MockGetterMockRecorder + isgomock struct{} +} + +// MockGetterMockRecorder is the mock recorder for MockGetter. +type MockGetterMockRecorder struct { + mock *MockGetter +} + +// NewMockGetter creates a new mock instance. +func NewMockGetter(ctrl *gomock.Controller) *MockGetter { + mock := &MockGetter{ctrl: ctrl} + mock.recorder = &MockGetterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockGetter) EXPECT() *MockGetterMockRecorder { + return m.recorder +} + +// GetCallerIdentity mocks base method. +func (m *MockGetter) GetCallerIdentity(ctx context.Context, atmosConfig *schema.AtmosConfiguration, authContext *schema.AWSAuthContext) (*CallerIdentity, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCallerIdentity", ctx, atmosConfig, authContext) + ret0, _ := ret[0].(*CallerIdentity) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetCallerIdentity indicates an expected call of GetCallerIdentity. +func (mr *MockGetterMockRecorder) GetCallerIdentity(ctx, atmosConfig, authContext any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCallerIdentity", reflect.TypeOf((*MockGetter)(nil).GetCallerIdentity), ctx, atmosConfig, authContext) +} diff --git a/pkg/config/load.go b/pkg/config/load.go index 618d746bb8..9e69b7745b 100644 --- a/pkg/config/load.go +++ b/pkg/config/load.go @@ -681,13 +681,12 @@ func readEnvAmosConfigPath(v *viper.Viper) error { log.Trace("Checking for atmos.yaml from ATMOS_CLI_CONFIG_PATH", "path", atmosPath) err := mergeConfig(v, atmosPath, CliConfigFileName, true) if err != nil { - switch err.(type) { - case viper.ConfigFileNotFoundError: + var configFileNotFoundError viper.ConfigFileNotFoundError + if errors.As(err, &configFileNotFoundError) { log.Debug("config not found ENV var "+AtmosCliConfigPathEnvVar, "file", atmosPath) return nil - default: - return err } + return err } log.Trace("Found config ENV", AtmosCliConfigPathEnvVar, atmosPath) diff --git a/pkg/function/aws.go b/pkg/function/aws.go new file mode 100644 index 0000000000..f8abbb78c9 --- /dev/null +++ b/pkg/function/aws.go @@ -0,0 +1,182 @@ +package function + +import ( + "context" + + awsIdentity "github.com/cloudposse/atmos/pkg/aws/identity" + log "github.com/cloudposse/atmos/pkg/logger" + "github.com/cloudposse/atmos/pkg/perf" + "github.com/cloudposse/atmos/pkg/schema" +) + +// errMsgAWSIdentityFailed is a constant for the AWS identity error message. +const errMsgAWSIdentityFailed = "Failed to get AWS caller identity" + +// getAWSIdentity is a helper that retrieves the AWS caller identity from the execution context. +func getAWSIdentity(ctx context.Context, execCtx *ExecutionContext) (*awsIdentity.CallerIdentity, error) { + defer perf.Track(nil, "function.getAWSIdentity")() + + // Get auth context from stack info if available. + var authContext *schema.AWSAuthContext + if execCtx != nil && execCtx.StackInfo != nil && + execCtx.StackInfo.AuthContext != nil && execCtx.StackInfo.AuthContext.AWS != nil { + authContext = execCtx.StackInfo.AuthContext.AWS + } + + // Get AtmosConfig from execution context. + var atmosConfig *schema.AtmosConfiguration + if execCtx != nil { + atmosConfig = execCtx.AtmosConfig + } + + // Get the AWS caller identity (cached). + return awsIdentity.GetCallerIdentityCached(ctx, atmosConfig, authContext) +} + +// AwsAccountIDFunction implements the aws.account_id function. +type AwsAccountIDFunction struct { + BaseFunction +} + +// NewAwsAccountIDFunction creates a new aws.account_id function handler. +func NewAwsAccountIDFunction() *AwsAccountIDFunction { + defer perf.Track(nil, "function.NewAwsAccountIDFunction")() + + return &AwsAccountIDFunction{ + BaseFunction: BaseFunction{ + FunctionName: TagAwsAccountID, + FunctionAliases: nil, + FunctionPhase: PostMerge, + }, + } +} + +// Execute processes the aws.account_id function. +// Usage: +// +// !aws.account_id - Returns the AWS account ID of the current caller identity +func (f *AwsAccountIDFunction) Execute(ctx context.Context, args string, execCtx *ExecutionContext) (any, error) { + defer perf.Track(nil, "function.AwsAccountIDFunction.Execute")() + + log.Debug("Executing aws.account_id function") + + identity, err := getAWSIdentity(ctx, execCtx) + if err != nil { + log.Error(errMsgAWSIdentityFailed, "error", err) + return nil, err + } + + log.Debug("Resolved !aws.account_id", "account_id", identity.Account) + return identity.Account, nil +} + +// AwsCallerIdentityArnFunction implements the aws.caller_identity_arn function. +type AwsCallerIdentityArnFunction struct { + BaseFunction +} + +// NewAwsCallerIdentityArnFunction creates a new aws.caller_identity_arn function handler. +func NewAwsCallerIdentityArnFunction() *AwsCallerIdentityArnFunction { + defer perf.Track(nil, "function.NewAwsCallerIdentityArnFunction")() + + return &AwsCallerIdentityArnFunction{ + BaseFunction: BaseFunction{ + FunctionName: TagAwsCallerIdentityArn, + FunctionAliases: nil, + FunctionPhase: PostMerge, + }, + } +} + +// Execute processes the aws.caller_identity_arn function. +// Usage: +// +// !aws.caller_identity_arn - Returns the ARN of the current caller identity +func (f *AwsCallerIdentityArnFunction) Execute(ctx context.Context, args string, execCtx *ExecutionContext) (any, error) { + defer perf.Track(nil, "function.AwsCallerIdentityArnFunction.Execute")() + + log.Debug("Executing aws.caller_identity_arn function") + + identity, err := getAWSIdentity(ctx, execCtx) + if err != nil { + log.Error(errMsgAWSIdentityFailed, "error", err) + return nil, err + } + + log.Debug("Resolved !aws.caller_identity_arn", "arn", identity.Arn) + return identity.Arn, nil +} + +// AwsCallerIdentityUserIDFunction implements the aws.caller_identity_user_id function. +type AwsCallerIdentityUserIDFunction struct { + BaseFunction +} + +// NewAwsCallerIdentityUserIDFunction creates a new aws.caller_identity_user_id function handler. +func NewAwsCallerIdentityUserIDFunction() *AwsCallerIdentityUserIDFunction { + defer perf.Track(nil, "function.NewAwsCallerIdentityUserIDFunction")() + + return &AwsCallerIdentityUserIDFunction{ + BaseFunction: BaseFunction{ + FunctionName: TagAwsCallerIdentityUserID, + FunctionAliases: nil, + FunctionPhase: PostMerge, + }, + } +} + +// Execute processes the aws.caller_identity_user_id function. +// Usage: +// +// !aws.caller_identity_user_id - Returns the user ID of the current caller identity +func (f *AwsCallerIdentityUserIDFunction) Execute(ctx context.Context, args string, execCtx *ExecutionContext) (any, error) { + defer perf.Track(nil, "function.AwsCallerIdentityUserIDFunction.Execute")() + + log.Debug("Executing aws.caller_identity_user_id function") + + identity, err := getAWSIdentity(ctx, execCtx) + if err != nil { + log.Error(errMsgAWSIdentityFailed, "error", err) + return nil, err + } + + log.Debug("Resolved !aws.caller_identity_user_id", "user_id", identity.UserID) + return identity.UserID, nil +} + +// AwsRegionFunction implements the aws.region function. +type AwsRegionFunction struct { + BaseFunction +} + +// NewAwsRegionFunction creates a new aws.region function handler. +func NewAwsRegionFunction() *AwsRegionFunction { + defer perf.Track(nil, "function.NewAwsRegionFunction")() + + return &AwsRegionFunction{ + BaseFunction: BaseFunction{ + FunctionName: TagAwsRegion, + FunctionAliases: nil, + FunctionPhase: PostMerge, + }, + } +} + +// Execute processes the aws.region function. +// Usage: +// +// !aws.region - Returns the AWS region from the current configuration +func (f *AwsRegionFunction) Execute(ctx context.Context, args string, execCtx *ExecutionContext) (any, error) { + defer perf.Track(nil, "function.AwsRegionFunction.Execute")() + + log.Debug("Executing aws.region function") + + identity, err := getAWSIdentity(ctx, execCtx) + if err != nil { + log.Error(errMsgAWSIdentityFailed, "error", err) + return nil, err + } + + log.Debug("Resolved !aws.region", "region", identity.Region) + return identity.Region, nil +} diff --git a/pkg/function/aws_test.go b/pkg/function/aws_test.go new file mode 100644 index 0000000000..04c1943ec0 --- /dev/null +++ b/pkg/function/aws_test.go @@ -0,0 +1,311 @@ +package function + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + awsIdentity "github.com/cloudposse/atmos/pkg/aws/identity" + "github.com/cloudposse/atmos/pkg/schema" +) + +// mockAWSGetter is a test implementation of identity.Getter. +type mockAWSGetter struct { + identity *awsIdentity.CallerIdentity + err error + calls int +} + +func (m *mockAWSGetter) GetCallerIdentity( + ctx context.Context, + atmosConfig *schema.AtmosConfiguration, + authContext *schema.AWSAuthContext, +) (*awsIdentity.CallerIdentity, error) { + m.calls++ + return m.identity, m.err +} + +func TestNewAwsAccountIDFunction(t *testing.T) { + fn := NewAwsAccountIDFunction() + require.NotNil(t, fn) + assert.Equal(t, TagAwsAccountID, fn.Name()) + assert.Equal(t, PostMerge, fn.Phase()) + assert.Nil(t, fn.Aliases()) +} + +func TestNewAwsCallerIdentityArnFunction(t *testing.T) { + fn := NewAwsCallerIdentityArnFunction() + require.NotNil(t, fn) + assert.Equal(t, TagAwsCallerIdentityArn, fn.Name()) + assert.Equal(t, PostMerge, fn.Phase()) + assert.Nil(t, fn.Aliases()) +} + +func TestNewAwsCallerIdentityUserIDFunction(t *testing.T) { + fn := NewAwsCallerIdentityUserIDFunction() + require.NotNil(t, fn) + assert.Equal(t, TagAwsCallerIdentityUserID, fn.Name()) + assert.Equal(t, PostMerge, fn.Phase()) + assert.Nil(t, fn.Aliases()) +} + +func TestNewAwsRegionFunction(t *testing.T) { + fn := NewAwsRegionFunction() + require.NotNil(t, fn) + assert.Equal(t, TagAwsRegion, fn.Name()) + assert.Equal(t, PostMerge, fn.Phase()) + assert.Nil(t, fn.Aliases()) +} + +func TestAwsAccountIDFunction_Execute(t *testing.T) { + // Clear identity cache before test. + awsIdentity.ClearIdentityCache() + + // Set up mock. + mock := &mockAWSGetter{ + identity: &awsIdentity.CallerIdentity{ + Account: "123456789012", + Arn: "arn:aws:iam::123456789012:user/test", + UserID: "AIDAEXAMPLE", + Region: "us-west-2", + }, + } + + restore := awsIdentity.SetGetter(mock) + defer restore() + + fn := NewAwsAccountIDFunction() + result, err := fn.Execute(context.Background(), "", nil) + + require.NoError(t, err) + assert.Equal(t, "123456789012", result) +} + +func TestAwsAccountIDFunction_Execute_Error(t *testing.T) { + // Clear identity cache before test. + awsIdentity.ClearIdentityCache() + + // Set up mock that returns an error. + expectedErr := errors.New("AWS credentials not configured") + mock := &mockAWSGetter{ + err: expectedErr, + } + + restore := awsIdentity.SetGetter(mock) + defer restore() + + fn := NewAwsAccountIDFunction() + _, err := fn.Execute(context.Background(), "", nil) + + require.Error(t, err) + assert.Equal(t, expectedErr, err) +} + +func TestAwsCallerIdentityArnFunction_Execute(t *testing.T) { + // Clear identity cache before test. + awsIdentity.ClearIdentityCache() + + // Set up mock. + mock := &mockAWSGetter{ + identity: &awsIdentity.CallerIdentity{ + Account: "123456789012", + Arn: "arn:aws:iam::123456789012:user/test", + UserID: "AIDAEXAMPLE", + Region: "us-west-2", + }, + } + + restore := awsIdentity.SetGetter(mock) + defer restore() + + fn := NewAwsCallerIdentityArnFunction() + result, err := fn.Execute(context.Background(), "", nil) + + require.NoError(t, err) + assert.Equal(t, "arn:aws:iam::123456789012:user/test", result) +} + +func TestAwsCallerIdentityArnFunction_Execute_Error(t *testing.T) { + // Clear identity cache before test. + awsIdentity.ClearIdentityCache() + + expectedErr := errors.New("STS error") + mock := &mockAWSGetter{err: expectedErr} + + restore := awsIdentity.SetGetter(mock) + defer restore() + + fn := NewAwsCallerIdentityArnFunction() + _, err := fn.Execute(context.Background(), "", nil) + + require.Error(t, err) +} + +func TestAwsCallerIdentityUserIDFunction_Execute(t *testing.T) { + // Clear identity cache before test. + awsIdentity.ClearIdentityCache() + + mock := &mockAWSGetter{ + identity: &awsIdentity.CallerIdentity{ + Account: "123456789012", + Arn: "arn:aws:iam::123456789012:user/test", + UserID: "AIDAEXAMPLE", + Region: "us-west-2", + }, + } + + restore := awsIdentity.SetGetter(mock) + defer restore() + + fn := NewAwsCallerIdentityUserIDFunction() + result, err := fn.Execute(context.Background(), "", nil) + + require.NoError(t, err) + assert.Equal(t, "AIDAEXAMPLE", result) +} + +func TestAwsCallerIdentityUserIDFunction_Execute_Error(t *testing.T) { + // Clear identity cache before test. + awsIdentity.ClearIdentityCache() + + expectedErr := errors.New("STS error") + mock := &mockAWSGetter{err: expectedErr} + + restore := awsIdentity.SetGetter(mock) + defer restore() + + fn := NewAwsCallerIdentityUserIDFunction() + _, err := fn.Execute(context.Background(), "", nil) + + require.Error(t, err) +} + +func TestAwsRegionFunction_Execute(t *testing.T) { + // Clear identity cache before test. + awsIdentity.ClearIdentityCache() + + mock := &mockAWSGetter{ + identity: &awsIdentity.CallerIdentity{ + Account: "123456789012", + Arn: "arn:aws:iam::123456789012:user/test", + UserID: "AIDAEXAMPLE", + Region: "us-west-2", + }, + } + + restore := awsIdentity.SetGetter(mock) + defer restore() + + fn := NewAwsRegionFunction() + result, err := fn.Execute(context.Background(), "", nil) + + require.NoError(t, err) + assert.Equal(t, "us-west-2", result) +} + +func TestAwsRegionFunction_Execute_Error(t *testing.T) { + // Clear identity cache before test. + awsIdentity.ClearIdentityCache() + + expectedErr := errors.New("Region error") + mock := &mockAWSGetter{err: expectedErr} + + restore := awsIdentity.SetGetter(mock) + defer restore() + + fn := NewAwsRegionFunction() + _, err := fn.Execute(context.Background(), "", nil) + + require.Error(t, err) +} + +func TestGetAWSIdentity_WithStackInfo(t *testing.T) { + // Clear identity cache before test. + awsIdentity.ClearIdentityCache() + + mock := &mockAWSGetter{ + identity: &awsIdentity.CallerIdentity{ + Account: "987654321098", + Arn: "arn:aws:iam::987654321098:role/admin", + UserID: "AROA12345", + Region: "eu-west-1", + }, + } + + restore := awsIdentity.SetGetter(mock) + defer restore() + + // Create execution context with stack info and auth context. + execCtx := &ExecutionContext{ + AtmosConfig: &schema.AtmosConfiguration{}, + StackInfo: &schema.ConfigAndStacksInfo{ + AuthContext: &schema.AuthContext{ + AWS: &schema.AWSAuthContext{ + Profile: "custom-profile", + Region: "eu-west-1", + }, + }, + }, + } + + fn := NewAwsAccountIDFunction() + result, err := fn.Execute(context.Background(), "", execCtx) + + require.NoError(t, err) + assert.Equal(t, "987654321098", result) +} + +func TestGetAWSIdentity_NilExecutionContext(t *testing.T) { + // Clear identity cache before test. + awsIdentity.ClearIdentityCache() + + mock := &mockAWSGetter{ + identity: &awsIdentity.CallerIdentity{ + Account: "123456789012", + Arn: "arn:aws:iam::123456789012:user/default", + UserID: "DEFAULT", + Region: "us-east-1", + }, + } + + restore := awsIdentity.SetGetter(mock) + defer restore() + + fn := NewAwsAccountIDFunction() + result, err := fn.Execute(context.Background(), "", nil) + + require.NoError(t, err) + assert.Equal(t, "123456789012", result) +} + +func TestGetAWSIdentity_PartialStackInfo(t *testing.T) { + // Clear identity cache before test. + awsIdentity.ClearIdentityCache() + + mock := &mockAWSGetter{ + identity: &awsIdentity.CallerIdentity{ + Account: "111222333444", + Arn: "arn:aws:iam::111222333444:user/test", + UserID: "TEST", + Region: "ap-southeast-1", + }, + } + + restore := awsIdentity.SetGetter(mock) + defer restore() + + // Execution context with StackInfo but nil AuthContext. + execCtx := &ExecutionContext{ + AtmosConfig: &schema.AtmosConfiguration{}, + StackInfo: &schema.ConfigAndStacksInfo{}, + } + + fn := NewAwsRegionFunction() + result, err := fn.Execute(context.Background(), "", execCtx) + + require.NoError(t, err) + assert.Equal(t, "ap-southeast-1", result) +} diff --git a/pkg/function/context.go b/pkg/function/context.go new file mode 100644 index 0000000000..b89ad3f6a8 --- /dev/null +++ b/pkg/function/context.go @@ -0,0 +1,60 @@ +package function + +import ( + "github.com/cloudposse/atmos/pkg/perf" + "github.com/cloudposse/atmos/pkg/schema" +) + +// ExecutionContext provides the runtime context for function execution. +// It contains all the information a function might need to resolve values. +type ExecutionContext struct { + // AtmosConfig is the current Atmos configuration. + AtmosConfig *schema.AtmosConfiguration + + // Stack is the current stack name being processed. + Stack string + + // Component is the current component name being processed. + Component string + + // BaseDir is the base directory for relative path resolution. + BaseDir string + + // File is the path to the file being processed. + File string + + // StackInfo contains additional stack and component information. + StackInfo *schema.ConfigAndStacksInfo +} + +// NewExecutionContext creates a new ExecutionContext with the given parameters. +func NewExecutionContext(atmosConfig *schema.AtmosConfiguration, stack, component string) *ExecutionContext { + defer perf.Track(atmosConfig, "function.NewExecutionContext")() + + return &ExecutionContext{ + AtmosConfig: atmosConfig, + Stack: stack, + Component: component, + } +} + +// WithFile returns a copy of the context with the file path set. +func (ctx *ExecutionContext) WithFile(file string) *ExecutionContext { + newCtx := *ctx + newCtx.File = file + return &newCtx +} + +// WithBaseDir returns a copy of the context with the base directory set. +func (ctx *ExecutionContext) WithBaseDir(baseDir string) *ExecutionContext { + newCtx := *ctx + newCtx.BaseDir = baseDir + return &newCtx +} + +// WithStackInfo returns a copy of the context with stack info set. +func (ctx *ExecutionContext) WithStackInfo(stackInfo *schema.ConfigAndStacksInfo) *ExecutionContext { + newCtx := *ctx + newCtx.StackInfo = stackInfo + return &newCtx +} diff --git a/pkg/function/context_test.go b/pkg/function/context_test.go new file mode 100644 index 0000000000..8195182d11 --- /dev/null +++ b/pkg/function/context_test.go @@ -0,0 +1,214 @@ +package function + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudposse/atmos/pkg/schema" +) + +func TestNewExecutionContext(t *testing.T) { + atmosConfig := &schema.AtmosConfiguration{} + stack := "tenant1-ue2-dev" + component := "vpc" + + ctx := NewExecutionContext(atmosConfig, stack, component) + + require.NotNil(t, ctx) + assert.Same(t, atmosConfig, ctx.AtmosConfig) + assert.Equal(t, stack, ctx.Stack) + assert.Equal(t, component, ctx.Component) + assert.Empty(t, ctx.BaseDir) + assert.Empty(t, ctx.File) + assert.Nil(t, ctx.StackInfo) +} + +func TestNewExecutionContext_NilConfig(t *testing.T) { + ctx := NewExecutionContext(nil, "stack", "component") + + require.NotNil(t, ctx) + assert.Nil(t, ctx.AtmosConfig) + assert.Equal(t, "stack", ctx.Stack) + assert.Equal(t, "component", ctx.Component) +} + +func TestNewExecutionContext_EmptyValues(t *testing.T) { + ctx := NewExecutionContext(nil, "", "") + + require.NotNil(t, ctx) + assert.Empty(t, ctx.Stack) + assert.Empty(t, ctx.Component) +} + +func TestExecutionContext_WithFile(t *testing.T) { + original := &ExecutionContext{ + AtmosConfig: &schema.AtmosConfiguration{}, + Stack: "stack1", + Component: "comp1", + BaseDir: "/base", + } + + newCtx := original.WithFile("/path/to/file.yaml") + + // New context should have the file set. + assert.Equal(t, "/path/to/file.yaml", newCtx.File) + + // Original should be unchanged. + assert.Empty(t, original.File) + + // Other fields should be copied. + assert.Same(t, original.AtmosConfig, newCtx.AtmosConfig) + assert.Equal(t, original.Stack, newCtx.Stack) + assert.Equal(t, original.Component, newCtx.Component) + assert.Equal(t, original.BaseDir, newCtx.BaseDir) +} + +func TestExecutionContext_WithFile_Chaining(t *testing.T) { + ctx := NewExecutionContext(&schema.AtmosConfiguration{}, "stack", "component") + + result := ctx.WithFile("/file1.yaml").WithFile("/file2.yaml") + + assert.Equal(t, "/file2.yaml", result.File) +} + +func TestExecutionContext_WithBaseDir(t *testing.T) { + original := &ExecutionContext{ + AtmosConfig: &schema.AtmosConfiguration{}, + Stack: "stack1", + Component: "comp1", + File: "/some/file.yaml", + } + + newCtx := original.WithBaseDir("/new/base/dir") + + // New context should have the base dir set. + assert.Equal(t, "/new/base/dir", newCtx.BaseDir) + + // Original should be unchanged. + assert.Empty(t, original.BaseDir) + + // Other fields should be copied. + assert.Same(t, original.AtmosConfig, newCtx.AtmosConfig) + assert.Equal(t, original.Stack, newCtx.Stack) + assert.Equal(t, original.Component, newCtx.Component) + assert.Equal(t, original.File, newCtx.File) +} + +func TestExecutionContext_WithBaseDir_Chaining(t *testing.T) { + ctx := NewExecutionContext(&schema.AtmosConfiguration{}, "stack", "component") + + result := ctx.WithBaseDir("/dir1").WithBaseDir("/dir2") + + assert.Equal(t, "/dir2", result.BaseDir) +} + +func TestExecutionContext_WithStackInfo(t *testing.T) { + original := &ExecutionContext{ + AtmosConfig: &schema.AtmosConfiguration{}, + Stack: "stack1", + Component: "comp1", + } + + stackInfo := &schema.ConfigAndStacksInfo{ + Stack: "test-stack", + Component: "test-component", + } + + newCtx := original.WithStackInfo(stackInfo) + + // New context should have the stack info set. + assert.Same(t, stackInfo, newCtx.StackInfo) + + // Original should be unchanged. + assert.Nil(t, original.StackInfo) + + // Other fields should be copied. + assert.Same(t, original.AtmosConfig, newCtx.AtmosConfig) + assert.Equal(t, original.Stack, newCtx.Stack) + assert.Equal(t, original.Component, newCtx.Component) +} + +func TestExecutionContext_WithStackInfo_Nil(t *testing.T) { + original := &ExecutionContext{ + AtmosConfig: &schema.AtmosConfiguration{}, + Stack: "stack1", + StackInfo: &schema.ConfigAndStacksInfo{}, + } + + newCtx := original.WithStackInfo(nil) + + assert.Nil(t, newCtx.StackInfo) +} + +func TestExecutionContext_MethodChaining(t *testing.T) { + atmosConfig := &schema.AtmosConfiguration{} + stackInfo := &schema.ConfigAndStacksInfo{ + Stack: "info-stack", + } + + ctx := NewExecutionContext(atmosConfig, "stack", "component"). + WithFile("/path/to/config.yaml"). + WithBaseDir("/base/dir"). + WithStackInfo(stackInfo) + + assert.Same(t, atmosConfig, ctx.AtmosConfig) + assert.Equal(t, "stack", ctx.Stack) + assert.Equal(t, "component", ctx.Component) + assert.Equal(t, "/path/to/config.yaml", ctx.File) + assert.Equal(t, "/base/dir", ctx.BaseDir) + assert.Same(t, stackInfo, ctx.StackInfo) +} + +func TestExecutionContext_ImmutableCopy(t *testing.T) { + // Verify that With* methods create immutable copies. + original := NewExecutionContext(&schema.AtmosConfiguration{}, "stack", "component") + + copy1 := original.WithFile("/file1.yaml") + copy2 := original.WithFile("/file2.yaml") + + // Copies should be independent. + assert.Equal(t, "/file1.yaml", copy1.File) + assert.Equal(t, "/file2.yaml", copy2.File) + assert.Empty(t, original.File) + + // Copies should not be the same pointer. + assert.NotSame(t, copy1, copy2) + assert.NotSame(t, original, copy1) + assert.NotSame(t, original, copy2) +} + +func TestExecutionContext_FieldAccess(t *testing.T) { + atmosConfig := &schema.AtmosConfiguration{} + stackInfo := &schema.ConfigAndStacksInfo{ + Stack: "info-stack", + Component: "info-component", + AuthContext: &schema.AuthContext{ + AWS: &schema.AWSAuthContext{ + Profile: "prod", + }, + }, + } + + ctx := &ExecutionContext{ + AtmosConfig: atmosConfig, + Stack: "my-stack", + Component: "my-component", + BaseDir: "/home/user/project", + File: "/home/user/project/stacks/config.yaml", + StackInfo: stackInfo, + } + + // Direct field access. + assert.Same(t, atmosConfig, ctx.AtmosConfig) + assert.Equal(t, "my-stack", ctx.Stack) + assert.Equal(t, "my-component", ctx.Component) + assert.Equal(t, "/home/user/project", ctx.BaseDir) + assert.Equal(t, "/home/user/project/stacks/config.yaml", ctx.File) + assert.Same(t, stackInfo, ctx.StackInfo) + + // Nested access. + assert.Equal(t, "info-stack", ctx.StackInfo.Stack) + assert.Equal(t, "prod", ctx.StackInfo.AuthContext.AWS.Profile) +} diff --git a/pkg/function/defaults.go b/pkg/function/defaults.go new file mode 100644 index 0000000000..6e787c3664 --- /dev/null +++ b/pkg/function/defaults.go @@ -0,0 +1,54 @@ +package function + +import ( + "sync" + + log "github.com/cloudposse/atmos/pkg/logger" + "github.com/cloudposse/atmos/pkg/perf" +) + +var registerOnce sync.Once + +// RegisterDefaults registers all default function handlers with the global registry. +// This is called automatically when DefaultRegistry() is first accessed, +// but can also be called explicitly to ensure functions are registered. +func RegisterDefaults() { + defer perf.Track(nil, "function.RegisterDefaults")() + + registerOnce.Do(func() { + registry := DefaultRegistry() + + // PreMerge functions. + mustRegister(registry, NewEnvFunction()) + mustRegister(registry, NewExecFunction()) + mustRegister(registry, NewRandomFunction()) + mustRegister(registry, NewTemplateFunction()) + mustRegister(registry, NewGitRootFunction()) + mustRegister(registry, NewIncludeFunction()) + mustRegister(registry, NewIncludeRawFunction()) + mustRegister(registry, NewLiteralFunction()) + + // PostMerge functions. + mustRegister(registry, NewStoreFunction()) + mustRegister(registry, NewStoreGetFunction()) + mustRegister(registry, NewTerraformOutputFunction()) + mustRegister(registry, NewTerraformStateFunction()) + mustRegister(registry, NewAwsAccountIDFunction()) + mustRegister(registry, NewAwsCallerIdentityArnFunction()) + mustRegister(registry, NewAwsCallerIdentityUserIDFunction()) + mustRegister(registry, NewAwsRegionFunction()) + }) +} + +// mustRegister registers a function and panics on error. +func mustRegister(registry *Registry, fn Function) { + if err := registry.Register(fn); err != nil { + log.Error("Failed to register function", "name", fn.Name(), "error", err) + panic(err) + } +} + +// init automatically registers defaults when the package is imported. +func init() { + RegisterDefaults() +} diff --git a/pkg/function/doc.go b/pkg/function/doc.go new file mode 100644 index 0000000000..9b1993bebc --- /dev/null +++ b/pkg/function/doc.go @@ -0,0 +1,19 @@ +// Package function provides a format-agnostic function registry for Atmos. +// +// This package implements a plugin-like architecture for YAML/HCL/JSON functions +// (e.g., !env, !exec, !terraform.output) that can be used across different +// configuration formats. +// +// The registry pattern allows functions to be registered, looked up by name or +// alias, and filtered by execution phase (PreMerge or PostMerge). +// +// Example usage: +// +// // Register a function +// fn := NewEnvFunction() +// function.DefaultRegistry().Register(fn) +// +// // Look up and execute +// fn, err := function.DefaultRegistry().Get("env") +// result, err := fn.Execute(ctx, "MY_VAR default_value", execCtx) +package function diff --git a/pkg/function/env.go b/pkg/function/env.go new file mode 100644 index 0000000000..10dfd70683 --- /dev/null +++ b/pkg/function/env.go @@ -0,0 +1,99 @@ +package function + +import ( + "context" + "fmt" + "os" + "strings" + + log "github.com/cloudposse/atmos/pkg/logger" + "github.com/cloudposse/atmos/pkg/perf" + "github.com/cloudposse/atmos/pkg/utils" +) + +// EnvFunction implements the env function for environment variable lookup. +type EnvFunction struct { + BaseFunction +} + +// NewEnvFunction creates a new env function handler. +func NewEnvFunction() *EnvFunction { + defer perf.Track(nil, "function.NewEnvFunction")() + + return &EnvFunction{ + BaseFunction: BaseFunction{ + FunctionName: TagEnv, + FunctionAliases: nil, + FunctionPhase: PreMerge, + }, + } +} + +// parseEnvArgs parses the env function arguments into variable name and optional default. +func parseEnvArgs(args string) (envVarName, envVarDefault string, err error) { + args = strings.TrimSpace(args) + if args == "" { + return "", "", fmt.Errorf("%w: env function requires at least one argument", ErrInvalidArguments) + } + + parts, err := utils.SplitStringByDelimiter(args, ' ') + if err != nil { + return "", "", fmt.Errorf("%w: failed to parse args %q: %w", ErrInvalidArguments, args, err) + } + + switch len(parts) { + case 2: + return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]), nil + case 1: + return strings.TrimSpace(parts[0]), "", nil + default: + return "", "", fmt.Errorf("%w: env function accepts 1 or 2 arguments, got %d", ErrInvalidArguments, len(parts)) + } +} + +// lookupEnvFromContext checks the component's env section from stack manifests. +func lookupEnvFromContext(execCtx *ExecutionContext, envVarName string) (string, bool) { + if execCtx == nil || execCtx.StackInfo == nil { + return "", false + } + envSection := execCtx.StackInfo.GetComponentEnvSection() + if envSection == nil { + return "", false + } + if val, exists := envSection[envVarName]; exists { + return fmt.Sprintf("%v", val), true + } + return "", false +} + +// Execute processes the env function. +// Usage: +// +// !env VAR_NAME - Get environment variable, return empty string if not set +// !env VAR_NAME default - Get environment variable, return default if not set +func (f *EnvFunction) Execute(ctx context.Context, args string, execCtx *ExecutionContext) (any, error) { + defer perf.Track(nil, "function.EnvFunction.Execute")() + + log.Debug("Executing env function", "args", args) + + envVarName, envVarDefault, err := parseEnvArgs(args) + if err != nil { + return "", err + } + + // Check the component's env section from stack manifests first. + if val, found := lookupEnvFromContext(execCtx, envVarName); found { + return val, nil + } + + // Fall back to OS environment variables. + if res, exists := os.LookupEnv(envVarName); exists { + return res, nil + } + + if envVarDefault != "" { + return envVarDefault, nil + } + + return "", nil +} diff --git a/pkg/function/env_test.go b/pkg/function/env_test.go new file mode 100644 index 0000000000..b524e5d6d1 --- /dev/null +++ b/pkg/function/env_test.go @@ -0,0 +1,147 @@ +package function + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEnvFunction_Execute_EdgeCases(t *testing.T) { + fn := NewEnvFunction() + + tests := []struct { + name string + args string + setupEnv map[string]string + expected any + expectError bool + }{ + { + name: "empty args returns error", + args: "", + expectError: true, + }, + { + name: "whitespace only returns error", + args: " ", + expectError: true, + }, + { + name: "existing env var", + args: "TEST_ENV_VAR", + setupEnv: map[string]string{"TEST_ENV_VAR": "test_value"}, + expected: "test_value", + }, + { + name: "missing env var returns empty", + args: "NONEXISTENT_VAR_12345", + expected: "", + }, + { + name: "missing env var with default", + args: "NONEXISTENT_VAR_12345 default_value", + expected: "default_value", + }, + { + name: "existing env var ignores default", + args: "TEST_ENV_VAR fallback", + setupEnv: map[string]string{"TEST_ENV_VAR": "actual_value"}, + expected: "actual_value", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up environment variables. + for k, v := range tt.setupEnv { + t.Setenv(k, v) + } + + result, err := fn.Execute(context.Background(), tt.args, nil) + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestEnvFunction_Execute_TooManyArgs(t *testing.T) { + fn := NewEnvFunction() + + // Test with too many arguments. + _, err := fn.Execute(context.Background(), "VAR default extra_arg", nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "accepts 1 or 2 arguments") +} + +func TestParseEnvArgs(t *testing.T) { + tests := []struct { + name string + args string + expectedName string + expectedDefault string + expectError bool + }{ + { + name: "single argument", + args: "VAR_NAME", + expectedName: "VAR_NAME", + expectedDefault: "", + }, + { + name: "two arguments", + args: "VAR_NAME default_value", + expectedName: "VAR_NAME", + expectedDefault: "default_value", + }, + { + name: "with extra whitespace", + args: " VAR_NAME default_value ", + expectedName: "VAR_NAME", + expectedDefault: "default_value", + }, + { + name: "empty args", + args: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + name, def, err := parseEnvArgs(tt.args) + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedName, name) + assert.Equal(t, tt.expectedDefault, def) + } + }) + } +} + +func TestLookupEnvFromContext(t *testing.T) { + // Nil context. + val, found := lookupEnvFromContext(nil, "TEST") + assert.False(t, found) + assert.Empty(t, val) + + // Nil stack info. + execCtx := &ExecutionContext{} + val, found = lookupEnvFromContext(execCtx, "TEST") + assert.False(t, found) + assert.Empty(t, val) +} + +func TestEnvFunction_Metadata(t *testing.T) { + fn := NewEnvFunction() + require.NotNil(t, fn) + assert.Equal(t, TagEnv, fn.Name()) + assert.Equal(t, PreMerge, fn.Phase()) +} diff --git a/pkg/function/errors.go b/pkg/function/errors.go new file mode 100644 index 0000000000..1a2df6b7a1 --- /dev/null +++ b/pkg/function/errors.go @@ -0,0 +1,24 @@ +package function + +import "errors" + +var ( + // ErrFunctionNotFound is returned when a function is not registered. + ErrFunctionNotFound = errors.New("function not found") + + // ErrFunctionAlreadyRegistered is returned when attempting to register + // a function with a name or alias that already exists. + ErrFunctionAlreadyRegistered = errors.New("function already registered") + + // ErrInvalidArguments is returned when a function receives invalid arguments. + ErrInvalidArguments = errors.New("invalid function arguments") + + // ErrExecutionFailed is returned when a function fails to execute. + ErrExecutionFailed = errors.New("function execution failed") + + // ErrCircularDependency is returned when a circular dependency is detected. + ErrCircularDependency = errors.New("circular dependency detected") + + // ErrSpecialYAMLHandling is returned when a function requires special YAML node handling. + ErrSpecialYAMLHandling = errors.New("function requires special YAML node handling") +) diff --git a/pkg/function/exec.go b/pkg/function/exec.go new file mode 100644 index 0000000000..85f407c252 --- /dev/null +++ b/pkg/function/exec.go @@ -0,0 +1,63 @@ +package function + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strings" + + log "github.com/cloudposse/atmos/pkg/logger" + "github.com/cloudposse/atmos/pkg/perf" + "github.com/cloudposse/atmos/pkg/utils" +) + +// ExecFunction implements the exec function for shell command execution. +type ExecFunction struct { + BaseFunction +} + +// NewExecFunction creates a new exec function handler. +func NewExecFunction() *ExecFunction { + defer perf.Track(nil, "function.NewExecFunction")() + + return &ExecFunction{ + BaseFunction: BaseFunction{ + FunctionName: TagExec, + FunctionAliases: nil, + FunctionPhase: PreMerge, + }, + } +} + +// Execute processes the exec function. +// Usage: +// +// !exec command args... - Execute shell command and return output +// +// If the output is valid JSON, it will be parsed and returned as the corresponding type. +// Otherwise, the raw string output is returned. +func (f *ExecFunction) Execute(ctx context.Context, args string, execCtx *ExecutionContext) (any, error) { + defer perf.Track(nil, "function.ExecFunction.Execute")() + + log.Debug("Executing exec function", "args", args) + + args = strings.TrimSpace(args) + if args == "" { + return nil, ErrInvalidArguments + } + + res, err := utils.ExecuteShellAndReturnOutput(args, YAMLTag(TagExec)+" "+args, ".", os.Environ(), false) + if err != nil { + return nil, fmt.Errorf("%w: shell execution failed: %w", ErrExecutionFailed, err) + } + + // Try to parse as JSON. + var decoded any + if err = json.Unmarshal([]byte(res), &decoded); err != nil { + log.Debug("Output is not JSON, returning as string", "error", err) + return res, nil + } + + return decoded, nil +} diff --git a/pkg/function/exec_test.go b/pkg/function/exec_test.go new file mode 100644 index 0000000000..2c92be4f0a --- /dev/null +++ b/pkg/function/exec_test.go @@ -0,0 +1,88 @@ +package function + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExecFunction_Execute_InvalidArgs(t *testing.T) { + fn := NewExecFunction() + + tests := []struct { + name string + args string + }{ + {name: "empty args", args: ""}, + {name: "whitespace only", args: " "}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := fn.Execute(context.Background(), tt.args, nil) + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidArguments) + }) + } +} + +func TestExecFunction_Execute_OutputParsing(t *testing.T) { + fn := NewExecFunction() + + tests := []struct { + name string + command string + checkFunc func(t *testing.T, result any) + }{ + { + name: "simple string output", + command: "echo hello", + checkFunc: func(t *testing.T, result any) { + assert.Equal(t, "hello\n", result) + }, + }, + { + name: "JSON object output", + command: `echo '{"key": "value"}'`, + checkFunc: func(t *testing.T, result any) { + m, ok := result.(map[string]any) + require.True(t, ok) + assert.Equal(t, "value", m["key"]) + }, + }, + { + name: "JSON array output", + command: `echo '[1, 2, 3]'`, + checkFunc: func(t *testing.T, result any) { + arr, ok := result.([]any) + require.True(t, ok) + assert.Len(t, arr, 3) + }, + }, + { + name: "non-JSON output", + command: "echo 'not json'", + checkFunc: func(t *testing.T, result any) { + assert.Equal(t, "not json\n", result) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := fn.Execute(context.Background(), tt.command, nil) + require.NoError(t, err) + tt.checkFunc(t, result) + }) + } +} + +func TestExecFunction_Metadata(t *testing.T) { + fn := NewExecFunction() + require.NotNil(t, fn) + assert.Equal(t, TagExec, fn.Name()) + assert.Equal(t, PreMerge, fn.Phase()) + assert.Nil(t, fn.Aliases()) +} diff --git a/pkg/function/function.go b/pkg/function/function.go new file mode 100644 index 0000000000..5949d5091f --- /dev/null +++ b/pkg/function/function.go @@ -0,0 +1,52 @@ +package function + +import ( + "context" + + "github.com/cloudposse/atmos/pkg/perf" +) + +// Function defines the interface for all Atmos configuration functions. +// Functions are format-agnostic and can be used in YAML, HCL, or JSON configurations. +type Function interface { + // Name returns the primary function name (e.g., "env", "terraform.output"). + Name() string + + // Aliases returns alternative names for the function. + Aliases() []string + + // Phase returns when this function should be executed. + Phase() Phase + + // Execute processes the function with the given arguments and context. + Execute(ctx context.Context, args string, execCtx *ExecutionContext) (any, error) +} + +// BaseFunction provides a reusable implementation of the Function interface. +// Embed this struct in concrete function types to inherit common behavior. +type BaseFunction struct { + FunctionName string + FunctionAliases []string + FunctionPhase Phase +} + +// Name returns the primary function name. +func (f *BaseFunction) Name() string { + defer perf.Track(nil, "function.BaseFunction.Name")() + + return f.FunctionName +} + +// Aliases returns alternative names for the function. +func (f *BaseFunction) Aliases() []string { + defer perf.Track(nil, "function.BaseFunction.Aliases")() + + return f.FunctionAliases +} + +// Phase returns when this function should be executed. +func (f *BaseFunction) Phase() Phase { + defer perf.Track(nil, "function.BaseFunction.Phase")() + + return f.FunctionPhase +} diff --git a/pkg/function/git_root.go b/pkg/function/git_root.go new file mode 100644 index 0000000000..a3348c5001 --- /dev/null +++ b/pkg/function/git_root.go @@ -0,0 +1,54 @@ +package function + +import ( + "context" + "fmt" + "os/exec" + "strings" + + "github.com/cloudposse/atmos/errors" + log "github.com/cloudposse/atmos/pkg/logger" + "github.com/cloudposse/atmos/pkg/perf" +) + +// GitRootFunction implements the repo-root function for getting the git repository root. +type GitRootFunction struct { + BaseFunction +} + +// NewGitRootFunction creates a new repo-root function handler. +func NewGitRootFunction() *GitRootFunction { + defer perf.Track(nil, "function.NewGitRootFunction")() + + return &GitRootFunction{ + BaseFunction: BaseFunction{ + FunctionName: TagRepoRoot, + FunctionAliases: []string{"git-root"}, + FunctionPhase: PreMerge, + }, + } +} + +// Execute processes the repo-root function. +// Usage: +// +// !repo-root - Returns the absolute path to the git repository root +// +// Returns an error if not in a git repository. +func (f *GitRootFunction) Execute(ctx context.Context, args string, execCtx *ExecutionContext) (any, error) { + defer perf.Track(nil, "function.GitRootFunction.Execute")() + + log.Debug("Executing repo-root function") + + cmd := exec.CommandContext(ctx, "git", "rev-parse", "--show-toplevel") + + output, err := cmd.Output() + if err != nil { + return "", fmt.Errorf("%w: failed to get git repository root: %w", errors.ErrGitCommandFailed, err) + } + + result := strings.TrimSpace(string(output)) + log.Debug("Resolved repo-root", "path", result) + + return result, nil +} diff --git a/pkg/function/git_root_test.go b/pkg/function/git_root_test.go new file mode 100644 index 0000000000..a54c2c05dc --- /dev/null +++ b/pkg/function/git_root_test.go @@ -0,0 +1,53 @@ +package function + +import ( + "context" + "os" + "os/exec" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGitRootFunction_Execute_InGitRepo(t *testing.T) { + // Skip if not in a git repository. + if _, err := exec.LookPath("git"); err != nil { + t.Skip("git not available") + } + cmd := exec.Command("git", "rev-parse", "--git-dir") + if err := cmd.Run(); err != nil { + t.Skip("not running inside a git repository") + } + + fn := NewGitRootFunction() + + result, err := fn.Execute(context.Background(), "", nil) + require.NoError(t, err) + + // Result should be a non-empty path. + path, ok := result.(string) + require.True(t, ok) + assert.NotEmpty(t, path) + + // The path should exist. + _, err = os.Stat(path) + assert.NoError(t, err) + + // The path should contain a .git directory or file. + gitDir := filepath.Join(path, ".git") + _, err = os.Stat(gitDir) + assert.NoError(t, err) +} + +func TestGitRootFunction_Metadata(t *testing.T) { + fn := NewGitRootFunction() + require.NotNil(t, fn) + assert.Equal(t, TagRepoRoot, fn.Name()) + assert.Equal(t, PreMerge, fn.Phase()) + + aliases := fn.Aliases() + require.Len(t, aliases, 1) + assert.Equal(t, "git-root", aliases[0]) +} diff --git a/pkg/function/include.go b/pkg/function/include.go new file mode 100644 index 0000000000..83cec88ac6 --- /dev/null +++ b/pkg/function/include.go @@ -0,0 +1,82 @@ +package function + +import ( + "context" + "fmt" + + log "github.com/cloudposse/atmos/pkg/logger" + "github.com/cloudposse/atmos/pkg/perf" +) + +// IncludeFunction implements the include function for including content from files. +type IncludeFunction struct { + BaseFunction +} + +// NewIncludeFunction creates a new include function handler. +func NewIncludeFunction() *IncludeFunction { + defer perf.Track(nil, "function.NewIncludeFunction")() + + return &IncludeFunction{ + BaseFunction: BaseFunction{ + FunctionName: TagInclude, + FunctionAliases: nil, + FunctionPhase: PreMerge, + }, + } +} + +// Execute processes the include function. +// Usage: +// +// !include path/to/file.yaml +// !include path/to/file.yaml .query.expression +// +// Note: The include function is special - it operates on yaml.Node directly +// and cannot return arbitrary values like other functions. The actual +// implementation remains in pkg/utils/yaml_include_by_extension.go which +// modifies the yaml.Node in-place. +// +// This function serves as a marker for the registry but the actual processing +// is handled specially in the YAML processor. +func (f *IncludeFunction) Execute(ctx context.Context, args string, execCtx *ExecutionContext) (any, error) { + defer perf.Track(nil, "function.IncludeFunction.Execute")() + + log.Debug("Executing include function", "args", args) + + // The include function requires special handling because it modifies + // yaml.Node directly. This placeholder returns an error. + return nil, fmt.Errorf("%w: include", ErrSpecialYAMLHandling) +} + +// IncludeRawFunction implements the include.raw function for including raw file content. +type IncludeRawFunction struct { + BaseFunction +} + +// NewIncludeRawFunction creates a new include.raw function handler. +func NewIncludeRawFunction() *IncludeRawFunction { + defer perf.Track(nil, "function.NewIncludeRawFunction")() + + return &IncludeRawFunction{ + BaseFunction: BaseFunction{ + FunctionName: TagIncludeRaw, + FunctionAliases: nil, + FunctionPhase: PreMerge, + }, + } +} + +// Execute processes the include.raw function. +// Usage: +// +// !include.raw path/to/file.txt +// +// Note: Like include, this function operates on yaml.Node directly. +func (f *IncludeRawFunction) Execute(ctx context.Context, args string, execCtx *ExecutionContext) (any, error) { + defer perf.Track(nil, "function.IncludeRawFunction.Execute")() + + log.Debug("Executing include.raw function", "args", args) + + return nil, fmt.Errorf("%w: include.raw", ErrSpecialYAMLHandling) +} diff --git a/pkg/function/include_test.go b/pkg/function/include_test.go new file mode 100644 index 0000000000..2da0307e11 --- /dev/null +++ b/pkg/function/include_test.go @@ -0,0 +1,72 @@ +package function + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIncludeFunctions_Execute(t *testing.T) { + tests := []struct { + name string + fn Function + args string + errContains string + }{ + { + name: "include returns ErrSpecialYAMLHandling", + fn: NewIncludeFunction(), + args: "path/to/file.yaml", + errContains: "include", + }, + { + name: "include.raw returns ErrSpecialYAMLHandling", + fn: NewIncludeRawFunction(), + args: "path/to/file.txt", + errContains: "include.raw", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := tt.fn.Execute(context.Background(), tt.args, nil) + require.Error(t, err) + assert.Nil(t, result) + assert.ErrorIs(t, err, ErrSpecialYAMLHandling) + assert.Contains(t, err.Error(), tt.errContains) + }) + } +} + +func TestIncludeFunctions_Metadata(t *testing.T) { + tests := []struct { + name string + fn Function + expectedName string + expectedTag string + }{ + { + name: "include function metadata", + fn: NewIncludeFunction(), + expectedName: TagInclude, + expectedTag: TagInclude, + }, + { + name: "include.raw function metadata", + fn: NewIncludeRawFunction(), + expectedName: TagIncludeRaw, + expectedTag: TagIncludeRaw, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.NotNil(t, tt.fn) + assert.Equal(t, tt.expectedName, tt.fn.Name()) + assert.Equal(t, PreMerge, tt.fn.Phase()) + assert.Nil(t, tt.fn.Aliases()) + }) + } +} diff --git a/pkg/function/literal.go b/pkg/function/literal.go new file mode 100644 index 0000000000..649278d6a6 --- /dev/null +++ b/pkg/function/literal.go @@ -0,0 +1,50 @@ +package function + +import ( + "context" + "strings" + + log "github.com/cloudposse/atmos/pkg/logger" + "github.com/cloudposse/atmos/pkg/perf" +) + +// LiteralFunction implements the literal function for preserving values as-is. +// This bypasses template processing to preserve template-like syntax ({{...}}, ${...}) +// for downstream tools like Terraform, Helm, and ArgoCD. +type LiteralFunction struct { + BaseFunction +} + +// NewLiteralFunction creates a new literal function handler. +func NewLiteralFunction() *LiteralFunction { + defer perf.Track(nil, "function.NewLiteralFunction")() + + return &LiteralFunction{ + BaseFunction: BaseFunction{ + FunctionName: TagLiteral, + FunctionAliases: nil, + FunctionPhase: PreMerge, // Must run before template processing. + }, + } +} + +// Execute processes the literal function. +// Usage: +// +// !literal "{{external.email}}" +// !literal "{{ .Values.ingress.class }}" +// !literal | +// #!/bin/bash +// echo "Hello ${USER}" +// +// The function returns the argument exactly as provided, preserving any +// template-like syntax that would otherwise be processed by Atmos. +func (f *LiteralFunction) Execute(ctx context.Context, args string, execCtx *ExecutionContext) (any, error) { + defer perf.Track(nil, "function.LiteralFunction.Execute")() + + log.Debug("Executing literal function", "args", args) + + // Return the value as-is, preserving any template syntax. + // The args string contains whatever follows the !literal tag. + return strings.TrimSpace(args), nil +} diff --git a/pkg/function/phase.go b/pkg/function/phase.go new file mode 100644 index 0000000000..a5e345edeb --- /dev/null +++ b/pkg/function/phase.go @@ -0,0 +1,26 @@ +package function + +// Phase represents when a function should be executed during configuration processing. +type Phase int + +const ( + // PreMerge functions are executed during initial file loading, before + // configuration merging. Examples: !env, !exec, !include, !random. + PreMerge Phase = iota + + // PostMerge functions are executed after configuration merging, when the + // full stack context is available. Examples: !terraform.output, !store.get. + PostMerge +) + +// String returns a human-readable representation of the phase. +func (p Phase) String() string { + switch p { + case PreMerge: + return "pre-merge" + case PostMerge: + return "post-merge" + default: + return "unknown" + } +} diff --git a/pkg/function/phase_test.go b/pkg/function/phase_test.go new file mode 100644 index 0000000000..05ac4ec471 --- /dev/null +++ b/pkg/function/phase_test.go @@ -0,0 +1,43 @@ +package function + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPhase_String(t *testing.T) { + tests := []struct { + name string + phase Phase + expected string + }{ + { + name: "PreMerge", + phase: PreMerge, + expected: "pre-merge", + }, + { + name: "PostMerge", + phase: PostMerge, + expected: "post-merge", + }, + { + name: "Unknown phase", + phase: Phase(99), + expected: "unknown", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.phase.String() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestPhase_Ordering(t *testing.T) { + // Verify PreMerge executes before PostMerge (ordering matters for function execution). + assert.True(t, PreMerge < PostMerge) +} diff --git a/pkg/function/random.go b/pkg/function/random.go new file mode 100644 index 0000000000..e5abb2f20c --- /dev/null +++ b/pkg/function/random.go @@ -0,0 +1,119 @@ +package function + +import ( + "context" + "crypto/rand" + "fmt" + "math/big" + "strconv" + "strings" + + log "github.com/cloudposse/atmos/pkg/logger" + "github.com/cloudposse/atmos/pkg/perf" + "github.com/cloudposse/atmos/pkg/utils" +) + +const ( + // Default range for random when no arguments provided. + defaultRandomMin = 0 + defaultRandomMax = 65535 +) + +// RandomFunction implements the random function for generating random numbers. +type RandomFunction struct { + BaseFunction +} + +// NewRandomFunction creates a new random function handler. +func NewRandomFunction() *RandomFunction { + defer perf.Track(nil, "function.NewRandomFunction")() + + return &RandomFunction{ + BaseFunction: BaseFunction{ + FunctionName: TagRandom, + FunctionAliases: nil, + FunctionPhase: PreMerge, + }, + } +} + +// Execute processes the random function. +// Usage: +// +// !random - Generate random number between 0 and 65535 +// !random max - Generate random number between 0 and max +// !random min max - Generate random number between min and max +func (f *RandomFunction) Execute(ctx context.Context, args string, execCtx *ExecutionContext) (any, error) { + defer perf.Track(nil, "function.RandomFunction.Execute")() + + log.Debug("Executing random function", "args", args) + + args = strings.TrimSpace(args) + + // No arguments: use defaults. + if args == "" { + return generateRandom(defaultRandomMin, defaultRandomMax) + } + + parts, err := utils.SplitStringByDelimiter(args, ' ') + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrInvalidArguments, args) + } + + var min, max int + + switch len(parts) { + case 0: + min = defaultRandomMin + max = defaultRandomMax + + case 1: + // One argument: treat as max, min defaults to 0. + min = 0 + maxStr := strings.TrimSpace(parts[0]) + max, err = strconv.Atoi(maxStr) + if err != nil { + return nil, fmt.Errorf("%w: invalid max value '%s', must be an integer", ErrInvalidArguments, maxStr) + } + + case 2: + // Two arguments: min and max. + minStr := strings.TrimSpace(parts[0]) + maxStr := strings.TrimSpace(parts[1]) + + min, err = strconv.Atoi(minStr) + if err != nil { + return nil, fmt.Errorf("%w: invalid min value '%s', must be an integer", ErrInvalidArguments, minStr) + } + + max, err = strconv.Atoi(maxStr) + if err != nil { + return nil, fmt.Errorf("%w: invalid max value '%s', must be an integer", ErrInvalidArguments, maxStr) + } + + default: + return nil, fmt.Errorf("%w: random function accepts 0, 1, or 2 arguments, got %d", ErrInvalidArguments, len(parts)) + } + + return generateRandom(min, max) +} + +// generateRandom generates a cryptographically secure random number in the range [min, max]. +func generateRandom(min, max int) (int, error) { + if min >= max { + return 0, fmt.Errorf("%w: min value (%d) must be less than max value (%d)", ErrInvalidArguments, min, max) + } + + // Generate cryptographically secure random number in range [min, max]. + rangeSize := int64(max - min + 1) + n, err := rand.Int(rand.Reader, big.NewInt(rangeSize)) + if err != nil { + return 0, fmt.Errorf("failed to generate random number: %w", err) + } + + result := int(n.Int64()) + min + + log.Debug("Generated random number", "min", min, "max", max, "result", result) + + return result, nil +} diff --git a/pkg/function/random_test.go b/pkg/function/random_test.go new file mode 100644 index 0000000000..8860089e35 --- /dev/null +++ b/pkg/function/random_test.go @@ -0,0 +1,116 @@ +package function + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRandomFunction_Execute_ErrorCases(t *testing.T) { + fn := NewRandomFunction() + + tests := []struct { + name string + args string + expectError error + errorMsg string + }{ + { + name: "invalid max value", + args: "not-a-number", + expectError: ErrInvalidArguments, + errorMsg: "invalid max value", + }, + { + name: "invalid min value", + args: "not-a-number 100", + expectError: ErrInvalidArguments, + errorMsg: "invalid min value", + }, + { + name: "invalid max value with valid min", + args: "10 not-a-number", + expectError: ErrInvalidArguments, + errorMsg: "invalid max value", + }, + { + name: "too many arguments", + args: "1 2 3", + expectError: ErrInvalidArguments, + errorMsg: "accepts 0, 1, or 2 arguments", + }, + { + name: "min equals max", + args: "10 10", + expectError: ErrInvalidArguments, + errorMsg: "min value", + }, + { + name: "min greater than max", + args: "100 10", + expectError: ErrInvalidArguments, + errorMsg: "min value", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := fn.Execute(context.Background(), tt.args, nil) + require.Error(t, err) + assert.True(t, errors.Is(err, tt.expectError)) + assert.Contains(t, err.Error(), tt.errorMsg) + }) + } +} + +func TestRandomFunction_Execute_BoundaryValues(t *testing.T) { + fn := NewRandomFunction() + + // Test with min=0, max=1 (only two possible values). + for i := 0; i < 10; i++ { + result, err := fn.Execute(context.Background(), "0 1", nil) + require.NoError(t, err) + val, ok := result.(int) + require.True(t, ok) + assert.GreaterOrEqual(t, val, 0) + assert.LessOrEqual(t, val, 1) + } +} + +func TestRandomFunction_Execute_NegativeValues(t *testing.T) { + fn := NewRandomFunction() + + // Test with negative min. + result, err := fn.Execute(context.Background(), "-10 10", nil) + require.NoError(t, err) + val, ok := result.(int) + require.True(t, ok) + assert.GreaterOrEqual(t, val, -10) + assert.LessOrEqual(t, val, 10) +} + +func TestNewRandomFunction(t *testing.T) { + fn := NewRandomFunction() + require.NotNil(t, fn) + assert.Equal(t, TagRandom, fn.Name()) + assert.Equal(t, PreMerge, fn.Phase()) + assert.Nil(t, fn.Aliases()) +} + +func TestGenerateRandom(t *testing.T) { + // Test basic range. + for i := 0; i < 100; i++ { + result, err := generateRandom(0, 100) + require.NoError(t, err) + assert.GreaterOrEqual(t, result, 0) + assert.LessOrEqual(t, result, 100) + } + + // Test error case. + _, err := generateRandom(100, 50) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrInvalidArguments)) +} diff --git a/pkg/function/registry.go b/pkg/function/registry.go new file mode 100644 index 0000000000..fe9bcbb409 --- /dev/null +++ b/pkg/function/registry.go @@ -0,0 +1,196 @@ +package function + +import ( + "strings" + "sync" + + "github.com/cloudposse/atmos/pkg/perf" +) + +// Registry is a thread-safe registry for Function implementations. +type Registry struct { + mu sync.RWMutex + functions map[string]Function + aliases map[string]string // alias -> primary name +} + +// NewRegistry creates a new empty function registry. +func NewRegistry() *Registry { + defer perf.Track(nil, "function.NewRegistry")() + + return &Registry{ + functions: make(map[string]Function), + aliases: make(map[string]string), + } +} + +// defaultRegistry is the global registry instance. +var ( + defaultRegistry *Registry + defaultRegistryOnce sync.Once +) + +// DefaultRegistry returns the global function registry. +func DefaultRegistry() *Registry { + defer perf.Track(nil, "function.DefaultRegistry")() + + defaultRegistryOnce.Do(func() { + defaultRegistry = NewRegistry() + }) + return defaultRegistry +} + +// Register adds a function to the registry. +// Returns an error if the name or any alias is already registered. +func (r *Registry) Register(fn Function) error { + defer perf.Track(nil, "function.Registry.Register")() + + r.mu.Lock() + defer r.mu.Unlock() + + name := strings.ToLower(fn.Name()) + + // Check if primary name conflicts. + if _, exists := r.functions[name]; exists { + return ErrFunctionAlreadyRegistered + } + if _, exists := r.aliases[name]; exists { + return ErrFunctionAlreadyRegistered + } + + // Check if any alias conflicts. + for _, alias := range fn.Aliases() { + alias = strings.ToLower(alias) + if _, exists := r.functions[alias]; exists { + return ErrFunctionAlreadyRegistered + } + if _, exists := r.aliases[alias]; exists { + return ErrFunctionAlreadyRegistered + } + } + + // Register the function and its aliases. + r.functions[name] = fn + for _, alias := range fn.Aliases() { + r.aliases[strings.ToLower(alias)] = name + } + + return nil +} + +// Get retrieves a function by name or alias. +// Returns ErrFunctionNotFound if the function is not registered. +func (r *Registry) Get(name string) (Function, error) { + defer perf.Track(nil, "function.Registry.Get")() + + r.mu.RLock() + defer r.mu.RUnlock() + + name = strings.ToLower(name) + + // Check primary names first. + if fn, exists := r.functions[name]; exists { + return fn, nil + } + + // Check aliases. + if primaryName, exists := r.aliases[name]; exists { + if fn, exists := r.functions[primaryName]; exists { + return fn, nil + } + } + + return nil, ErrFunctionNotFound +} + +// Has checks if a function is registered by name or alias. +func (r *Registry) Has(name string) bool { + defer perf.Track(nil, "function.Registry.Has")() + + r.mu.RLock() + defer r.mu.RUnlock() + + name = strings.ToLower(name) + + if _, exists := r.functions[name]; exists { + return true + } + if _, exists := r.aliases[name]; exists { + return true + } + return false +} + +// GetByPhase returns all functions that should execute in the given phase. +func (r *Registry) GetByPhase(phase Phase) []Function { + defer perf.Track(nil, "function.Registry.GetByPhase")() + + r.mu.RLock() + defer r.mu.RUnlock() + + var result []Function + for _, fn := range r.functions { + if fn.Phase() == phase { + result = append(result, fn) + } + } + return result +} + +// List returns all registered function names. +func (r *Registry) List() []string { + defer perf.Track(nil, "function.Registry.List")() + + r.mu.RLock() + defer r.mu.RUnlock() + + names := make([]string, 0, len(r.functions)) + for name := range r.functions { + names = append(names, name) + } + return names +} + +// Len returns the number of registered functions. +func (r *Registry) Len() int { + defer perf.Track(nil, "function.Registry.Len")() + + r.mu.RLock() + defer r.mu.RUnlock() + + return len(r.functions) +} + +// Unregister removes a function from the registry. +func (r *Registry) Unregister(name string) { + defer perf.Track(nil, "function.Registry.Unregister")() + + r.mu.Lock() + defer r.mu.Unlock() + + name = strings.ToLower(name) + + fn, exists := r.functions[name] + if !exists { + return + } + + // Remove aliases first. + for _, alias := range fn.Aliases() { + delete(r.aliases, strings.ToLower(alias)) + } + + // Remove the function. + delete(r.functions, name) +} + +// Clear removes all functions from the registry. +func (r *Registry) Clear() { + defer perf.Track(nil, "function.Registry.Clear")() + + r.mu.Lock() + defer r.mu.Unlock() + + r.functions = make(map[string]Function) + r.aliases = make(map[string]string) +} diff --git a/pkg/function/registry_test.go b/pkg/function/registry_test.go new file mode 100644 index 0000000000..819e753da6 --- /dev/null +++ b/pkg/function/registry_test.go @@ -0,0 +1,431 @@ +package function + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRegistry_RegisterAndGet(t *testing.T) { + registry := NewRegistry() + + // Create a test function. + fn := &EnvFunction{ + BaseFunction: BaseFunction{ + FunctionName: "test-env", + FunctionAliases: []string{"test-e"}, + FunctionPhase: PreMerge, + }, + } + + // Register the function. + err := registry.Register(fn) + require.NoError(t, err) + + // Get by primary name. + got, err := registry.Get("test-env") + require.NoError(t, err) + assert.Equal(t, "test-env", got.Name()) + + // Get by alias. + got, err = registry.Get("test-e") + require.NoError(t, err) + assert.Equal(t, "test-env", got.Name()) + + // Get non-existent function. + _, err = registry.Get("non-existent") + assert.ErrorIs(t, err, ErrFunctionNotFound) +} + +func TestRegistry_DuplicateRegistration(t *testing.T) { + registry := NewRegistry() + + fn1 := &EnvFunction{ + BaseFunction: BaseFunction{ + FunctionName: "dup-test", + FunctionPhase: PreMerge, + }, + } + + fn2 := &EnvFunction{ + BaseFunction: BaseFunction{ + FunctionName: "dup-test", + FunctionPhase: PreMerge, + }, + } + + // First registration should succeed. + err := registry.Register(fn1) + require.NoError(t, err) + + // Duplicate registration should fail. + err = registry.Register(fn2) + assert.ErrorIs(t, err, ErrFunctionAlreadyRegistered) +} + +func TestRegistry_GetByPhase(t *testing.T) { + registry := NewRegistry() + + // Register PreMerge function. + preMergeFn := &EnvFunction{ + BaseFunction: BaseFunction{ + FunctionName: "pre-merge-fn", + FunctionPhase: PreMerge, + }, + } + require.NoError(t, registry.Register(preMergeFn)) + + // Register PostMerge function. + postMergeFn := &StoreFunction{ + BaseFunction: BaseFunction{ + FunctionName: "post-merge-fn", + FunctionPhase: PostMerge, + }, + } + require.NoError(t, registry.Register(postMergeFn)) + + // Get PreMerge functions. + preMergeFns := registry.GetByPhase(PreMerge) + assert.Len(t, preMergeFns, 1) + assert.Equal(t, "pre-merge-fn", preMergeFns[0].Name()) + + // Get PostMerge functions. + postMergeFns := registry.GetByPhase(PostMerge) + assert.Len(t, postMergeFns, 1) + assert.Equal(t, "post-merge-fn", postMergeFns[0].Name()) +} + +func TestRegistry_Has(t *testing.T) { + registry := NewRegistry() + + fn := &EnvFunction{ + BaseFunction: BaseFunction{ + FunctionName: "has-test", + FunctionAliases: []string{"has-alias"}, + FunctionPhase: PreMerge, + }, + } + require.NoError(t, registry.Register(fn)) + + assert.True(t, registry.Has("has-test")) + assert.True(t, registry.Has("has-alias")) + assert.False(t, registry.Has("non-existent")) +} + +func TestRegistry_Unregister(t *testing.T) { + registry := NewRegistry() + + fn := &EnvFunction{ + BaseFunction: BaseFunction{ + FunctionName: "unreg-test", + FunctionAliases: []string{"unreg-alias"}, + FunctionPhase: PreMerge, + }, + } + require.NoError(t, registry.Register(fn)) + + // Verify it exists. + assert.True(t, registry.Has("unreg-test")) + assert.True(t, registry.Has("unreg-alias")) + + // Unregister. + registry.Unregister("unreg-test") + + // Verify it's gone. + assert.False(t, registry.Has("unreg-test")) + assert.False(t, registry.Has("unreg-alias")) +} + +func TestDefaultRegistry_HasAllDefaults(t *testing.T) { + // Force re-initialization. + registry := DefaultRegistry() + + // Check that default functions are registered. + expectedFunctions := []string{ + TagEnv, + TagExec, + TagRandom, + TagTemplate, + TagRepoRoot, + TagInclude, + TagIncludeRaw, + TagStore, + TagStoreGet, + TagTerraformOutput, + TagTerraformState, + TagAwsAccountID, + TagAwsCallerIdentityArn, + TagAwsCallerIdentityUserID, + TagAwsRegion, + } + + for _, name := range expectedFunctions { + assert.True(t, registry.Has(name), "expected function %s to be registered", name) + } +} + +func TestEnvFunction_Execute(t *testing.T) { + fn := NewEnvFunction() + + // Set up test environment variable. + t.Setenv("TEST_VAR", "test_value") + + // Test basic env lookup. + result, err := fn.Execute(context.Background(), "TEST_VAR", nil) + require.NoError(t, err) + assert.Equal(t, "test_value", result) + + // Test with default value for missing variable. + result, err = fn.Execute(context.Background(), "MISSING_VAR default_value", nil) + require.NoError(t, err) + assert.Equal(t, "default_value", result) + + // Test missing variable without default. + result, err = fn.Execute(context.Background(), "MISSING_VAR", nil) + require.NoError(t, err) + assert.Equal(t, "", result) +} + +func TestRandomFunction_Execute(t *testing.T) { + fn := NewRandomFunction() + + // Test with no arguments (default range). + result, err := fn.Execute(context.Background(), "", nil) + require.NoError(t, err) + val, ok := result.(int) + require.True(t, ok) + assert.GreaterOrEqual(t, val, 0) + assert.LessOrEqual(t, val, 65535) + + // Test with max only. + result, err = fn.Execute(context.Background(), "100", nil) + require.NoError(t, err) + val, ok = result.(int) + require.True(t, ok) + assert.GreaterOrEqual(t, val, 0) + assert.LessOrEqual(t, val, 100) + + // Test with min and max. + result, err = fn.Execute(context.Background(), "10 20", nil) + require.NoError(t, err) + val, ok = result.(int) + require.True(t, ok) + assert.GreaterOrEqual(t, val, 10) + assert.LessOrEqual(t, val, 20) +} + +func TestTemplateFunction_Execute(t *testing.T) { + fn := NewTemplateFunction() + + // Test JSON object. + result, err := fn.Execute(context.Background(), `{"key": "value"}`, nil) + require.NoError(t, err) + m, ok := result.(map[string]any) + require.True(t, ok) + assert.Equal(t, "value", m["key"]) + + // Test JSON array. + result, err = fn.Execute(context.Background(), `[1, 2, 3]`, nil) + require.NoError(t, err) + arr, ok := result.([]any) + require.True(t, ok) + assert.Len(t, arr, 3) + + // Test non-JSON string. + result, err = fn.Execute(context.Background(), "not json", nil) + require.NoError(t, err) + assert.Equal(t, "not json", result) +} + +func TestRegistry_List(t *testing.T) { + registry := NewRegistry() + + // Empty registry. + names := registry.List() + assert.Empty(t, names) + + // Register some functions. + fn1 := &EnvFunction{ + BaseFunction: BaseFunction{ + FunctionName: "func1", + FunctionPhase: PreMerge, + }, + } + fn2 := &EnvFunction{ + BaseFunction: BaseFunction{ + FunctionName: "func2", + FunctionPhase: PreMerge, + }, + } + require.NoError(t, registry.Register(fn1)) + require.NoError(t, registry.Register(fn2)) + + names = registry.List() + assert.Len(t, names, 2) + assert.Contains(t, names, "func1") + assert.Contains(t, names, "func2") +} + +func TestRegistry_Len(t *testing.T) { + registry := NewRegistry() + + // Empty registry. + assert.Equal(t, 0, registry.Len()) + + // Register a function. + fn := &EnvFunction{ + BaseFunction: BaseFunction{ + FunctionName: "len-test", + FunctionPhase: PreMerge, + }, + } + require.NoError(t, registry.Register(fn)) + assert.Equal(t, 1, registry.Len()) +} + +func TestRegistry_Clear(t *testing.T) { + registry := NewRegistry() + + // Register some functions. + fn1 := &EnvFunction{ + BaseFunction: BaseFunction{ + FunctionName: "clear-test1", + FunctionAliases: []string{"alias1"}, + FunctionPhase: PreMerge, + }, + } + fn2 := &EnvFunction{ + BaseFunction: BaseFunction{ + FunctionName: "clear-test2", + FunctionPhase: PreMerge, + }, + } + require.NoError(t, registry.Register(fn1)) + require.NoError(t, registry.Register(fn2)) + assert.Equal(t, 2, registry.Len()) + + // Clear the registry. + registry.Clear() + assert.Equal(t, 0, registry.Len()) + assert.False(t, registry.Has("clear-test1")) + assert.False(t, registry.Has("alias1")) + assert.False(t, registry.Has("clear-test2")) +} + +func TestRegistry_AliasConflicts(t *testing.T) { + registry := NewRegistry() + + // Register a function with alias. + fn1 := &EnvFunction{ + BaseFunction: BaseFunction{ + FunctionName: "alias-conflict-1", + FunctionAliases: []string{"shared-alias"}, + FunctionPhase: PreMerge, + }, + } + require.NoError(t, registry.Register(fn1)) + + // Try to register another function with same alias. + fn2 := &EnvFunction{ + BaseFunction: BaseFunction{ + FunctionName: "alias-conflict-2", + FunctionAliases: []string{"shared-alias"}, + FunctionPhase: PreMerge, + }, + } + err := registry.Register(fn2) + assert.ErrorIs(t, err, ErrFunctionAlreadyRegistered) +} + +func TestRegistry_NameConflictsWithAlias(t *testing.T) { + registry := NewRegistry() + + // Register a function with alias. + fn1 := &EnvFunction{ + BaseFunction: BaseFunction{ + FunctionName: "name-alias-conflict", + FunctionAliases: []string{"will-conflict"}, + FunctionPhase: PreMerge, + }, + } + require.NoError(t, registry.Register(fn1)) + + // Try to register function where name matches existing alias. + fn2 := &EnvFunction{ + BaseFunction: BaseFunction{ + FunctionName: "will-conflict", + FunctionPhase: PreMerge, + }, + } + err := registry.Register(fn2) + assert.ErrorIs(t, err, ErrFunctionAlreadyRegistered) +} + +func TestRegistry_AliasConflictsWithName(t *testing.T) { + registry := NewRegistry() + + // Register a function. + fn1 := &EnvFunction{ + BaseFunction: BaseFunction{ + FunctionName: "existing-name", + FunctionPhase: PreMerge, + }, + } + require.NoError(t, registry.Register(fn1)) + + // Try to register function where alias matches existing name. + fn2 := &EnvFunction{ + BaseFunction: BaseFunction{ + FunctionName: "new-func", + FunctionAliases: []string{"existing-name"}, + FunctionPhase: PreMerge, + }, + } + err := registry.Register(fn2) + assert.ErrorIs(t, err, ErrFunctionAlreadyRegistered) +} + +func TestRegistry_Unregister_NonExistent(t *testing.T) { + registry := NewRegistry() + + // Unregister non-existent function should not panic. + registry.Unregister("non-existent") + assert.Equal(t, 0, registry.Len()) +} + +func TestRegistry_CaseInsensitive(t *testing.T) { + registry := NewRegistry() + + fn := &EnvFunction{ + BaseFunction: BaseFunction{ + FunctionName: "CaseSensitive", + FunctionAliases: []string{"ALIAS"}, + FunctionPhase: PreMerge, + }, + } + require.NoError(t, registry.Register(fn)) + + // Test case-insensitive lookup. + assert.True(t, registry.Has("casesensitive")) + assert.True(t, registry.Has("CASESENSITIVE")) + assert.True(t, registry.Has("alias")) + assert.True(t, registry.Has("ALIAS")) + + // Get returns the function with original name case. + got, err := registry.Get("CASESENSITIVE") + require.NoError(t, err) + assert.Equal(t, "CaseSensitive", got.Name()) +} + +func TestRegistry_GetByPhase_Empty(t *testing.T) { + registry := NewRegistry() + + // Empty registry should return empty slice. + preMerge := registry.GetByPhase(PreMerge) + assert.Empty(t, preMerge) + + postMerge := registry.GetByPhase(PostMerge) + assert.Empty(t, postMerge) +} diff --git a/pkg/function/resolution/context.go b/pkg/function/resolution/context.go new file mode 100644 index 0000000000..761b79d408 --- /dev/null +++ b/pkg/function/resolution/context.go @@ -0,0 +1,218 @@ +package resolution + +import ( + "fmt" + "runtime" + "strings" + "sync" + "sync/atomic" + + errUtils "github.com/cloudposse/atmos/errors" + "github.com/cloudposse/atmos/pkg/perf" + "github.com/cloudposse/atmos/pkg/schema" +) + +// DependencyNode represents a single node in the dependency resolution chain. +type DependencyNode struct { + Component string + Stack string + FunctionType string // "terraform.state", "terraform.output", "atmos.Component". + FunctionCall string // Full function call for error reporting. +} + +// Context tracks the call stack during YAML function resolution to detect circular dependencies. +type Context struct { + CallStack []DependencyNode + Visited map[string]bool // Map of "stack-component" to track visited nodes. +} + +// goroutineContexts maps goroutine IDs to their resolution contexts. +var goroutineContexts sync.Map + +// NewContext creates a new resolution context for cycle detection. +func NewContext() *Context { + defer perf.Track(nil, "resolution.NewContext")() + + return &Context{ + CallStack: make([]DependencyNode, 0), + Visited: make(map[string]bool), + } +} + +const ( + // Initial buffer size for capturing goroutine stack traces. + goroutineStackBufSize = 64 + // Maximum buffer size to prevent unbounded growth. + maxGoroutineStackBufSize = 8192 +) + +// unknownIDCounter is used to generate unique fallback IDs when goroutine ID parsing fails. +var unknownIDCounter uint64 + +// getGoroutineID returns the current goroutine ID. +// Returns a unique "unknown-N" identifier if parsing fails to prevent panics +// and avoid metric collisions when multiple goroutines hit the fallback path. +func getGoroutineID() string { + // Allocate buffer and grow it if needed to avoid truncation. + buf := make([]byte, goroutineStackBufSize) + for { + n := runtime.Stack(buf, false) + if n < len(buf) { + // Buffer was large enough. + buf = buf[:n] + break + } + // Buffer was too small, double it and try again. + if len(buf) >= maxGoroutineStackBufSize { + // Safety limit reached, return unique fallback ID. + return fmt.Sprintf("unknown-%d", atomic.AddUint64(&unknownIDCounter, 1)) + } + buf = make([]byte, len(buf)*2) + } + + // Format: "goroutine 123 [running]:\n..." + // Parse defensively to avoid panics. + fields := strings.Fields(string(buf)) + if len(fields) < 2 { + return fmt.Sprintf("unknown-%d", atomic.AddUint64(&unknownIDCounter, 1)) + } + + // Extract the number after "goroutine ". + return fields[1] +} + +// GetOrCreate gets or creates a resolution context for the current goroutine. +func GetOrCreate() *Context { + defer perf.Track(nil, "resolution.GetOrCreate")() + + gid := getGoroutineID() + + if ctx, ok := goroutineContexts.Load(gid); ok { + return ctx.(*Context) + } + + ctx := NewContext() + goroutineContexts.Store(gid, ctx) + return ctx +} + +// Clear clears the resolution context for the current goroutine. +func Clear() { + defer perf.Track(nil, "resolution.Clear")() + + gid := getGoroutineID() + goroutineContexts.Delete(gid) +} + +// Scoped creates a new scoped resolution context and returns a restore function. +// This prevents memory leaks and cross-call contamination by ensuring contexts are cleaned up. +// Usage: +// +// restoreCtx := resolution.Scoped() +// defer restoreCtx() +func Scoped() func() { + defer perf.Track(nil, "resolution.Scoped")() + + gid := getGoroutineID() + + // Save the existing context (if any). + var savedCtx *Context + if ctx, ok := goroutineContexts.Load(gid); ok { + savedCtx = ctx.(*Context) + } + + // Install a fresh context. + freshCtx := NewContext() + goroutineContexts.Store(gid, freshCtx) + + // Return a restore function that reinstates the saved context or clears it. + return func() { + if savedCtx != nil { + goroutineContexts.Store(gid, savedCtx) + } else { + goroutineContexts.Delete(gid) + } + } +} + +// Push adds a node to the call stack and checks for circular dependencies. +func (ctx *Context) Push(atmosConfig *schema.AtmosConfiguration, node DependencyNode) error { + defer perf.Track(atmosConfig, "resolution.Context.Push")() + + key := fmt.Sprintf("%s-%s", node.Stack, node.Component) + + // Check if we've already visited this node. + if ctx.Visited[key] { + return ctx.buildCircularDependencyError(node) + } + + // Mark as visited and add to call stack. + ctx.Visited[key] = true + ctx.CallStack = append(ctx.CallStack, node) + + return nil +} + +// Pop removes the top node from the call stack. +func (ctx *Context) Pop(atmosConfig *schema.AtmosConfiguration) { + defer perf.Track(atmosConfig, "resolution.Context.Pop")() + + if len(ctx.CallStack) > 0 { + lastIdx := len(ctx.CallStack) - 1 + node := ctx.CallStack[lastIdx] + key := fmt.Sprintf("%s-%s", node.Stack, node.Component) + + // Remove from visited set. + delete(ctx.Visited, key) + + // Remove from call stack. + ctx.CallStack = ctx.CallStack[:lastIdx] + } +} + +// buildCircularDependencyError creates a detailed error message showing the dependency chain. +func (ctx *Context) buildCircularDependencyError(newNode DependencyNode) error { + var builder strings.Builder + + builder.WriteString("Dependency chain:\n") + + // Show the full call stack. + for i, node := range ctx.CallStack { + builder.WriteString(fmt.Sprintf(" %d. Component '%s' in stack '%s'\n", + i+1, node.Component, node.Stack)) + builder.WriteString(fmt.Sprintf(" → %s\n", node.FunctionCall)) + } + + // Show where the cycle completes. + builder.WriteString(fmt.Sprintf(" %d. Component '%s' in stack '%s' (cycle detected)\n", + len(ctx.CallStack)+1, newNode.Component, newNode.Stack)) + builder.WriteString(fmt.Sprintf(" → %s\n\n", newNode.FunctionCall)) + + builder.WriteString("To fix this issue:\n") + builder.WriteString(" - Review your component dependencies and break the circular reference\n") + builder.WriteString(" - Consider using Terraform data sources or direct remote state instead\n") + builder.WriteString(" - Ensure dependencies flow in one direction only\n") + + return fmt.Errorf("%w: %s", errUtils.ErrCircularDependency, builder.String()) +} + +// Clone creates a copy of the resolution context for use in concurrent operations. +func (ctx *Context) Clone() *Context { + defer perf.Track(nil, "resolution.Context.Clone")() + + if ctx == nil { + return nil + } + + newCtx := &Context{ + CallStack: make([]DependencyNode, len(ctx.CallStack)), + Visited: make(map[string]bool, len(ctx.Visited)), + } + + copy(newCtx.CallStack, ctx.CallStack) + for k, v := range ctx.Visited { + newCtx.Visited[k] = v + } + + return newCtx +} diff --git a/pkg/function/resolution/context_test.go b/pkg/function/resolution/context_test.go new file mode 100644 index 0000000000..0990f5a5fb --- /dev/null +++ b/pkg/function/resolution/context_test.go @@ -0,0 +1,396 @@ +package resolution + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + errUtils "github.com/cloudposse/atmos/errors" + "github.com/cloudposse/atmos/pkg/schema" +) + +func TestNewContext(t *testing.T) { + ctx := NewContext() + + require.NotNil(t, ctx) + assert.Empty(t, ctx.CallStack) + assert.NotNil(t, ctx.Visited) + assert.Empty(t, ctx.Visited) +} + +func TestContext_Push_Success(t *testing.T) { + ctx := NewContext() + atmosConfig := &schema.AtmosConfiguration{} + + node := DependencyNode{ + Component: "vpc", + Stack: "tenant1-ue2-dev", + FunctionType: "terraform.output", + FunctionCall: "!terraform.output vpc outputs", + } + + err := ctx.Push(atmosConfig, node) + require.NoError(t, err) + + assert.Len(t, ctx.CallStack, 1) + assert.Equal(t, node, ctx.CallStack[0]) + assert.True(t, ctx.Visited["tenant1-ue2-dev-vpc"]) +} + +func TestContext_Push_CircularDependency(t *testing.T) { + ctx := NewContext() + atmosConfig := &schema.AtmosConfiguration{} + + node1 := DependencyNode{ + Component: "vpc", + Stack: "tenant1-ue2-dev", + FunctionType: "terraform.output", + FunctionCall: "!terraform.output vpc outputs", + } + + node2 := DependencyNode{ + Component: "eks", + Stack: "tenant1-ue2-dev", + FunctionType: "terraform.output", + FunctionCall: "!terraform.output eks cluster_name", + } + + // Push first node. + err := ctx.Push(atmosConfig, node1) + require.NoError(t, err) + + // Push second node. + err = ctx.Push(atmosConfig, node2) + require.NoError(t, err) + + // Try to push first node again - should detect cycle. + err = ctx.Push(atmosConfig, node1) + require.Error(t, err) + assert.ErrorIs(t, err, errUtils.ErrCircularDependency) + assert.Contains(t, err.Error(), "vpc") + assert.Contains(t, err.Error(), "tenant1-ue2-dev") + assert.Contains(t, err.Error(), "cycle detected") +} + +func TestContext_Pop(t *testing.T) { + ctx := NewContext() + atmosConfig := &schema.AtmosConfiguration{} + + node1 := DependencyNode{ + Component: "vpc", + Stack: "tenant1-ue2-dev", + FunctionType: "terraform.output", + FunctionCall: "!terraform.output vpc outputs", + } + + node2 := DependencyNode{ + Component: "eks", + Stack: "tenant1-ue2-dev", + FunctionType: "terraform.output", + FunctionCall: "!terraform.output eks cluster_name", + } + + // Push two nodes. + require.NoError(t, ctx.Push(atmosConfig, node1)) + require.NoError(t, ctx.Push(atmosConfig, node2)) + + assert.Len(t, ctx.CallStack, 2) + assert.True(t, ctx.Visited["tenant1-ue2-dev-vpc"]) + assert.True(t, ctx.Visited["tenant1-ue2-dev-eks"]) + + // Pop the second node. + ctx.Pop(atmosConfig) + + assert.Len(t, ctx.CallStack, 1) + assert.True(t, ctx.Visited["tenant1-ue2-dev-vpc"]) + assert.False(t, ctx.Visited["tenant1-ue2-dev-eks"]) + + // Pop the first node. + ctx.Pop(atmosConfig) + + assert.Empty(t, ctx.CallStack) + assert.False(t, ctx.Visited["tenant1-ue2-dev-vpc"]) +} + +func TestContext_Pop_EmptyStack(t *testing.T) { + ctx := NewContext() + atmosConfig := &schema.AtmosConfiguration{} + + // Pop on empty stack should not panic. + ctx.Pop(atmosConfig) + + assert.Empty(t, ctx.CallStack) +} + +func TestContext_Clone(t *testing.T) { + ctx := NewContext() + atmosConfig := &schema.AtmosConfiguration{} + + node := DependencyNode{ + Component: "vpc", + Stack: "tenant1-ue2-dev", + FunctionType: "terraform.output", + FunctionCall: "!terraform.output vpc outputs", + } + + require.NoError(t, ctx.Push(atmosConfig, node)) + + // Clone the context. + cloned := ctx.Clone() + + // Verify cloned has same data. + require.NotNil(t, cloned) + assert.Len(t, cloned.CallStack, 1) + assert.Equal(t, node, cloned.CallStack[0]) + assert.True(t, cloned.Visited["tenant1-ue2-dev-vpc"]) + + // Verify cloned is independent. + cloned.Pop(atmosConfig) + assert.Empty(t, cloned.CallStack) + assert.Len(t, ctx.CallStack, 1) // Original unchanged. +} + +func TestContext_Clone_Nil(t *testing.T) { + var ctx *Context + cloned := ctx.Clone() + assert.Nil(t, cloned) +} + +func TestGetOrCreate(t *testing.T) { + // Clear any existing context. + Clear() + + // First call should create a new context. + ctx1 := GetOrCreate() + require.NotNil(t, ctx1) + assert.Empty(t, ctx1.CallStack) + + // Add a node to the context. + atmosConfig := &schema.AtmosConfiguration{} + node := DependencyNode{ + Component: "vpc", + Stack: "test", + FunctionType: "terraform.output", + FunctionCall: "test", + } + require.NoError(t, ctx1.Push(atmosConfig, node)) + + // Second call should return the same context. + ctx2 := GetOrCreate() + assert.Same(t, ctx1, ctx2) + assert.Len(t, ctx2.CallStack, 1) + + // Cleanup. + Clear() +} + +func TestClear(t *testing.T) { + // Create a context. + ctx := GetOrCreate() + atmosConfig := &schema.AtmosConfiguration{} + + node := DependencyNode{ + Component: "vpc", + Stack: "test", + FunctionType: "terraform.output", + FunctionCall: "test", + } + require.NoError(t, ctx.Push(atmosConfig, node)) + + // Clear the context. + Clear() + + // Next GetOrCreate should return a fresh context. + newCtx := GetOrCreate() + assert.Empty(t, newCtx.CallStack) + + // Cleanup. + Clear() +} + +func TestScoped(t *testing.T) { + // Setup an existing context. + existingCtx := GetOrCreate() + atmosConfig := &schema.AtmosConfiguration{} + + node := DependencyNode{ + Component: "vpc", + Stack: "test", + FunctionType: "terraform.output", + FunctionCall: "test", + } + require.NoError(t, existingCtx.Push(atmosConfig, node)) + + // Create a scoped context. + restore := Scoped() + + // Within the scope, we should have a fresh context. + scopedCtx := GetOrCreate() + assert.Empty(t, scopedCtx.CallStack) + assert.NotSame(t, existingCtx, scopedCtx) + + // Add something to the scoped context. + node2 := DependencyNode{ + Component: "eks", + Stack: "test", + FunctionType: "terraform.output", + FunctionCall: "test2", + } + require.NoError(t, scopedCtx.Push(atmosConfig, node2)) + + // Restore the original context. + restore() + + // After restore, we should have the original context back. + restoredCtx := GetOrCreate() + assert.Len(t, restoredCtx.CallStack, 1) + assert.Equal(t, "vpc", restoredCtx.CallStack[0].Component) + + // Cleanup. + Clear() +} + +func TestScoped_NoExistingContext(t *testing.T) { + // Make sure there's no existing context. + Clear() + + // Create a scoped context. + restore := Scoped() + + // Add something to the scoped context. + ctx := GetOrCreate() + atmosConfig := &schema.AtmosConfiguration{} + node := DependencyNode{ + Component: "test", + Stack: "test", + FunctionType: "test", + FunctionCall: "test", + } + require.NoError(t, ctx.Push(atmosConfig, node)) + + // Restore. + restore() + + // After restore, there should be no context (it was nil before). + // Calling GetOrCreate should create a fresh one. + newCtx := GetOrCreate() + assert.Empty(t, newCtx.CallStack) + + // Cleanup. + Clear() +} + +func TestGetGoroutineID(t *testing.T) { + // Test that we get a consistent ID within the same goroutine. + id1 := getGoroutineID() + id2 := getGoroutineID() + + assert.Equal(t, id1, id2) + assert.NotEmpty(t, id1) +} + +func TestGetGoroutineID_DifferentGoroutines(t *testing.T) { + var wg sync.WaitGroup + ids := make([]string, 2) + + wg.Add(2) + + go func() { + defer wg.Done() + ids[0] = getGoroutineID() + }() + + go func() { + defer wg.Done() + ids[1] = getGoroutineID() + }() + + wg.Wait() + + // Different goroutines should have different IDs. + assert.NotEmpty(t, ids[0]) + assert.NotEmpty(t, ids[1]) + assert.NotEqual(t, ids[0], ids[1]) +} + +func TestBuildCircularDependencyError(t *testing.T) { + ctx := NewContext() + atmosConfig := &schema.AtmosConfiguration{} + + // Build a call stack. + node1 := DependencyNode{ + Component: "vpc", + Stack: "tenant1-ue2-dev", + FunctionType: "terraform.output", + FunctionCall: "!terraform.output vpc vpc_id", + } + node2 := DependencyNode{ + Component: "eks", + Stack: "tenant1-ue2-dev", + FunctionType: "terraform.output", + FunctionCall: "!terraform.output eks cluster_arn", + } + + require.NoError(t, ctx.Push(atmosConfig, node1)) + require.NoError(t, ctx.Push(atmosConfig, node2)) + + // Build the error for a circular dependency back to vpc. + newNode := DependencyNode{ + Component: "vpc", + Stack: "tenant1-ue2-dev", + FunctionType: "terraform.output", + FunctionCall: "!terraform.output vpc subnet_ids", + } + + err := ctx.buildCircularDependencyError(newNode) + + require.Error(t, err) + assert.ErrorIs(t, err, errUtils.ErrCircularDependency) + assert.Contains(t, err.Error(), "Dependency chain") + assert.Contains(t, err.Error(), "vpc") + assert.Contains(t, err.Error(), "eks") + assert.Contains(t, err.Error(), "tenant1-ue2-dev") + assert.Contains(t, err.Error(), "cycle detected") + assert.Contains(t, err.Error(), "To fix this issue") +} + +func TestConcurrentContextAccess(t *testing.T) { + var wg sync.WaitGroup + numGoroutines := 10 + + // Clear any existing contexts. + goroutineContexts = sync.Map{} + + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + + // Each goroutine gets/creates its own context. + ctx := GetOrCreate() + assert.NotNil(t, ctx) + + // Push a unique node. + atmosConfig := &schema.AtmosConfiguration{} + node := DependencyNode{ + Component: "comp", + Stack: "stack", + FunctionType: "terraform.output", + FunctionCall: "test", + } + err := ctx.Push(atmosConfig, node) + assert.NoError(t, err) + + // Verify the context has exactly one node. + assert.Len(t, ctx.CallStack, 1) + + // Clear at the end. + Clear() + }() + } + + wg.Wait() +} diff --git a/pkg/function/store.go b/pkg/function/store.go new file mode 100644 index 0000000000..c995df5f77 --- /dev/null +++ b/pkg/function/store.go @@ -0,0 +1,273 @@ +package function + +import ( + "context" + "fmt" + "strings" + + log "github.com/cloudposse/atmos/pkg/logger" + "github.com/cloudposse/atmos/pkg/perf" + "github.com/cloudposse/atmos/pkg/store" + "github.com/cloudposse/atmos/pkg/utils" +) + +// StoreFunction implements the store function for retrieving values from configured stores. +type StoreFunction struct { + BaseFunction +} + +// NewStoreFunction creates a new store function handler. +func NewStoreFunction() *StoreFunction { + defer perf.Track(nil, "function.NewStoreFunction")() + + return &StoreFunction{ + BaseFunction: BaseFunction{ + FunctionName: TagStore, + FunctionAliases: nil, + FunctionPhase: PostMerge, + }, + } +} + +// Execute processes the store function. +// Usage: +// +// !store store_name stack component key +// !store store_name component key - Uses current stack +// !store store_name stack component key | default "value" +// !store store_name stack component key | query ".foo.bar" +func (f *StoreFunction) Execute(ctx context.Context, args string, execCtx *ExecutionContext) (any, error) { + defer perf.Track(nil, "function.StoreFunction.Execute")() + + log.Debug("Executing store function", "args", args) + + if execCtx == nil || execCtx.AtmosConfig == nil { + return nil, fmt.Errorf("%w: store function requires AtmosConfig", ErrExecutionFailed) + } + + // Parse parameters. + params, err := parseStoreParams(args, execCtx.Stack) + if err != nil { + return nil, err + } + + // Retrieve the store from atmosConfig. + store := execCtx.AtmosConfig.Stores[params.storeName] + if store == nil { + return nil, fmt.Errorf("%w: store '%s' not found", ErrExecutionFailed, params.storeName) + } + + // Retrieve the value from the store. + value, err := store.Get(params.stack, params.component, params.key) + if err != nil { + if params.defaultValue != nil { + return *params.defaultValue, nil + } + return nil, fmt.Errorf("%w: failed to get key '%s': %w", ErrExecutionFailed, params.key, err) + } + + // Execute the YQ expression if provided. + if params.query != "" { + value, err = utils.EvaluateYqExpression(execCtx.AtmosConfig, value, params.query) + if err != nil { + return nil, err + } + } + + return value, nil +} + +// StoreGetFunction implements the store.get function for retrieving arbitrary keys from stores. +type StoreGetFunction struct { + BaseFunction +} + +// NewStoreGetFunction creates a new store.get function handler. +func NewStoreGetFunction() *StoreGetFunction { + defer perf.Track(nil, "function.NewStoreGetFunction")() + + return &StoreGetFunction{ + BaseFunction: BaseFunction{ + FunctionName: TagStoreGet, + FunctionAliases: nil, + FunctionPhase: PostMerge, + }, + } +} + +// retrieveFromStore gets a value from the store and applies defaults if needed. +func retrieveFromStore(s store.Store, params *storeGetParams) (any, error) { + value, err := s.GetKey(params.key) + if err != nil { + if params.defaultValue != nil { + return *params.defaultValue, nil + } + return nil, fmt.Errorf("%w: failed to get key '%s': %w", ErrExecutionFailed, params.key, err) + } + + // Check if the retrieved value is nil and use default if provided. + if value == nil && params.defaultValue != nil { + return *params.defaultValue, nil + } + + return value, nil +} + +// Execute processes the store.get function. +// Usage: +// +// !store.get store_name key +// !store.get store_name key | default "value" +// !store.get store_name key | query ".foo.bar" +func (f *StoreGetFunction) Execute(ctx context.Context, args string, execCtx *ExecutionContext) (any, error) { + defer perf.Track(nil, "function.StoreGetFunction.Execute")() + + log.Debug("Executing store.get function", "args", args) + + if execCtx == nil || execCtx.AtmosConfig == nil { + return nil, fmt.Errorf("%w: store.get function requires AtmosConfig", ErrExecutionFailed) + } + + // Parse parameters. + params, err := parseStoreGetParams(args) + if err != nil { + return nil, err + } + + // Retrieve the store from atmosConfig. + store := execCtx.AtmosConfig.Stores[params.storeName] + if store == nil { + return nil, fmt.Errorf("%w: store '%s' not found", ErrExecutionFailed, params.storeName) + } + + // Retrieve the value from the store. + value, err := retrieveFromStore(store, params) + if err != nil { + return nil, err + } + + // Execute the YQ expression if provided. + if params.query != "" { + return utils.EvaluateYqExpression(execCtx.AtmosConfig, value, params.query) + } + + return value, nil +} + +// storeParams holds parsed parameters for the store function. +type storeParams struct { + storeName string + stack string + component string + key string + query string + defaultValue *string +} + +// storeGetParams holds parsed parameters for the store.get function. +type storeGetParams struct { + storeName string + key string + query string + defaultValue *string +} + +// parseStoreParams parses the arguments for the store function. +func parseStoreParams(args string, currentStack string) (*storeParams, error) { + // Split on pipe to separate store parameters and options. + parts := strings.Split(args, "|") + storePart := strings.TrimSpace(parts[0]) + + // Extract default value and query from pipe parts. + var defaultValue *string + var query string + if len(parts) > 1 { + var err error + defaultValue, query, err = extractPipeOptions(parts[1:]) + if err != nil { + return nil, err + } + } + + // Process the main store part. + storeParts := strings.Fields(storePart) + partsLength := len(storeParts) + if partsLength != 3 && partsLength != 4 { + return nil, fmt.Errorf("%w: store function requires 3 or 4 parameters, got %d", ErrInvalidArguments, partsLength) + } + + params := &storeParams{ + storeName: strings.TrimSpace(storeParts[0]), + defaultValue: defaultValue, + query: query, + } + + if partsLength == 4 { + params.stack = strings.TrimSpace(storeParts[1]) + params.component = strings.TrimSpace(storeParts[2]) + params.key = strings.TrimSpace(storeParts[3]) + } else { + params.stack = currentStack + params.component = strings.TrimSpace(storeParts[1]) + params.key = strings.TrimSpace(storeParts[2]) + } + + return params, nil +} + +// parseStoreGetParams parses the arguments for the store.get function. +func parseStoreGetParams(args string) (*storeGetParams, error) { + // Split on pipe to separate store parameters and options. + parts := strings.Split(args, "|") + storePart := strings.TrimSpace(parts[0]) + + // Extract default value and query from pipe parts. + var defaultValue *string + var query string + if len(parts) > 1 { + var err error + defaultValue, query, err = extractPipeOptions(parts[1:]) + if err != nil { + return nil, err + } + } + + // Process the main store part. + storeParts := strings.Fields(storePart) + if len(storeParts) != 2 { + return nil, fmt.Errorf("%w: store.get function requires 2 parameters, got %d", ErrInvalidArguments, len(storeParts)) + } + + return &storeGetParams{ + storeName: strings.TrimSpace(storeParts[0]), + key: strings.TrimSpace(storeParts[1]), + defaultValue: defaultValue, + query: query, + }, nil +} + +// extractPipeOptions extracts default value and query from pipe-separated parts. +func extractPipeOptions(parts []string) (*string, string, error) { + var defaultValue *string + var query string + + for _, p := range parts { + // Use SplitN to handle values containing spaces (e.g., query ".foo .bar"). + pipeParts := strings.SplitN(strings.TrimSpace(p), " ", 2) + if len(pipeParts) != 2 { + return nil, "", fmt.Errorf("%w: invalid pipe parameters", ErrInvalidArguments) + } + key := strings.Trim(pipeParts[0], `"'`) + value := strings.Trim(pipeParts[1], `"'`) + switch key { + case "default": + defaultValue = &value + case "query": + query = value + default: + return nil, "", fmt.Errorf("%w: invalid pipe identifier '%s'", ErrInvalidArguments, key) + } + } + + return defaultValue, query, nil +} diff --git a/pkg/function/store_test.go b/pkg/function/store_test.go new file mode 100644 index 0000000000..edb2f936eb --- /dev/null +++ b/pkg/function/store_test.go @@ -0,0 +1,578 @@ +package function + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/cloudposse/atmos/pkg/schema" + "github.com/cloudposse/atmos/pkg/store" +) + +func TestNewStoreFunction(t *testing.T) { + fn := NewStoreFunction() + require.NotNil(t, fn) + assert.Equal(t, TagStore, fn.Name()) + assert.Equal(t, PostMerge, fn.Phase()) + assert.Nil(t, fn.Aliases()) +} + +func TestNewStoreGetFunction(t *testing.T) { + fn := NewStoreGetFunction() + require.NotNil(t, fn) + assert.Equal(t, TagStoreGet, fn.Name()) + assert.Equal(t, PostMerge, fn.Phase()) + assert.Nil(t, fn.Aliases()) +} + +func TestStoreFunction_Execute(t *testing.T) { + tests := []struct { + name string + args string + currentStack string + storeValue any + storeErr error + want any + wantErr bool + errContains string + }{ + { + name: "basic 4-param usage", + args: "mystore tenant1-ue2-dev vpc outputs", + currentStack: "default-stack", + storeValue: map[string]any{"vpc_id": "vpc-123"}, + want: map[string]any{"vpc_id": "vpc-123"}, + }, + { + name: "basic 3-param usage with current stack", + args: "mystore vpc outputs", + currentStack: "tenant1-ue2-dev", + storeValue: "simple-value", + want: "simple-value", + }, + { + name: "with default value on error", + args: "mystore tenant1-ue2-dev vpc outputs | default \"fallback\"", + currentStack: "default", + storeErr: errors.New("key not found"), + want: "fallback", + }, + { + name: "error without default", + args: "mystore tenant1-ue2-dev vpc outputs", + currentStack: "default", + storeErr: errors.New("key not found"), + wantErr: true, + errContains: "failed to get key", + }, + { + name: "store not found", + args: "nonexistent tenant1-ue2-dev vpc outputs", + currentStack: "default", + wantErr: true, + errContains: "store 'nonexistent' not found", + }, + { + name: "invalid argument count - too few", + args: "mystore vpc", + currentStack: "default", + wantErr: true, + errContains: "requires 3 or 4 parameters", + }, + { + name: "invalid argument count - too many", + args: "mystore a b c d e", + currentStack: "default", + wantErr: true, + errContains: "requires 3 or 4 parameters", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + + // Only set up expectation if store will be found. + if tt.errContains != "store 'nonexistent' not found" && tt.errContains != "requires 3 or 4 parameters" { + mockStore.EXPECT(). + Get(gomock.Any(), gomock.Any(), gomock.Any()). + Return(tt.storeValue, tt.storeErr). + AnyTimes() + } + + atmosConfig := &schema.AtmosConfiguration{ + Stores: map[string]store.Store{ + "mystore": mockStore, + }, + } + + execCtx := &ExecutionContext{ + AtmosConfig: atmosConfig, + Stack: tt.currentStack, + } + + fn := NewStoreFunction() + result, err := fn.Execute(context.Background(), tt.args, execCtx) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + + require.NoError(t, err) + assert.Equal(t, tt.want, result) + }) + } +} + +func TestStoreFunction_Execute_NilContext(t *testing.T) { + fn := NewStoreFunction() + + // Test with nil context. + _, err := fn.Execute(context.Background(), "mystore stack comp key", nil) + require.Error(t, err) + assert.ErrorIs(t, err, ErrExecutionFailed) + assert.Contains(t, err.Error(), "requires AtmosConfig") + + // Test with nil AtmosConfig. + execCtx := &ExecutionContext{AtmosConfig: nil} + _, err = fn.Execute(context.Background(), "mystore stack comp key", execCtx) + require.Error(t, err) + assert.ErrorIs(t, err, ErrExecutionFailed) +} + +func TestStoreGetFunction_Execute(t *testing.T) { + tests := []struct { + name string + args string + storeValue any + storeErr error + want any + wantErr bool + errContains string + }{ + { + name: "basic usage", + args: "mystore mykey", + storeValue: "retrieved-value", + want: "retrieved-value", + }, + { + name: "with default value on error", + args: "mystore mykey | default \"fallback\"", + storeErr: errors.New("key not found"), + want: "fallback", + }, + { + name: "with default value on nil", + args: "mystore mykey | default \"fallback\"", + storeValue: nil, + want: "fallback", + }, + { + name: "error without default", + args: "mystore mykey", + storeErr: errors.New("key not found"), + wantErr: true, + errContains: "failed to get key", + }, + { + name: "store not found", + args: "nonexistent mykey", + wantErr: true, + errContains: "store 'nonexistent' not found", + }, + { + name: "invalid argument count - too few", + args: "mystore", + wantErr: true, + errContains: "requires 2 parameters", + }, + { + name: "invalid argument count - too many", + args: "mystore key1 key2", + wantErr: true, + errContains: "requires 2 parameters", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + + // Only set up expectation if store will be found. + if tt.errContains != "store 'nonexistent' not found" && tt.errContains != "requires 2 parameters" { + mockStore.EXPECT(). + GetKey(gomock.Any()). + Return(tt.storeValue, tt.storeErr). + AnyTimes() + } + + atmosConfig := &schema.AtmosConfiguration{ + Stores: map[string]store.Store{ + "mystore": mockStore, + }, + } + + execCtx := &ExecutionContext{ + AtmosConfig: atmosConfig, + } + + fn := NewStoreGetFunction() + result, err := fn.Execute(context.Background(), tt.args, execCtx) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + + require.NoError(t, err) + assert.Equal(t, tt.want, result) + }) + } +} + +func TestStoreGetFunction_Execute_NilContext(t *testing.T) { + fn := NewStoreGetFunction() + + // Test with nil context. + _, err := fn.Execute(context.Background(), "mystore mykey", nil) + require.Error(t, err) + assert.ErrorIs(t, err, ErrExecutionFailed) + assert.Contains(t, err.Error(), "requires AtmosConfig") + + // Test with nil AtmosConfig. + execCtx := &ExecutionContext{AtmosConfig: nil} + _, err = fn.Execute(context.Background(), "mystore mykey", execCtx) + require.Error(t, err) + assert.ErrorIs(t, err, ErrExecutionFailed) +} + +func TestParseStoreParams(t *testing.T) { + tests := []struct { + name string + args string + currentStack string + wantStore string + wantStack string + wantComp string + wantKey string + wantDefault *string + wantQuery string + wantErr bool + }{ + { + name: "4 params", + args: "store1 stack1 comp1 key1", + currentStack: "current", + wantStore: "store1", + wantStack: "stack1", + wantComp: "comp1", + wantKey: "key1", + }, + { + name: "3 params uses current stack", + args: "store1 comp1 key1", + currentStack: "current", + wantStore: "store1", + wantStack: "current", + wantComp: "comp1", + wantKey: "key1", + }, + { + name: "with default value", + args: "store1 stack1 comp1 key1 | default \"mydefault\"", + currentStack: "current", + wantStore: "store1", + wantStack: "stack1", + wantComp: "comp1", + wantKey: "key1", + wantDefault: strPtr("mydefault"), + }, + { + name: "with query", + args: "store1 stack1 comp1 key1 | query \".foo.bar\"", + currentStack: "current", + wantStore: "store1", + wantStack: "stack1", + wantComp: "comp1", + wantKey: "key1", + wantQuery: ".foo.bar", + }, + { + name: "with both default and query", + args: "store1 stack1 comp1 key1 | default \"def\" | query \".x\"", + currentStack: "current", + wantStore: "store1", + wantStack: "stack1", + wantComp: "comp1", + wantKey: "key1", + wantDefault: strPtr("def"), + wantQuery: ".x", + }, + { + name: "too few params", + args: "store1 comp1", + currentStack: "current", + wantErr: true, + }, + { + name: "too many params", + args: "store1 a b c d", + currentStack: "current", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + params, err := parseStoreParams(tt.args, tt.currentStack) + + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.wantStore, params.storeName) + assert.Equal(t, tt.wantStack, params.stack) + assert.Equal(t, tt.wantComp, params.component) + assert.Equal(t, tt.wantKey, params.key) + assert.Equal(t, tt.wantQuery, params.query) + + if tt.wantDefault == nil { + assert.Nil(t, params.defaultValue) + } else { + require.NotNil(t, params.defaultValue) + assert.Equal(t, *tt.wantDefault, *params.defaultValue) + } + }) + } +} + +func TestParseStoreGetParams(t *testing.T) { + tests := []struct { + name string + args string + wantStore string + wantKey string + wantDefault *string + wantQuery string + wantErr bool + }{ + { + name: "basic 2 params", + args: "store1 key1", + wantStore: "store1", + wantKey: "key1", + }, + { + name: "with default", + args: "store1 key1 | default \"fallback\"", + wantStore: "store1", + wantKey: "key1", + wantDefault: strPtr("fallback"), + }, + { + name: "with query", + args: "store1 key1 | query \".path\"", + wantStore: "store1", + wantKey: "key1", + wantQuery: ".path", + }, + { + name: "too few params", + args: "store1", + wantErr: true, + }, + { + name: "too many params", + args: "store1 key1 extra", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + params, err := parseStoreGetParams(tt.args) + + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.wantStore, params.storeName) + assert.Equal(t, tt.wantKey, params.key) + assert.Equal(t, tt.wantQuery, params.query) + + if tt.wantDefault == nil { + assert.Nil(t, params.defaultValue) + } else { + require.NotNil(t, params.defaultValue) + assert.Equal(t, *tt.wantDefault, *params.defaultValue) + } + }) + } +} + +func TestExtractPipeOptions(t *testing.T) { + tests := []struct { + name string + parts []string + wantDefault *string + wantQuery string + wantErr bool + }{ + { + name: "empty parts", + parts: []string{}, + }, + { + name: "default only", + parts: []string{"default \"value\""}, + wantDefault: strPtr("value"), + }, + { + name: "query only", + parts: []string{"query \".foo\""}, + wantQuery: ".foo", + }, + { + name: "both default and query", + parts: []string{"default \"val\"", "query \".bar\""}, + wantDefault: strPtr("val"), + wantQuery: ".bar", + }, + { + name: "invalid - no value", + parts: []string{"default"}, + wantErr: true, + }, + { + name: "invalid - unknown key", + parts: []string{"unknown \"value\""}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defVal, query, err := extractPipeOptions(tt.parts) + + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.wantQuery, query) + + if tt.wantDefault == nil { + assert.Nil(t, defVal) + } else { + require.NotNil(t, defVal) + assert.Equal(t, *tt.wantDefault, *defVal) + } + }) + } +} + +// strPtr is a helper to create a pointer to a string. +func strPtr(s string) *string { + return &s +} + +func TestRetrieveFromStore(t *testing.T) { + tests := []struct { + name string + storeValue any + storeErr error + hasDefault bool + defaultVal string + want any + wantErr bool + errContains string + }{ + { + name: "successful retrieval", + storeValue: "stored-value", + want: "stored-value", + }, + { + name: "error with default", + storeErr: errors.New("not found"), + hasDefault: true, + defaultVal: "default-val", + want: "default-val", + }, + { + name: "error without default", + storeErr: errors.New("not found"), + wantErr: true, + errContains: "failed to get key", + }, + { + name: "nil value with default", + storeValue: nil, + hasDefault: true, + defaultVal: "default-for-nil", + want: "default-for-nil", + }, + { + name: "nil value without default", + storeValue: nil, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT(). + GetKey(gomock.Any()). + Return(tt.storeValue, tt.storeErr). + AnyTimes() + + params := &storeGetParams{ + storeName: "test", + key: "testkey", + } + if tt.hasDefault { + params.defaultValue = &tt.defaultVal + } + + result, err := retrieveFromStore(mockStore, params) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + + require.NoError(t, err) + assert.Equal(t, tt.want, result) + }) + } +} diff --git a/pkg/function/tags.go b/pkg/function/tags.go new file mode 100644 index 0000000000..e0cc8a8d6d --- /dev/null +++ b/pkg/function/tags.go @@ -0,0 +1,129 @@ +package function + +import ( + "github.com/cloudposse/atmos/pkg/perf" +) + +// Tag constants for Atmos configuration functions. +// These are the canonical function names used across all formats. +// In YAML they appear as !tag, in HCL as tag(), etc. +const ( + // TagExec executes a shell command and returns the output. + TagExec = "exec" + + // TagStore retrieves a value from a configured store. + TagStore = "store" + + // TagStoreGet retrieves a value from a configured store (alternative syntax). + TagStoreGet = "store.get" + + // TagTemplate processes a JSON template. + TagTemplate = "template" + + // TagTerraformOutput retrieves a Terraform output value. + TagTerraformOutput = "terraform.output" + + // TagTerraformState retrieves a value from Terraform state. + TagTerraformState = "terraform.state" + + // TagEnv retrieves an environment variable value. + TagEnv = "env" + + // TagInclude includes content from another file. + TagInclude = "include" + + // TagIncludeRaw includes raw content from another file. + TagIncludeRaw = "include.raw" + + // TagRepoRoot returns the git repository root path. + TagRepoRoot = "repo-root" + + // TagRandom generates a random number. + TagRandom = "random" + + // TagLiteral preserves values exactly as written, bypassing template processing. + TagLiteral = "literal" + + // TagAwsAccountID returns the AWS account ID. + TagAwsAccountID = "aws.account_id" + + // TagAwsCallerIdentityArn returns the AWS caller identity ARN. + TagAwsCallerIdentityArn = "aws.caller_identity_arn" + + // TagAwsCallerIdentityUserID returns the AWS caller identity user ID. + TagAwsCallerIdentityUserID = "aws.caller_identity_user_id" + + // TagAwsRegion returns the AWS region. + TagAwsRegion = "aws.region" +) + +// YAMLTagPrefix is the prefix used for YAML custom tags. +const YAMLTagPrefix = "!" + +// AllTags returns all registered tag names. +func AllTags() []string { + defer perf.Track(nil, "function.AllTags")() + + return []string{ + TagExec, + TagStore, + TagStoreGet, + TagTemplate, + TagTerraformOutput, + TagTerraformState, + TagEnv, + TagInclude, + TagIncludeRaw, + TagRepoRoot, + TagRandom, + TagLiteral, + TagAwsAccountID, + TagAwsCallerIdentityArn, + TagAwsCallerIdentityUserID, + TagAwsRegion, + } +} + +// tagsMap provides O(1) lookup for tag names. +var tagsMap = map[string]bool{ + TagExec: true, + TagStore: true, + TagStoreGet: true, + TagTemplate: true, + TagTerraformOutput: true, + TagTerraformState: true, + TagEnv: true, + TagInclude: true, + TagIncludeRaw: true, + TagRepoRoot: true, + TagRandom: true, + TagLiteral: true, + TagAwsAccountID: true, + TagAwsCallerIdentityArn: true, + TagAwsCallerIdentityUserID: true, + TagAwsRegion: true, +} + +// IsValidTag checks if the given tag name is registered. +func IsValidTag(tag string) bool { + defer perf.Track(nil, "function.IsValidTag")() + + return tagsMap[tag] +} + +// YAMLTag returns the YAML tag format for a function name (e.g., "env" -> "!env"). +func YAMLTag(name string) string { + defer perf.Track(nil, "function.YAMLTag")() + + return YAMLTagPrefix + name +} + +// FromYAMLTag extracts the function name from a YAML tag (e.g., "!env" -> "env"). +func FromYAMLTag(tag string) string { + defer perf.Track(nil, "function.FromYAMLTag")() + + if len(tag) > 0 && tag[0] == '!' { + return tag[1:] + } + return tag +} diff --git a/pkg/function/tags_test.go b/pkg/function/tags_test.go new file mode 100644 index 0000000000..a46de564ae --- /dev/null +++ b/pkg/function/tags_test.go @@ -0,0 +1,180 @@ +package function + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAllTags(t *testing.T) { + tags := AllTags() + + // Verify all expected tags are present. + expectedTags := []string{ + TagExec, + TagStore, + TagStoreGet, + TagTemplate, + TagTerraformOutput, + TagTerraformState, + TagEnv, + TagInclude, + TagIncludeRaw, + TagRepoRoot, + TagRandom, + TagLiteral, + TagAwsAccountID, + TagAwsCallerIdentityArn, + TagAwsCallerIdentityUserID, + TagAwsRegion, + } + + assert.Equal(t, len(expectedTags), len(tags)) + + for _, expected := range expectedTags { + assert.Contains(t, tags, expected, "expected tag %s to be in AllTags()", expected) + } +} + +func TestIsValidTag(t *testing.T) { + // Verify all expected tags are valid. + expectedTags := []string{ + TagExec, + TagStore, + TagStoreGet, + TagTemplate, + TagTerraformOutput, + TagTerraformState, + TagEnv, + TagInclude, + TagIncludeRaw, + TagRepoRoot, + TagRandom, + TagLiteral, + TagAwsAccountID, + TagAwsCallerIdentityArn, + TagAwsCallerIdentityUserID, + TagAwsRegion, + } + + for _, tag := range expectedTags { + assert.True(t, IsValidTag(tag), "expected tag %s to be valid", tag) + } + + // Verify non-existent tag returns false. + assert.False(t, IsValidTag("non-existent-tag")) +} + +func TestYAMLTag(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {TagEnv, "!env"}, + {TagExec, "!exec"}, + {TagStore, "!store"}, + {TagStoreGet, "!store.get"}, + {TagTemplate, "!template"}, + {TagTerraformOutput, "!terraform.output"}, + {TagTerraformState, "!terraform.state"}, + {TagInclude, "!include"}, + {TagIncludeRaw, "!include.raw"}, + {TagRepoRoot, "!repo-root"}, + {TagRandom, "!random"}, + {TagLiteral, "!literal"}, + {TagAwsAccountID, "!aws.account_id"}, + {TagAwsCallerIdentityArn, "!aws.caller_identity_arn"}, + {TagAwsCallerIdentityUserID, "!aws.caller_identity_user_id"}, + {TagAwsRegion, "!aws.region"}, + {"custom", "!custom"}, + {"", "!"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := YAMLTag(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestFromYAMLTag(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"!env", "env"}, + {"!exec", "exec"}, + {"!store", "store"}, + {"!store.get", "store.get"}, + {"!template", "template"}, + {"!terraform.output", "terraform.output"}, + {"!terraform.state", "terraform.state"}, + {"!include", "include"}, + {"!include.raw", "include.raw"}, + {"!repo-root", "repo-root"}, + {"!random", "random"}, + {"!literal", "literal"}, + {"!aws.account_id", "aws.account_id"}, + {"!aws.caller_identity_arn", "aws.caller_identity_arn"}, + {"!aws.caller_identity_user_id", "aws.caller_identity_user_id"}, + {"!aws.region", "aws.region"}, + {"!custom", "custom"}, + // Without prefix - returns as-is. + {"env", "env"}, + {"store", "store"}, + {"", ""}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := FromYAMLTag(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestYAMLTagPrefix(t *testing.T) { + assert.Equal(t, "!", YAMLTagPrefix) +} + +func TestTagConstants(t *testing.T) { + // Verify tag constant values. + assert.Equal(t, "exec", TagExec) + assert.Equal(t, "store", TagStore) + assert.Equal(t, "store.get", TagStoreGet) + assert.Equal(t, "template", TagTemplate) + assert.Equal(t, "terraform.output", TagTerraformOutput) + assert.Equal(t, "terraform.state", TagTerraformState) + assert.Equal(t, "env", TagEnv) + assert.Equal(t, "include", TagInclude) + assert.Equal(t, "include.raw", TagIncludeRaw) + assert.Equal(t, "repo-root", TagRepoRoot) + assert.Equal(t, "random", TagRandom) + assert.Equal(t, "literal", TagLiteral) + assert.Equal(t, "aws.account_id", TagAwsAccountID) + assert.Equal(t, "aws.caller_identity_arn", TagAwsCallerIdentityArn) + assert.Equal(t, "aws.caller_identity_user_id", TagAwsCallerIdentityUserID) + assert.Equal(t, "aws.region", TagAwsRegion) +} + +func TestYAMLTag_RoundTrip(t *testing.T) { + // Test that YAMLTag and FromYAMLTag are inverse operations. + tags := AllTags() + + for _, tag := range tags { + yamlTag := YAMLTag(tag) + recovered := FromYAMLTag(yamlTag) + assert.Equal(t, tag, recovered, "round-trip failed for tag %s", tag) + } +} + +func TestIsValidTag_Consistency(t *testing.T) { + // Verify IsValidTag is consistent with AllTags. + tags := AllTags() + + // All tags in AllTags should be valid. + for _, tag := range tags { + assert.True(t, IsValidTag(tag), "tag %s in AllTags() but IsValidTag returns false", tag) + } +} diff --git a/pkg/function/template.go b/pkg/function/template.go new file mode 100644 index 0000000000..1b3b94088e --- /dev/null +++ b/pkg/function/template.go @@ -0,0 +1,118 @@ +package function + +import ( + "context" + "encoding/json" + "strings" + + log "github.com/cloudposse/atmos/pkg/logger" + "github.com/cloudposse/atmos/pkg/perf" +) + +// TemplateFunction implements the template function for JSON template processing. +type TemplateFunction struct { + BaseFunction +} + +// NewTemplateFunction creates a new template function handler. +func NewTemplateFunction() *TemplateFunction { + defer perf.Track(nil, "function.NewTemplateFunction")() + + return &TemplateFunction{ + BaseFunction: BaseFunction{ + FunctionName: TagTemplate, + FunctionAliases: nil, + FunctionPhase: PreMerge, + }, + } +} + +// Execute processes the template function. +// Usage: +// +// !template {"key": "value"} - Parse JSON and return as native type +// !template [1, 2, 3] - Parse JSON array and return as native slice +// +// If the input is valid JSON, it will be parsed and returned as the corresponding type. +// Otherwise, the raw string is returned. +func (f *TemplateFunction) Execute(ctx context.Context, args string, execCtx *ExecutionContext) (any, error) { + defer perf.Track(nil, "function.TemplateFunction.Execute")() + + log.Debug("Executing template function", "args", args) + + args = strings.TrimSpace(args) + if args == "" { + return "", nil + } + + // Try to parse as JSON. + var decoded any + if err := json.Unmarshal([]byte(args), &decoded); err != nil { + // Not valid JSON, return as-is. + return args, nil + } + + return decoded, nil +} + +// processTemplateString processes a string that may contain a !template tag. +func processTemplateString(s string, templatePrefix string) any { + if !strings.HasPrefix(s, templatePrefix) { + return s + } + // Extract args after the tag. + args := strings.TrimPrefix(s, templatePrefix) + args = strings.TrimSpace(args) + + // Parse as JSON if possible. + var decoded any + if err := json.Unmarshal([]byte(args), &decoded); err != nil { + return args + } + return decoded +} + +// ProcessTemplateTagsOnly processes only !template tags in a data structure, recursively. +// It is used before merging to ensure !template strings are decoded to their actual types. +// This avoids type conflicts during merge (e.g., string vs list). +func ProcessTemplateTagsOnly(input map[string]any) map[string]any { + defer perf.Track(nil, "function.ProcessTemplateTagsOnly")() + + if input == nil { + return nil + } + + templatePrefix := YAMLTag(TagTemplate) + result := make(map[string]any, len(input)) + + var recurse func(any) any + recurse = func(node any) any { + switch v := node.(type) { + case string: + return processTemplateString(v, templatePrefix) + + case map[string]any: + newMap := make(map[string]any, len(v)) + for k, val := range v { + newMap[k] = recurse(val) + } + return newMap + + case []any: + newSlice := make([]any, len(v)) + for i, val := range v { + newSlice[i] = recurse(val) + } + return newSlice + + default: + return v + } + } + + for k, v := range input { + result[k] = recurse(v) + } + + return result +} diff --git a/pkg/function/template_test.go b/pkg/function/template_test.go new file mode 100644 index 0000000000..6a593c22ac --- /dev/null +++ b/pkg/function/template_test.go @@ -0,0 +1,298 @@ +package function + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewTemplateFunction(t *testing.T) { + fn := NewTemplateFunction() + require.NotNil(t, fn) + assert.Equal(t, TagTemplate, fn.Name()) + assert.Equal(t, PreMerge, fn.Phase()) + assert.Nil(t, fn.Aliases()) +} + +func TestTemplateFunction_Execute_EmptyString(t *testing.T) { + fn := NewTemplateFunction() + + result, err := fn.Execute(context.Background(), "", nil) + require.NoError(t, err) + assert.Equal(t, "", result) +} + +func TestTemplateFunction_Execute_WhitespaceOnly(t *testing.T) { + fn := NewTemplateFunction() + + result, err := fn.Execute(context.Background(), " ", nil) + require.NoError(t, err) + assert.Equal(t, "", result) +} + +func TestTemplateFunction_Execute_JSONString(t *testing.T) { + fn := NewTemplateFunction() + + // JSON string value. + result, err := fn.Execute(context.Background(), `"hello"`, nil) + require.NoError(t, err) + assert.Equal(t, "hello", result) +} + +func TestTemplateFunction_Execute_JSONNumber(t *testing.T) { + fn := NewTemplateFunction() + + // JSON number. + result, err := fn.Execute(context.Background(), `42`, nil) + require.NoError(t, err) + assert.Equal(t, float64(42), result) +} + +func TestTemplateFunction_Execute_JSONBoolean(t *testing.T) { + fn := NewTemplateFunction() + + // JSON boolean true. + result, err := fn.Execute(context.Background(), `true`, nil) + require.NoError(t, err) + assert.Equal(t, true, result) + + // JSON boolean false. + result, err = fn.Execute(context.Background(), `false`, nil) + require.NoError(t, err) + assert.Equal(t, false, result) +} + +func TestTemplateFunction_Execute_JSONNull(t *testing.T) { + fn := NewTemplateFunction() + + // JSON null. + result, err := fn.Execute(context.Background(), `null`, nil) + require.NoError(t, err) + assert.Nil(t, result) +} + +func TestTemplateFunction_Execute_NestedJSON(t *testing.T) { + fn := NewTemplateFunction() + + input := `{"outer": {"inner": {"value": 123}}}` + result, err := fn.Execute(context.Background(), input, nil) + require.NoError(t, err) + + m, ok := result.(map[string]any) + require.True(t, ok) + + outer, ok := m["outer"].(map[string]any) + require.True(t, ok) + + inner, ok := outer["inner"].(map[string]any) + require.True(t, ok) + + assert.Equal(t, float64(123), inner["value"]) +} + +func TestTemplateFunction_Execute_InvalidJSON(t *testing.T) { + fn := NewTemplateFunction() + + // Invalid JSON should be returned as-is. + result, err := fn.Execute(context.Background(), `{invalid json}`, nil) + require.NoError(t, err) + assert.Equal(t, `{invalid json}`, result) +} + +func TestProcessTemplateTagsOnly_Nil(t *testing.T) { + result := ProcessTemplateTagsOnly(nil) + assert.Nil(t, result) +} + +func TestProcessTemplateTagsOnly_Empty(t *testing.T) { + result := ProcessTemplateTagsOnly(map[string]any{}) + assert.Empty(t, result) +} + +func TestProcessTemplateTagsOnly_NoTemplates(t *testing.T) { + input := map[string]any{ + "key1": "value1", + "key2": 42, + "key3": true, + } + + result := ProcessTemplateTagsOnly(input) + + assert.Equal(t, "value1", result["key1"]) + assert.Equal(t, 42, result["key2"]) + assert.Equal(t, true, result["key3"]) +} + +func TestProcessTemplateTagsOnly_WithTemplateTag(t *testing.T) { + input := map[string]any{ + "regular": "value", + "templated": "!template [1, 2, 3]", + } + + result := ProcessTemplateTagsOnly(input) + + assert.Equal(t, "value", result["regular"]) + + arr, ok := result["templated"].([]any) + require.True(t, ok) + assert.Len(t, arr, 3) + assert.Equal(t, float64(1), arr[0]) + assert.Equal(t, float64(2), arr[1]) + assert.Equal(t, float64(3), arr[2]) +} + +func TestProcessTemplateTagsOnly_NestedMap(t *testing.T) { + input := map[string]any{ + "parent": map[string]any{ + "child": "!template {\"nested\": true}", + }, + } + + result := ProcessTemplateTagsOnly(input) + + parent, ok := result["parent"].(map[string]any) + require.True(t, ok) + + child, ok := parent["child"].(map[string]any) + require.True(t, ok) + + assert.Equal(t, true, child["nested"]) +} + +func TestProcessTemplateTagsOnly_NestedSlice(t *testing.T) { + input := map[string]any{ + "items": []any{ + "!template [\"a\", \"b\"]", + "regular", + }, + } + + result := ProcessTemplateTagsOnly(input) + + items, ok := result["items"].([]any) + require.True(t, ok) + require.Len(t, items, 2) + + // First item should be parsed as array. + arr, ok := items[0].([]any) + require.True(t, ok) + assert.Equal(t, "a", arr[0]) + assert.Equal(t, "b", arr[1]) + + // Second item should remain as string. + assert.Equal(t, "regular", items[1]) +} + +func TestProcessTemplateTagsOnly_InvalidTemplateJSON(t *testing.T) { + input := map[string]any{ + "invalid": "!template {not valid json}", + } + + result := ProcessTemplateTagsOnly(input) + + // Invalid JSON should return the args portion as-is. + assert.Equal(t, "{not valid json}", result["invalid"]) +} + +func TestProcessTemplateTagsOnly_OtherTags(t *testing.T) { + input := map[string]any{ + "env_var": "!env MY_VAR", + "exec_cmd": "!exec echo hello", + "store": "!store mystore stack comp key", + } + + result := ProcessTemplateTagsOnly(input) + + // Other tags should remain untouched. + assert.Equal(t, "!env MY_VAR", result["env_var"]) + assert.Equal(t, "!exec echo hello", result["exec_cmd"]) + assert.Equal(t, "!store mystore stack comp key", result["store"]) +} + +func TestProcessTemplateTagsOnly_MixedContent(t *testing.T) { + input := map[string]any{ + "string": "plain string", + "number": 42, + "bool": true, + "nil_value": nil, + "template_obj": "!template {\"key\": \"value\"}", + "template_arr": "!template [1, 2, 3]", + "env_tag": "!env HOME", + "nested": map[string]any{ + "deep_template": "!template {\"deep\": true}", + "deep_string": "just a string", + }, + "list": []any{ + "!template {\"in_list\": true}", + "regular item", + 42, + }, + } + + result := ProcessTemplateTagsOnly(input) + + // Regular values unchanged. + assert.Equal(t, "plain string", result["string"]) + assert.Equal(t, 42, result["number"]) + assert.Equal(t, true, result["bool"]) + assert.Nil(t, result["nil_value"]) + + // Template tags processed. + obj, ok := result["template_obj"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "value", obj["key"]) + + arr, ok := result["template_arr"].([]any) + require.True(t, ok) + assert.Len(t, arr, 3) + + // Other tags preserved. + assert.Equal(t, "!env HOME", result["env_tag"]) + + // Nested map processed. + nested, ok := result["nested"].(map[string]any) + require.True(t, ok) + deepTemplate, ok := nested["deep_template"].(map[string]any) + require.True(t, ok) + assert.Equal(t, true, deepTemplate["deep"]) + assert.Equal(t, "just a string", nested["deep_string"]) + + // List processed. + list, ok := result["list"].([]any) + require.True(t, ok) + listItem, ok := list[0].(map[string]any) + require.True(t, ok) + assert.Equal(t, true, listItem["in_list"]) + assert.Equal(t, "regular item", list[1]) + assert.Equal(t, 42, list[2]) +} + +func TestProcessTemplateTagsOnly_EmptyTemplate(t *testing.T) { + input := map[string]any{ + "empty": "!template ", + } + + result := ProcessTemplateTagsOnly(input) + + // Empty template args should be returned as empty string. + assert.Equal(t, "", result["empty"]) +} + +func TestProcessTemplateTagsOnly_WhitespaceTemplate(t *testing.T) { + input := map[string]any{ + "whitespace": "!template ", + } + + result := ProcessTemplateTagsOnly(input) + + // Whitespace-only template args should be trimmed to empty string. + assert.Equal(t, "", result["whitespace"]) +} + +func TestYAMLTag_Template(t *testing.T) { + // Verify the tag format is correct. + expected := "!template" + assert.Equal(t, expected, YAMLTag(TagTemplate)) +} diff --git a/pkg/function/terraform.go b/pkg/function/terraform.go new file mode 100644 index 0000000000..e994c22e27 --- /dev/null +++ b/pkg/function/terraform.go @@ -0,0 +1,153 @@ +package function + +import ( + "context" + "fmt" + "strings" + + log "github.com/cloudposse/atmos/pkg/logger" + "github.com/cloudposse/atmos/pkg/perf" + "github.com/cloudposse/atmos/pkg/utils" +) + +// terraformArgs holds parsed terraform function arguments. +type terraformArgs struct { + component string + stack string + output string +} + +// parseTerraformArgs parses terraform function arguments (component, stack, output). +// Arguments can be either 2 or 3 parts: +// - 2 parts: component output_name (stack from context) +// - 3 parts: component stack output_name +func parseTerraformArgs(args string, execCtx *ExecutionContext) (*terraformArgs, error) { + parts, err := utils.SplitStringByDelimiter(args, ' ') + if err != nil { + return nil, err + } + + var component, stack, output string + + switch len(parts) { + case 3: + component = strings.TrimSpace(parts[0]) + stack = strings.TrimSpace(parts[1]) + output = strings.TrimSpace(parts[2]) + case 2: + component = strings.TrimSpace(parts[0]) + stack = execCtx.Stack + output = strings.TrimSpace(parts[1]) + default: + return nil, fmt.Errorf("%w: terraform function requires 2 or 3 arguments, got %d", ErrInvalidArguments, len(parts)) + } + + return &terraformArgs{component: component, stack: stack, output: output}, nil +} + +// TerraformOutputFunction implements the terraform.output function. +type TerraformOutputFunction struct { + BaseFunction +} + +// NewTerraformOutputFunction creates a new terraform.output function handler. +func NewTerraformOutputFunction() *TerraformOutputFunction { + defer perf.Track(nil, "function.NewTerraformOutputFunction")() + + return &TerraformOutputFunction{ + BaseFunction: BaseFunction{ + FunctionName: TagTerraformOutput, + FunctionAliases: nil, + FunctionPhase: PostMerge, + }, + } +} + +// Execute processes the terraform.output function. +// Usage: +// +// !terraform.output component output_name +// !terraform.output component stack output_name +// +// Note: This is a placeholder that parses args. The actual terraform output +// retrieval is handled by internal/exec which has the full implementation. +func (f *TerraformOutputFunction) Execute(ctx context.Context, args string, execCtx *ExecutionContext) (any, error) { + defer perf.Track(nil, "function.TerraformOutputFunction.Execute")() + + log.Debug("Executing terraform.output function", "args", args) + + if execCtx == nil || execCtx.AtmosConfig == nil { + return nil, fmt.Errorf("%w: terraform.output function requires AtmosConfig", ErrExecutionFailed) + } + + args = strings.TrimSpace(args) + if args == "" { + return nil, fmt.Errorf("%w: terraform.output function requires arguments", ErrInvalidArguments) + } + + // Parse arguments. + parsed, err := parseTerraformArgs(args, execCtx) + if err != nil { + return nil, err + } + + log.Debug("Parsed terraform.output args", "component", parsed.component, "stack", parsed.stack, "output", parsed.output) + + // TODO: The actual implementation requires outputGetter.GetOutput and other + // helpers from internal/exec. For now, return a placeholder error. + // The migration will update internal/exec to call this function. + return nil, fmt.Errorf("%w: terraform.output not yet fully migrated: component=%s stack=%s output=%s", ErrExecutionFailed, parsed.component, parsed.stack, parsed.output) +} + +// TerraformStateFunction implements the terraform.state function. +type TerraformStateFunction struct { + BaseFunction +} + +// NewTerraformStateFunction creates a new terraform.state function handler. +func NewTerraformStateFunction() *TerraformStateFunction { + defer perf.Track(nil, "function.NewTerraformStateFunction")() + + return &TerraformStateFunction{ + BaseFunction: BaseFunction{ + FunctionName: TagTerraformState, + FunctionAliases: nil, + FunctionPhase: PostMerge, + }, + } +} + +// Execute processes the terraform.state function. +// Usage: +// +// !terraform.state component output_name +// !terraform.state component stack output_name +// +// Note: This is a placeholder that parses args. The actual terraform state +// retrieval is handled by internal/exec which has the full implementation. +func (f *TerraformStateFunction) Execute(ctx context.Context, args string, execCtx *ExecutionContext) (any, error) { + defer perf.Track(nil, "function.TerraformStateFunction.Execute")() + + log.Debug("Executing terraform.state function", "args", args) + + if execCtx == nil || execCtx.AtmosConfig == nil { + return nil, fmt.Errorf("%w: terraform.state function requires AtmosConfig", ErrExecutionFailed) + } + + args = strings.TrimSpace(args) + if args == "" { + return nil, fmt.Errorf("%w: terraform.state function requires arguments", ErrInvalidArguments) + } + + // Parse arguments. + parsed, err := parseTerraformArgs(args, execCtx) + if err != nil { + return nil, err + } + + log.Debug("Parsed terraform.state args", "component", parsed.component, "stack", parsed.stack, "output", parsed.output) + + // TODO: The actual implementation requires state retrieval helpers from internal/exec. + // For now, return a placeholder error. + return nil, fmt.Errorf("%w: terraform.state not yet fully migrated: component=%s stack=%s output=%s", ErrExecutionFailed, parsed.component, parsed.stack, parsed.output) +} diff --git a/pkg/function/terraform_test.go b/pkg/function/terraform_test.go new file mode 100644 index 0000000000..21334a2335 --- /dev/null +++ b/pkg/function/terraform_test.go @@ -0,0 +1,244 @@ +package function + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudposse/atmos/pkg/schema" +) + +func TestNewTerraformOutputFunction(t *testing.T) { + fn := NewTerraformOutputFunction() + require.NotNil(t, fn) + assert.Equal(t, TagTerraformOutput, fn.Name()) + assert.Equal(t, PostMerge, fn.Phase()) + assert.Nil(t, fn.Aliases()) +} + +func TestNewTerraformStateFunction(t *testing.T) { + fn := NewTerraformStateFunction() + require.NotNil(t, fn) + assert.Equal(t, TagTerraformState, fn.Name()) + assert.Equal(t, PostMerge, fn.Phase()) + assert.Nil(t, fn.Aliases()) +} + +func TestParseTerraformArgs(t *testing.T) { + tests := []struct { + name string + args string + contextStack string + wantComponent string + wantStack string + wantOutput string + wantErr bool + errContains string + }{ + { + name: "three parts - component stack output", + args: "vpc tenant1-ue2-dev vpc_id", + contextStack: "default", + wantComponent: "vpc", + wantStack: "tenant1-ue2-dev", + wantOutput: "vpc_id", + }, + { + name: "two parts - component output uses context stack", + args: "vpc vpc_id", + contextStack: "tenant1-ue2-prod", + wantComponent: "vpc", + wantStack: "tenant1-ue2-prod", + wantOutput: "vpc_id", + }, + { + name: "with extra whitespace", + args: " vpc tenant1-ue2-dev vpc_id ", + contextStack: "default", + wantComponent: "vpc", + wantStack: "tenant1-ue2-dev", + wantOutput: "vpc_id", + }, + { + name: "too few arguments", + args: "vpc", + contextStack: "default", + wantErr: true, + errContains: "requires 2 or 3 arguments", + }, + { + name: "too many arguments", + args: "vpc stack output extra", + contextStack: "default", + wantErr: true, + errContains: "requires 2 or 3 arguments", + }, + { + name: "empty args", + args: "", + contextStack: "default", + wantErr: true, + errContains: "", // Empty args return EOF error from parser. + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + execCtx := &ExecutionContext{ + Stack: tt.contextStack, + } + + parsed, err := parseTerraformArgs(tt.args, execCtx) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + + require.NoError(t, err) + require.NotNil(t, parsed) + assert.Equal(t, tt.wantComponent, parsed.component) + assert.Equal(t, tt.wantStack, parsed.stack) + assert.Equal(t, tt.wantOutput, parsed.output) + }) + } +} + +func TestTerraformOutputFunction_Execute_NilContext(t *testing.T) { + fn := NewTerraformOutputFunction() + + // Test with nil execution context. + _, err := fn.Execute(context.Background(), "vpc vpc_id", nil) + require.Error(t, err) + assert.ErrorIs(t, err, ErrExecutionFailed) + assert.Contains(t, err.Error(), "requires AtmosConfig") + + // Test with nil AtmosConfig. + execCtx := &ExecutionContext{AtmosConfig: nil} + _, err = fn.Execute(context.Background(), "vpc vpc_id", execCtx) + require.Error(t, err) + assert.ErrorIs(t, err, ErrExecutionFailed) +} + +func TestTerraformOutputFunction_Execute_EmptyArgs(t *testing.T) { + fn := NewTerraformOutputFunction() + execCtx := &ExecutionContext{ + AtmosConfig: &schema.AtmosConfiguration{}, + Stack: "test-stack", + } + + // Test with empty args. + _, err := fn.Execute(context.Background(), "", execCtx) + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidArguments) + assert.Contains(t, err.Error(), "requires arguments") +} + +func TestTerraformOutputFunction_Execute_InvalidArgs(t *testing.T) { + fn := NewTerraformOutputFunction() + execCtx := &ExecutionContext{ + AtmosConfig: &schema.AtmosConfiguration{}, + Stack: "test-stack", + } + + // Test with single argument. + _, err := fn.Execute(context.Background(), "vpc", execCtx) + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidArguments) +} + +func TestTerraformOutputFunction_Execute_NotMigrated(t *testing.T) { + // This tests the current placeholder implementation. + fn := NewTerraformOutputFunction() + execCtx := &ExecutionContext{ + AtmosConfig: &schema.AtmosConfiguration{}, + Stack: "tenant1-ue2-dev", + } + + // Execute with valid args - should return "not yet migrated" error. + _, err := fn.Execute(context.Background(), "vpc tenant1-ue2-dev vpc_id", execCtx) + require.Error(t, err) + assert.ErrorIs(t, err, ErrExecutionFailed) + assert.Contains(t, err.Error(), "not yet fully migrated") + assert.Contains(t, err.Error(), "component=vpc") + assert.Contains(t, err.Error(), "stack=tenant1-ue2-dev") + assert.Contains(t, err.Error(), "output=vpc_id") +} + +func TestTerraformStateFunction_Execute_NilContext(t *testing.T) { + fn := NewTerraformStateFunction() + + // Test with nil execution context. + _, err := fn.Execute(context.Background(), "vpc vpc_id", nil) + require.Error(t, err) + assert.ErrorIs(t, err, ErrExecutionFailed) + assert.Contains(t, err.Error(), "requires AtmosConfig") + + // Test with nil AtmosConfig. + execCtx := &ExecutionContext{AtmosConfig: nil} + _, err = fn.Execute(context.Background(), "vpc vpc_id", execCtx) + require.Error(t, err) + assert.ErrorIs(t, err, ErrExecutionFailed) +} + +func TestTerraformStateFunction_Execute_EmptyArgs(t *testing.T) { + fn := NewTerraformStateFunction() + execCtx := &ExecutionContext{ + AtmosConfig: &schema.AtmosConfiguration{}, + Stack: "test-stack", + } + + // Test with empty args. + _, err := fn.Execute(context.Background(), "", execCtx) + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidArguments) + assert.Contains(t, err.Error(), "requires arguments") +} + +func TestTerraformStateFunction_Execute_InvalidArgs(t *testing.T) { + fn := NewTerraformStateFunction() + execCtx := &ExecutionContext{ + AtmosConfig: &schema.AtmosConfiguration{}, + Stack: "test-stack", + } + + // Test with single argument. + _, err := fn.Execute(context.Background(), "vpc", execCtx) + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidArguments) +} + +func TestTerraformStateFunction_Execute_NotMigrated(t *testing.T) { + // This tests the current placeholder implementation. + fn := NewTerraformStateFunction() + execCtx := &ExecutionContext{ + AtmosConfig: &schema.AtmosConfiguration{}, + Stack: "tenant1-ue2-prod", + } + + // Execute with valid args - should return "not yet migrated" error. + _, err := fn.Execute(context.Background(), "eks cluster_name", execCtx) + require.Error(t, err) + assert.ErrorIs(t, err, ErrExecutionFailed) + assert.Contains(t, err.Error(), "not yet fully migrated") + assert.Contains(t, err.Error(), "component=eks") + assert.Contains(t, err.Error(), "stack=tenant1-ue2-prod") + assert.Contains(t, err.Error(), "output=cluster_name") +} + +func TestTerraformArgs_Struct(t *testing.T) { + args := &terraformArgs{ + component: "vpc", + stack: "tenant1-ue2-dev", + output: "vpc_id", + } + + assert.Equal(t, "vpc", args.component) + assert.Equal(t, "tenant1-ue2-dev", args.stack) + assert.Equal(t, "vpc_id", args.output) +} diff --git a/pkg/provisioner/backend/s3.go b/pkg/provisioner/backend/s3.go index 0466356c14..eba25c4ad4 100644 --- a/pkg/provisioner/backend/s3.go +++ b/pkg/provisioner/backend/s3.go @@ -13,7 +13,7 @@ import ( "github.com/aws/smithy-go" errUtils "github.com/cloudposse/atmos/errors" - "github.com/cloudposse/atmos/internal/aws_utils" + awsIdentity "github.com/cloudposse/atmos/pkg/aws/identity" "github.com/cloudposse/atmos/pkg/perf" "github.com/cloudposse/atmos/pkg/schema" "github.com/cloudposse/atmos/pkg/ui" @@ -171,7 +171,7 @@ func loadAWSConfigWithAuth(ctx context.Context, region, roleArn string, authCont assumeRoleDuration := 1 * time.Hour // Load AWS config with auth context and optional role assumption. - return aws_utils.LoadAWSConfigWithAuth(ctx, region, roleArn, assumeRoleDuration, awsAuthContext) + return awsIdentity.LoadConfigWithAuth(ctx, region, roleArn, assumeRoleDuration, awsAuthContext) } // bucketExists checks if an S3 bucket exists. diff --git a/pkg/store/mock_store.go b/pkg/store/mock_store.go new file mode 100644 index 0000000000..b44f8343d6 --- /dev/null +++ b/pkg/store/mock_store.go @@ -0,0 +1,84 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: store.go +// +// Generated by this command: +// +// mockgen -source=store.go -destination=mock_store.go -package=store +// + +// Package store is a generated GoMock package. +package store + +import ( + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockStore is a mock of Store interface. +type MockStore struct { + ctrl *gomock.Controller + recorder *MockStoreMockRecorder + isgomock struct{} +} + +// MockStoreMockRecorder is the mock recorder for MockStore. +type MockStoreMockRecorder struct { + mock *MockStore +} + +// NewMockStore creates a new mock instance. +func NewMockStore(ctrl *gomock.Controller) *MockStore { + mock := &MockStore{ctrl: ctrl} + mock.recorder = &MockStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStore) EXPECT() *MockStoreMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockStore) Get(stack, component, key string) (any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", stack, component, key) + ret0, _ := ret[0].(any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockStoreMockRecorder) Get(stack, component, key any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockStore)(nil).Get), stack, component, key) +} + +// GetKey mocks base method. +func (m *MockStore) GetKey(key string) (any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetKey", key) + ret0, _ := ret[0].(any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetKey indicates an expected call of GetKey. +func (mr *MockStoreMockRecorder) GetKey(key any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKey", reflect.TypeOf((*MockStore)(nil).GetKey), key) +} + +// Set mocks base method. +func (m *MockStore) Set(stack, component, key string, value any) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Set", stack, component, key, value) + ret0, _ := ret[0].(error) + return ret0 +} + +// Set indicates an expected call of Set. +func (mr *MockStoreMockRecorder) Set(stack, component, key, value any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockStore)(nil).Set), stack, component, key, value) +} diff --git a/pkg/store/store.go b/pkg/store/store.go index add2ca0b9a..8a5487cf20 100644 --- a/pkg/store/store.go +++ b/pkg/store/store.go @@ -3,9 +3,14 @@ package store import "strings" // Store defines the common interface for all store implementations. +// +//go:generate go run go.uber.org/mock/mockgen@v0.6.0 -source=$GOFILE -destination=mock_store.go -package=store type Store interface { + // Set stores a value for a specific stack, component, and key combination. Set(stack string, component string, key string, value any) error + // Get retrieves a value for a specific stack, component, and key combination. Get(stack string, component string, key string) (any, error) + // GetKey retrieves a value directly by key without stack or component context. GetKey(key string) (any, error) } diff --git a/pkg/yaml/doc.go b/pkg/yaml/doc.go new file mode 100644 index 0000000000..15c669998f --- /dev/null +++ b/pkg/yaml/doc.go @@ -0,0 +1,18 @@ +// Package yaml provides YAML parsing, caching, and utility functions for Atmos. +// +// This package contains YAML-specific functionality including: +// - Parsing and unmarshaling YAML with custom tag processing +// - Content-aware caching for parsed YAML documents +// - Position tracking for provenance +// - Output formatting and highlighting +// +// The custom tag processing uses the function registry from pkg/function to +// handle tags like !env, !exec, !terraform.output, etc. +// +// Example usage: +// +// data, err := yaml.UnmarshalYAML[map[string]any](content) +// if err != nil { +// return err +// } +package yaml diff --git a/pkg/yaml/errors.go b/pkg/yaml/errors.go new file mode 100644 index 0000000000..bf000ac5be --- /dev/null +++ b/pkg/yaml/errors.go @@ -0,0 +1,23 @@ +package yaml + +import "errors" + +var ( + // ErrNilAtmosConfig is returned when atmosConfig is nil. + ErrNilAtmosConfig = errors.New("atmosConfig cannot be nil") + + // ErrIncludeInvalidArguments is returned when !include has invalid arguments. + ErrIncludeInvalidArguments = errors.New("invalid number of arguments in the !include function") + + // ErrIncludeFileNotFound is returned when !include references a non-existent file. + ErrIncludeFileNotFound = errors.New("the !include function references a file that does not exist") + + // ErrIncludeAbsPath is returned when converting to absolute path fails. + ErrIncludeAbsPath = errors.New("failed to convert the file path to an absolute path in the !include function") + + // ErrIncludeProcessFailed is returned when processing stack manifest fails. + ErrIncludeProcessFailed = errors.New("failed to process the stack manifest with the !include function") + + // ErrInvalidYAMLFunction is returned when a YAML function has invalid syntax. + ErrInvalidYAMLFunction = errors.New("invalid Atmos YAML function") +) diff --git a/pkg/yaml/position.go b/pkg/yaml/position.go new file mode 100644 index 0000000000..ad37ec8798 --- /dev/null +++ b/pkg/yaml/position.go @@ -0,0 +1,134 @@ +package yaml + +import ( + "github.com/cloudposse/atmos/pkg/perf" + "github.com/cloudposse/atmos/pkg/utils" + goyaml "gopkg.in/yaml.v3" +) + +// Position represents a line and column position in a YAML file. +type Position struct { + Line int // 1-indexed line number. + Column int // 1-indexed column number. +} + +// PositionMap maps JSONPath-style paths to their positions in a YAML file. +type PositionMap map[string]Position + +// ExtractPositions extracts line/column positions from a YAML node tree. +// Returns a map of JSONPath -> Position for all values in the YAML. +// If enabled is false, returns an empty map immediately (zero overhead). +func ExtractPositions(node *goyaml.Node, enabled bool) PositionMap { + defer perf.Track(nil, "yaml.ExtractPositions")() + + if !enabled || node == nil { + return make(PositionMap) + } + + positions := make(PositionMap) + extractPositionsRecursive(node, "", positions) + return positions +} + +// extractPositionsRecursive recursively walks the YAML node tree and records positions. +// +//nolint:gocognit,revive // YAML node traversal requires multiple cases for different node types. +func extractPositionsRecursive(node *goyaml.Node, currentPath string, positions PositionMap) { + if node == nil { + return + } + + switch node.Kind { + case goyaml.DocumentNode: + // Document node wraps the actual content. + if len(node.Content) > 0 { + extractPositionsRecursive(node.Content[0], currentPath, positions) + } + + case goyaml.MappingNode: + // Map: pairs of key-value nodes. + for i := 0; i < len(node.Content); i += 2 { + if i+1 >= len(node.Content) { + break + } + + keyNode := node.Content[i] + valueNode := node.Content[i+1] + + // Get the key as a string. + key := keyNode.Value + + // Build the path for this key. + var path string + if currentPath == "" { + path = key + } else { + path = utils.AppendJSONPathKey(currentPath, key) + } + + // Record position for this value. + positions[path] = Position{ + Line: valueNode.Line, + Column: valueNode.Column, + } + + // Recurse into the value. + extractPositionsRecursive(valueNode, path, positions) + } + + case goyaml.SequenceNode: + // Array: list of nodes. + for i, itemNode := range node.Content { + // Build the path with array index. + path := utils.AppendJSONPathIndex(currentPath, i) + + // Record position for this item. + positions[path] = Position{ + Line: itemNode.Line, + Column: itemNode.Column, + } + + // Recurse into the item. + extractPositionsRecursive(itemNode, path, positions) + } + + case goyaml.ScalarNode: + // Leaf value - position already recorded by parent. + // Nothing to do here. + + case goyaml.AliasNode: + // YAML alias (*anchor) - recurse into the aliased node. + if node.Alias != nil { + extractPositionsRecursive(node.Alias, currentPath, positions) + } + } +} + +// GetPosition gets the position for a specific JSONPath from the position map. +// Returns Position{0, 0} if not found. +func GetPosition(positions PositionMap, path string) Position { + defer perf.Track(nil, "yaml.GetPosition")() + + if positions == nil { + return Position{} + } + + pos, exists := positions[path] + if !exists { + return Position{} + } + + return pos +} + +// HasPosition checks if a position exists for a specific JSONPath. +func HasPosition(positions PositionMap, path string) bool { + defer perf.Track(nil, "yaml.HasPosition")() + + if positions == nil { + return false + } + + _, exists := positions[path] + return exists +} diff --git a/pkg/yaml/position_test.go b/pkg/yaml/position_test.go new file mode 100644 index 0000000000..282ed69743 --- /dev/null +++ b/pkg/yaml/position_test.go @@ -0,0 +1,253 @@ +package yaml + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + goyaml "gopkg.in/yaml.v3" +) + +func TestExtractPositions_Disabled(t *testing.T) { + // When disabled, should return empty map. + node := &goyaml.Node{} + positions := ExtractPositions(node, false) + assert.Empty(t, positions) +} + +func TestExtractPositions_NilNode(t *testing.T) { + // Should handle nil node gracefully. + positions := ExtractPositions(nil, true) + assert.Empty(t, positions) +} + +func TestExtractPositions_ScalarValue(t *testing.T) { + yamlContent := `key: value` + var node goyaml.Node + err := goyaml.Unmarshal([]byte(yamlContent), &node) + require.NoError(t, err) + + positions := ExtractPositions(&node, true) + + // Should have position for 'key'. + assert.True(t, HasPosition(positions, "key")) + pos := GetPosition(positions, "key") + assert.Equal(t, 1, pos.Line) +} + +func TestExtractPositions_NestedMapping(t *testing.T) { + yamlContent := ` +parent: + child1: value1 + child2: value2 +` + var node goyaml.Node + err := goyaml.Unmarshal([]byte(yamlContent), &node) + require.NoError(t, err) + + positions := ExtractPositions(&node, true) + + // Should have positions for all keys. + assert.True(t, HasPosition(positions, "parent")) + assert.True(t, HasPosition(positions, "parent.child1")) + assert.True(t, HasPosition(positions, "parent.child2")) +} + +func TestExtractPositions_Sequence(t *testing.T) { + yamlContent := ` +items: + - first + - second + - third +` + var node goyaml.Node + err := goyaml.Unmarshal([]byte(yamlContent), &node) + require.NoError(t, err) + + positions := ExtractPositions(&node, true) + + // Should have positions for array items. + assert.True(t, HasPosition(positions, "items")) + assert.True(t, HasPosition(positions, "items[0]")) + assert.True(t, HasPosition(positions, "items[1]")) + assert.True(t, HasPosition(positions, "items[2]")) +} + +func TestExtractPositions_MixedContent(t *testing.T) { + yamlContent := ` +metadata: + name: test + labels: + app: myapp + version: v1 +servers: + - host: server1 + port: 8080 + - host: server2 + port: 9090 +` + var node goyaml.Node + err := goyaml.Unmarshal([]byte(yamlContent), &node) + require.NoError(t, err) + + positions := ExtractPositions(&node, true) + + // Test various paths. + assert.True(t, HasPosition(positions, "metadata")) + assert.True(t, HasPosition(positions, "metadata.name")) + assert.True(t, HasPosition(positions, "metadata.labels")) + assert.True(t, HasPosition(positions, "metadata.labels.app")) + assert.True(t, HasPosition(positions, "metadata.labels.version")) + assert.True(t, HasPosition(positions, "servers")) + assert.True(t, HasPosition(positions, "servers[0]")) + assert.True(t, HasPosition(positions, "servers[0].host")) + assert.True(t, HasPosition(positions, "servers[0].port")) + assert.True(t, HasPosition(positions, "servers[1]")) + assert.True(t, HasPosition(positions, "servers[1].host")) + assert.True(t, HasPosition(positions, "servers[1].port")) +} + +func TestExtractPositions_AliasNode(t *testing.T) { + yamlContent := ` +defaults: &defaults + adapter: postgres + host: localhost + +development: + database: dev_db + <<: *defaults +` + var node goyaml.Node + err := goyaml.Unmarshal([]byte(yamlContent), &node) + require.NoError(t, err) + + positions := ExtractPositions(&node, true) + + // Should have positions for anchor and alias content. + assert.True(t, HasPosition(positions, "defaults")) + assert.True(t, HasPosition(positions, "defaults.adapter")) + assert.True(t, HasPosition(positions, "development")) +} + +func TestGetPosition_NotFound(t *testing.T) { + positions := PositionMap{ + "exists": Position{Line: 1, Column: 1}, + } + + // Existing path. + pos := GetPosition(positions, "exists") + assert.Equal(t, 1, pos.Line) + assert.Equal(t, 1, pos.Column) + + // Non-existing path returns zero value. + pos = GetPosition(positions, "not-exists") + assert.Equal(t, 0, pos.Line) + assert.Equal(t, 0, pos.Column) +} + +func TestGetPosition_NilMap(t *testing.T) { + pos := GetPosition(nil, "any") + assert.Equal(t, 0, pos.Line) + assert.Equal(t, 0, pos.Column) +} + +func TestHasPosition_NilMap(t *testing.T) { + assert.False(t, HasPosition(nil, "any")) +} + +func TestHasPosition_ExistsAndNotExists(t *testing.T) { + positions := PositionMap{ + "exists": Position{Line: 5, Column: 10}, + } + + assert.True(t, HasPosition(positions, "exists")) + assert.False(t, HasPosition(positions, "not-exists")) +} + +func TestExtractPositions_EmptyDocument(t *testing.T) { + yamlContent := `` + var node goyaml.Node + err := goyaml.Unmarshal([]byte(yamlContent), &node) + require.NoError(t, err) + + positions := ExtractPositions(&node, true) + assert.Empty(t, positions) +} + +func TestExtractPositions_DeeplyNested(t *testing.T) { + yamlContent := ` +a: + b: + c: + d: + e: deep +` + var node goyaml.Node + err := goyaml.Unmarshal([]byte(yamlContent), &node) + require.NoError(t, err) + + positions := ExtractPositions(&node, true) + + assert.True(t, HasPosition(positions, "a")) + assert.True(t, HasPosition(positions, "a.b")) + assert.True(t, HasPosition(positions, "a.b.c")) + assert.True(t, HasPosition(positions, "a.b.c.d")) + assert.True(t, HasPosition(positions, "a.b.c.d.e")) +} + +func TestExtractPositions_SequenceOfMappings(t *testing.T) { + yamlContent := ` +list: + - name: item1 + value: 100 + - name: item2 + value: 200 +` + var node goyaml.Node + err := goyaml.Unmarshal([]byte(yamlContent), &node) + require.NoError(t, err) + + positions := ExtractPositions(&node, true) + + assert.True(t, HasPosition(positions, "list")) + assert.True(t, HasPosition(positions, "list[0]")) + assert.True(t, HasPosition(positions, "list[0].name")) + assert.True(t, HasPosition(positions, "list[0].value")) + assert.True(t, HasPosition(positions, "list[1]")) + assert.True(t, HasPosition(positions, "list[1].name")) + assert.True(t, HasPosition(positions, "list[1].value")) +} + +func TestExtractPositions_LineNumbers(t *testing.T) { + yamlContent := `first: value1 +second: value2 +third: value3` + var node goyaml.Node + err := goyaml.Unmarshal([]byte(yamlContent), &node) + require.NoError(t, err) + + positions := ExtractPositions(&node, true) + + // Verify line numbers are correct. + assert.Equal(t, 1, GetPosition(positions, "first").Line) + assert.Equal(t, 2, GetPosition(positions, "second").Line) + assert.Equal(t, 3, GetPosition(positions, "third").Line) +} + +func TestPosition_Struct(t *testing.T) { + pos := Position{ + Line: 10, + Column: 5, + } + + assert.Equal(t, 10, pos.Line) + assert.Equal(t, 5, pos.Column) +} + +func TestPositionMap_Type(t *testing.T) { + // Test that PositionMap works as a map. + pm := make(PositionMap) + pm["test"] = Position{Line: 1, Column: 2} + + assert.Equal(t, Position{Line: 1, Column: 2}, pm["test"]) +} diff --git a/pkg/yaml/types.go b/pkg/yaml/types.go new file mode 100644 index 0000000000..fbdf2b5a57 --- /dev/null +++ b/pkg/yaml/types.go @@ -0,0 +1,30 @@ +package yaml + +import ( + "github.com/cloudposse/atmos/pkg/perf" + goyaml "gopkg.in/yaml.v3" +) + +// DefaultIndent is the default indentation for YAML output. +const DefaultIndent = 2 + +// Options configures YAML encoding behavior. +type Options struct { + Indent int +} + +// LongString is a string type that encodes as a YAML folded scalar (>). +// This is used to wrap long strings across multiple lines for better readability. +type LongString string + +// MarshalYAML implements yaml.Marshaler to encode as a folded scalar. +func (s LongString) MarshalYAML() (interface{}, error) { + defer perf.Track(nil, "yaml.LongString.MarshalYAML")() + + node := &goyaml.Node{ + Kind: goyaml.ScalarNode, + Style: goyaml.FoldedStyle, // Use > style for folded scalar. + Value: string(s), + } + return node, nil +} diff --git a/pkg/yaml/types_test.go b/pkg/yaml/types_test.go new file mode 100644 index 0000000000..1717275afb --- /dev/null +++ b/pkg/yaml/types_test.go @@ -0,0 +1,68 @@ +package yaml + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + goyaml "gopkg.in/yaml.v3" +) + +func TestLongString_MarshalYAML(t *testing.T) { + tests := []struct { + name string + input LongString + wantKind goyaml.Kind + }{ + { + name: "simple string", + input: LongString("hello world"), + wantKind: goyaml.ScalarNode, + }, + { + name: "empty string", + input: LongString(""), + wantKind: goyaml.ScalarNode, + }, + { + name: "multiline string", + input: LongString("line1\nline2\nline3"), + wantKind: goyaml.ScalarNode, + }, + { + name: "long string", + input: LongString("This is a very long string that should be wrapped using the folded scalar style in YAML output"), + wantKind: goyaml.ScalarNode, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := tt.input.MarshalYAML() + require.NoError(t, err) + + node, ok := result.(*goyaml.Node) + require.True(t, ok, "result should be a *goyaml.Node") + assert.Equal(t, tt.wantKind, node.Kind) + assert.Equal(t, goyaml.FoldedStyle, node.Style) + assert.Equal(t, string(tt.input), node.Value) + }) + } +} + +func TestLongString_MarshalYAML_Integration(t *testing.T) { + // Test that LongString marshals correctly when embedded in a struct. + type testStruct struct { + Description LongString `yaml:"description"` + } + + ts := testStruct{ + Description: LongString("This is a long description that should be output as a folded scalar"), + } + + out, err := goyaml.Marshal(&ts) + require.NoError(t, err) + + // The output should use folded style (>). + assert.Contains(t, string(out), "description: >") +} diff --git a/tools/lintroller/rule_perf_track.go b/tools/lintroller/rule_perf_track.go index c1b843d1ed..4f10b0b205 100644 --- a/tools/lintroller/rule_perf_track.go +++ b/tools/lintroller/rule_perf_track.go @@ -73,6 +73,8 @@ var excludedReceivers = []string{ "DescribeConfigFormatError", // Error types. "DefaultStacksProcessor", // Processor implementations. "AtmosFuncs", // Template function wrappers (high-frequency). + "ExecutionContext", // Trivial With* mutators in pkg/function. + "Phase", // Trivial String() method in pkg/function. } // Functions to exclude from perf.Track() checks (by name). @@ -113,6 +115,9 @@ func (r *PerfTrackRule) Check(pass *analysis.Pass, file *ast.File) error { } // Check if package is in exclusion list. + if pass.Pkg == nil { + return nil + } pkgPath := pass.Pkg.Path() for _, excluded := range excludedPackages { // Match only complete path segments to avoid false positives. diff --git a/website/blog/2025-12-18-function-registry-package.mdx b/website/blog/2025-12-18-function-registry-package.mdx new file mode 100644 index 0000000000..62ffea57dd --- /dev/null +++ b/website/blog/2025-12-18-function-registry-package.mdx @@ -0,0 +1,39 @@ +--- +slug: function-registry-package +title: "New pkg/function Package for Format-Agnostic Function Registry" +authors: [osterman] +tags: [core] +--- + +Introduces `pkg/function/`, a new format-agnostic function registry that consolidates YAML function handlers into a reusable package. + + + +## What Changed + +This release adds foundational packages for modular function handling: + +- **`pkg/function/`**: Format-agnostic function registry with handlers for all YAML functions (`!env`, `!exec`, `!terraform.output`, `!store.get`, `!literal`, etc.) +- **`pkg/yaml/`**: YAML-specific utilities for position tracking and error handling +- **`pkg/aws/identity/`**: Consolidated AWS identity caching (moved from `internal/aws_utils`) + +## Why This Matters + +The function registry separates concerns between format-specific parsing (YAML, HCL, JSON) and format-agnostic function execution. This enables: + +- **Code Reuse**: Single registry used across all configuration formats +- **Extensibility**: New functions can be added without modifying core parsing logic +- **Testing**: Interface-driven design with dependency injection for better testability +- **Plugin Architecture**: Foundation for future plugin support + +## Technical Details + +Functions are organized by execution phase: +- **PreMerge**: `!env`, `!exec`, `!random`, `!template`, `!include`, `!literal` +- **PostMerge**: `!terraform.output`, `!terraform.state`, `!store.get`, `!aws.*` + +The registry provides thread-safe registration, lookup by name or alias, and phase-based filtering. + +## Get Involved + +This is preparatory work for broader YAML processing refactoring. Contributions and feedback are welcome at [github.com/cloudposse/atmos](https://github.com/cloudposse/atmos).