Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion pkg/auth/manager_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"errors"
"fmt"
"strings"
"sync"
"time"

errUtils "github.com/cloudposse/atmos/errors"
Expand All @@ -18,8 +20,50 @@ import (
// logKeyExpirationChain is the log key for expiration values in chain operations.
const logKeyExpirationChain = "expiration"

// processCredentialCache is a process-level in-memory cache for authenticated credentials.
// Unlike keyring/file caches which persist across processes and may contain stale data,
// this cache only holds credentials authenticated during the current process, so they are
// guaranteed to be correct. This avoids redundant AssumeRole API calls when multiple
// components in the same command share the same authentication chain (e.g., during
// `atmos describe affected` which resolves many `!terraform.state` YAML functions).
var processCredentialCache sync.Map // key: "realm:chain" string, value: *processCachedCreds

// processCachedCreds holds credentials cached in-memory for the current process.
type processCachedCreds struct {
credentials types.ICredentials
}

// resetProcessCredentialCache clears the process-level credential cache.
// This is intended for use in tests to ensure isolation between test cases.
func resetProcessCredentialCache() {
processCredentialCache.Range(func(key, _ any) bool {
processCredentialCache.Delete(key)
return true
})
}

// chainCacheKey returns a unique cache key for the current chain and realm.
func (m *manager) chainCacheKey() string {
return m.realm.Value + ":" + strings.Join(m.chain, "->")
}

// authenticateChain performs credential chain authentication with bottom-up validation.
func (m *manager) authenticateChain(ctx context.Context, _ string) (types.ICredentials, error) {
// Fast path: check process-level in-memory cache.
// Credentials authenticated during this process are guaranteed correct, unlike
// keyring/file caches which may hold stale data from previous runs.
cacheKey := m.chainCacheKey()
if entry, ok := processCredentialCache.Load(cacheKey); ok {
cached := entry.(*processCachedCreds)
if valid, _ := m.isCredentialValid("process-cache", cached.credentials); valid {
log.Debug("Using process-cached credentials for chain", "chain", m.chain)
return cached.credentials, nil
}
// Expired — remove stale entry.
processCredentialCache.Delete(cacheKey)
log.Debug("Process-cached credentials expired, re-authenticating", "chain", m.chain)
}

// Step 1: Bottom-up validation - check cached credentials from target to root.
validFromIndex := m.findFirstValidCachedCredentials()

Expand All @@ -32,7 +76,17 @@ func (m *manager) authenticateChain(ctx context.Context, _ string) (types.ICrede
// has cached credentials. This ensures assume-role identities perform the actual
// AssumeRole API call rather than using potentially incorrect cached credentials
// (e.g., permission set creds incorrectly cached as assume-role creds).
return m.authenticateFromIndex(ctx, validFromIndex)
creds, err := m.authenticateFromIndex(ctx, validFromIndex)
if err != nil {
return nil, err
}

// Cache the successfully authenticated credentials for this process.
processCredentialCache.Store(cacheKey, &processCachedCreds{
credentials: creds,
})

return creds, nil
}

// findFirstValidCachedCredentials checks cached credentials from bottom to top of chain.
Expand Down
258 changes: 258 additions & 0 deletions pkg/auth/manager_chain_process_cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
package auth

import (
"context"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/cloudposse/atmos/pkg/auth/realm"
"github.com/cloudposse/atmos/pkg/auth/types"
"github.com/cloudposse/atmos/pkg/schema"
)

// countingIdentity tracks how many times Authenticate is called.
type countingIdentity struct {
provider string
callCount atomic.Int32
creds types.ICredentials
}

func (c *countingIdentity) Kind() string { return "aws/assume-role" }
func (c *countingIdentity) GetProviderName() (string, error) { return c.provider, nil }
func (c *countingIdentity) Authenticate(_ context.Context, _ types.ICredentials) (types.ICredentials, error) {
c.callCount.Add(1)
return c.creds, nil
}
func (c *countingIdentity) Validate() error { return nil }
func (c *countingIdentity) Environment() (map[string]string, error) { return nil, nil }
func (c *countingIdentity) Paths() ([]types.Path, error) { return []types.Path{}, nil }
func (c *countingIdentity) PostAuthenticate(_ context.Context, _ *types.PostAuthenticateParams) error {
return nil
}
func (c *countingIdentity) Logout(_ context.Context) error { return nil }
func (c *countingIdentity) CredentialsExist() (bool, error) { return false, nil }
func (c *countingIdentity) LoadCredentials(_ context.Context) (types.ICredentials, error) {
return nil, nil
}

func (c *countingIdentity) PrepareEnvironment(_ context.Context, environ map[string]string) (map[string]string, error) {
return environ, nil
}
func (c *countingIdentity) SetRealm(_ string) {}

func TestProcessCredentialCache_AvoidsDuplicateAuth(t *testing.T) {
resetProcessCredentialCache()
t.Cleanup(resetProcessCredentialCache)

exp := time.Now().UTC().Add(1 * time.Hour)
identityCreds := &testCreds{exp: &exp}
identity := &countingIdentity{provider: "prov", creds: identityCreds}

providerCreds := &testCreds{}
provider := &testProvider{name: "prov", creds: providerCreds}

store := &testStore{data: map[string]any{}}

// Create first manager and authenticate.
m1 := &manager{
config: &schema.AuthConfig{
Identities: map[string]schema.Identity{
"role": {Kind: "aws/assume-role", Via: &schema.IdentityVia{Provider: "prov"}},
},
},
providers: map[string]types.Provider{"prov": provider},
identities: map[string]types.Identity{"role": identity},
credentialStore: store,
chain: []string{"prov", "role"},
realm: realm.RealmInfo{Value: "test-realm"},
}

creds1, err := m1.authenticateChain(context.Background(), "role")
require.NoError(t, err)
assert.Equal(t, identityCreds, creds1)
assert.Equal(t, int32(1), identity.callCount.Load(), "identity.Authenticate should be called once")

// Create second manager with same chain and realm (simulates new manager for nested component).
m2 := &manager{
config: &schema.AuthConfig{
Identities: map[string]schema.Identity{
"role": {Kind: "aws/assume-role", Via: &schema.IdentityVia{Provider: "prov"}},
},
},
providers: map[string]types.Provider{"prov": provider},
identities: map[string]types.Identity{"role": identity},
credentialStore: store,
chain: []string{"prov", "role"},
realm: realm.RealmInfo{Value: "test-realm"},
}

creds2, err := m2.authenticateChain(context.Background(), "role")
require.NoError(t, err)
assert.Equal(t, identityCreds, creds2)
assert.Equal(t, int32(1), identity.callCount.Load(), "identity.Authenticate should NOT be called again (cache hit)")
}

func TestProcessCredentialCache_DifferentChainMisses(t *testing.T) {
resetProcessCredentialCache()
t.Cleanup(resetProcessCredentialCache)

exp := time.Now().UTC().Add(1 * time.Hour)
identity1Creds := &testCreds{exp: &exp}
identity1 := &countingIdentity{provider: "prov", creds: identity1Creds}

identity2Creds := &testCreds{exp: &exp}
identity2 := &countingIdentity{provider: "prov", creds: identity2Creds}

providerCreds := &testCreds{}
provider := &testProvider{name: "prov", creds: providerCreds}

store := &testStore{data: map[string]any{}}

// Authenticate chain ["prov", "role1"].
m1 := &manager{
config: &schema.AuthConfig{
Identities: map[string]schema.Identity{
"role1": {Kind: "aws/assume-role", Via: &schema.IdentityVia{Provider: "prov"}},
},
},
providers: map[string]types.Provider{"prov": provider},
identities: map[string]types.Identity{"role1": identity1},
credentialStore: store,
chain: []string{"prov", "role1"},
realm: realm.RealmInfo{Value: "test-realm"},
}

_, err := m1.authenticateChain(context.Background(), "role1")
require.NoError(t, err)
assert.Equal(t, int32(1), identity1.callCount.Load())

// Authenticate different chain ["prov", "role2"] - should NOT use cache.
m2 := &manager{
config: &schema.AuthConfig{
Identities: map[string]schema.Identity{
"role2": {Kind: "aws/assume-role", Via: &schema.IdentityVia{Provider: "prov"}},
},
},
providers: map[string]types.Provider{"prov": provider},
identities: map[string]types.Identity{"role2": identity2},
credentialStore: store,
chain: []string{"prov", "role2"},
realm: realm.RealmInfo{Value: "test-realm"},
}

_, err = m2.authenticateChain(context.Background(), "role2")
require.NoError(t, err)
assert.Equal(t, int32(1), identity2.callCount.Load(), "different chain should authenticate independently")
}

func TestProcessCredentialCache_ExpiredCredsReauthenticate(t *testing.T) {
resetProcessCredentialCache()
t.Cleanup(resetProcessCredentialCache)

// Seed the cache with expired credentials.
expiredTime := time.Now().UTC().Add(-1 * time.Hour)
expiredCreds := &testCreds{exp: &expiredTime}
processCredentialCache.Store("test-realm:prov->role", &processCachedCreds{
credentials: expiredCreds,
})

freshExp := time.Now().UTC().Add(1 * time.Hour)
freshCreds := &testCreds{exp: &freshExp}
identity := &countingIdentity{provider: "prov", creds: freshCreds}

providerCreds := &testCreds{}
provider := &testProvider{name: "prov", creds: providerCreds}

store := &testStore{data: map[string]any{}}

m := &manager{
config: &schema.AuthConfig{
Identities: map[string]schema.Identity{
"role": {Kind: "aws/assume-role", Via: &schema.IdentityVia{Provider: "prov"}},
},
},
providers: map[string]types.Provider{"prov": provider},
identities: map[string]types.Identity{"role": identity},
credentialStore: store,
chain: []string{"prov", "role"},
realm: realm.RealmInfo{Value: "test-realm"},
}

creds, err := m.authenticateChain(context.Background(), "role")
require.NoError(t, err)
assert.Equal(t, freshCreds, creds)
assert.Equal(t, int32(1), identity.callCount.Load(), "should re-authenticate when cache is expired")
}

func TestProcessCredentialCache_DifferentRealmMisses(t *testing.T) {
resetProcessCredentialCache()
t.Cleanup(resetProcessCredentialCache)

exp := time.Now().UTC().Add(1 * time.Hour)
identity1Creds := &testCreds{exp: &exp}
identity1 := &countingIdentity{provider: "prov", creds: identity1Creds}

identity2Creds := &testCreds{exp: &exp}
identity2 := &countingIdentity{provider: "prov", creds: identity2Creds}

providerCreds := &testCreds{}
provider := &testProvider{name: "prov", creds: providerCreds}

store := &testStore{data: map[string]any{}}

// Authenticate with realm "realm-a".
m1 := &manager{
config: &schema.AuthConfig{
Identities: map[string]schema.Identity{
"role": {Kind: "aws/assume-role", Via: &schema.IdentityVia{Provider: "prov"}},
},
},
providers: map[string]types.Provider{"prov": provider},
identities: map[string]types.Identity{"role": identity1},
credentialStore: store,
chain: []string{"prov", "role"},
realm: realm.RealmInfo{Value: "realm-a"},
}

_, err := m1.authenticateChain(context.Background(), "role")
require.NoError(t, err)
assert.Equal(t, int32(1), identity1.callCount.Load())

// Same chain but different realm - should NOT use cache.
m2 := &manager{
config: &schema.AuthConfig{
Identities: map[string]schema.Identity{
"role": {Kind: "aws/assume-role", Via: &schema.IdentityVia{Provider: "prov"}},
},
},
providers: map[string]types.Provider{"prov": provider},
identities: map[string]types.Identity{"role": identity2},
credentialStore: store,
chain: []string{"prov", "role"},
realm: realm.RealmInfo{Value: "realm-b"},
}

_, err = m2.authenticateChain(context.Background(), "role")
require.NoError(t, err)
assert.Equal(t, int32(1), identity2.callCount.Load(), "different realm should authenticate independently")
}

func Test_resetProcessCredentialCache(t *testing.T) {
// Store something in the cache.
processCredentialCache.Store("test-key", &processCachedCreds{})

// Verify it exists.
_, ok := processCredentialCache.Load("test-key")
require.True(t, ok)

// Reset.
resetProcessCredentialCache()

// Verify it's gone.
_, ok = processCredentialCache.Load("test-key")
assert.False(t, ok)
}
Loading