Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 34 additions & 44 deletions cmd/lakectl/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"reflect"
"slices"
"strings"
"sync"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
Expand Down Expand Up @@ -184,17 +183,12 @@ const (
)

const (
CacheFileName = "lakectl_token_cache.json"
LakectlDirName = ".lakectl"
CacheDirName = "cache"
cacheFileName = "lakectl_token_cache.json"
)

var (
cachedToken *apigen.AuthenticationToken
tokenLoadOnce sync.Once
tokenCache *awsiam.JWTCache
tokenCacheOnce sync.Once
ErrTokenUnavailable = fmt.Errorf("token is not available")
ErrCacheUnavailable = fmt.Errorf("cache is not available")
)

func withRecursiveFlag(cmd *cobra.Command, usage string) {
Expand Down Expand Up @@ -654,15 +648,17 @@ func getClient() *apigen.ClientWithResponses {

func CreateTokenCacheCallback() awsiam.TokenCacheCallback {
return func(newToken *apigen.AuthenticationToken) {
cachedToken = newToken
if err := SaveTokenToCache(); err != nil {
if err := SaveTokenToCache(newToken); err != nil {
logging.ContextUnavailable().Debugf("error saving token to cache: %w", err)
}
}
}

func getClientOptions(awsIAMparams *awsiam.IAMAuthParams, serverEndpoint string) []apigen.ClientOption {
token := getTokenOnce()
token, err := getToken()
if err != nil {
logging.ContextUnavailable().Debugf("no token available in cache: %w", err)
}

tokenCacheCallback := CreateTokenCacheCallback()

Expand All @@ -680,7 +676,6 @@ func getClientOptions(awsIAMparams *awsiam.IAMAuthParams, serverEndpoint string)
DieErr(err)
}
loginClient := &awsiam.ExternalPrincipalLoginClient{Client: noAuthClient}

awsAuthProvider := awsiam.WithAWSIAMRoleAuthProviderOption(
awsIAMparams,
logging.ContextUnavailable(),
Expand All @@ -692,47 +687,42 @@ func getClientOptions(awsIAMparams *awsiam.IAMAuthParams, serverEndpoint string)
return []apigen.ClientOption{awsAuthProvider}
}

func getTokenOnce() *apigen.AuthenticationToken {
tokenLoadOnce.Do(func() {
cache := getTokenCacheOnce()
var err error
if cache != nil {
if token, err := cache.GetToken(); err == nil {
cachedToken = token
return
}
logging.ContextUnavailable().Debugf("Error loading token from cache: %w", err)
func getToken() (*apigen.AuthenticationToken, error) {
cache := getTokenCache()
if cache != nil {
token, err := cache.GetToken()
if err != nil {
return nil, err
}
})
return cachedToken
return token, nil
}
return nil, ErrCacheUnavailable
}

func getTokenCacheOnce() *awsiam.JWTCache {
tokenCacheOnce.Do(func() {
homeDir, err := os.UserHomeDir()
if err != nil {
logging.ContextUnavailable().Debugf("Error getting user homedir: %w", err)
}
cache, err := awsiam.NewJWTCache(homeDir, LakectlDirName, CacheDirName, CacheFileName)
if err != nil {
logging.ContextUnavailable().Debugf("Error creating token cache: %w", err)
tokenCache = nil
} else {
tokenCache = cache
}
})
return tokenCache
func getTokenCache() *awsiam.JWTCache {
homeDir, err := os.UserHomeDir()
if err != nil {
logging.ContextUnavailable().Debugf("Error getting user homedir: %w", err)
}
cache, err := awsiam.NewJWTCache(homeDir, cacheFileName)
if err != nil {
logging.ContextUnavailable().Debugf("Error creating token cache: %w", err)
return nil
}
return cache
}

func SaveTokenToCache() error {
cache := getTokenCacheOnce()
if cache == nil || cachedToken == nil {
func SaveTokenToCache(newToken *apigen.AuthenticationToken) error {
cache := getTokenCache()
if cache == nil {
return ErrCacheUnavailable
}
if newToken == nil {
return ErrTokenUnavailable
}
if err := cache.SaveToken(cachedToken); err != nil {
if err := cache.SaveToken(newToken); err != nil {
return err
}
tokenLoadOnce = sync.Once{}
return nil
}

Expand Down
6 changes: 4 additions & 2 deletions pkg/authentication/externalidp/awsiam/token_caching.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ var ErrInvalidTokenFormat = fmt.Errorf("token format is invalid")

const (
ReadWriteExecuteOwnerOnly = 0700
LakectlDirName = ".lakectl"
CacheDirName = "cache"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. What's the purpose of those directories? Why not just ~/.lakectl_token_cache.json? Or just ~/.lakectl/token_cache.json (If you really need a dir)
  2. This is not the right location for such things. It belongs to lakectl code parts, not here.

)

type TokenCache struct {
Expand All @@ -25,15 +27,15 @@ type JWTCache struct {
FilePath string
}

func NewJWTCache(baseDir, lakectlDir, cacheDir, fileName string) (*JWTCache, error) {
func NewJWTCache(baseDir, fileName string) (*JWTCache, error) {
if baseDir == "" {
var err error
baseDir, err = os.UserHomeDir()
if err != nil {
return nil, err
}
}
cachePath := filepath.Join(baseDir, lakectlDir, cacheDir)
cachePath := filepath.Join(baseDir, LakectlDirName, CacheDirName)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those sub diretories are irrelevant for this component.
It doesn't need to force or even know about LakectlDirName, CacheDirName

if err := os.MkdirAll(cachePath, ReadWriteExecuteOwnerOnly); err != nil {
return nil, ErrFailedToCreateCacheDir
}
Expand Down
20 changes: 10 additions & 10 deletions pkg/authentication/externalidp/awsiam/token_caching_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ import (
func TestNewJWTCache(t *testing.T) {
t.Run("with custom cache dir", func(t *testing.T) {
tempDir := t.TempDir()
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl", "cache", "lakectl_token_cache.json")
cache, err := awsiam.NewJWTCache(tempDir, "lakectl_token_cache.json")
require.NoError(t, err)
require.NotEmpty(t, cache)
require.Equal(t, filepath.Join(tempDir, ".lakectl", "cache", "lakectl_token_cache.json"), cache.FilePath)
})

t.Run("with empty cache dir uses home dir", func(t *testing.T) {
cache, err := awsiam.NewJWTCache("", ".lakectl", "cache", "lakectl_token_cache.json")
cache, err := awsiam.NewJWTCache("", "lakectl_token_cache.json")
require.NoError(t, err)
require.NotEmpty(t, cache)
homeDir, _ := os.UserHomeDir()
Expand All @@ -34,7 +34,7 @@ func TestNewJWTCache(t *testing.T) {
func TestJWTCacheSaveToken(t *testing.T) {
t.Run("saves valid token successfully", func(t *testing.T) {
tempDir := t.TempDir()
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl", "cache", "lakectl_token_cache.json")
cache, err := awsiam.NewJWTCache(tempDir, "lakectl_token_cache.json")
require.NoError(t, err)

expirationTime := time.Now().Add(1 * time.Hour).Unix()
Expand Down Expand Up @@ -63,7 +63,7 @@ func TestJWTCacheSaveToken(t *testing.T) {

t.Run("handles nil token", func(t *testing.T) {
tempDir := t.TempDir()
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl", "cache", "lakectl_token_cache.json")
cache, err := awsiam.NewJWTCache(tempDir, "lakectl_token_cache.json")
require.NoError(t, err)

err = cache.SaveToken(nil)
Expand All @@ -76,7 +76,7 @@ func TestJWTCacheSaveToken(t *testing.T) {

t.Run("handles token with empty string", func(t *testing.T) {
tempDir := t.TempDir()
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl", "cache", "lakectl_token_cache.json")
cache, err := awsiam.NewJWTCache(tempDir, "lakectl_token_cache.json")
require.NoError(t, err)

token := &apigen.AuthenticationToken{
Expand All @@ -93,7 +93,7 @@ func TestJWTCacheSaveToken(t *testing.T) {

t.Run("handles token without expiration", func(t *testing.T) {
tempDir := t.TempDir()
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl", "cache", "lakectl_token_cache.json")
cache, err := awsiam.NewJWTCache(tempDir, "lakectl_token_cache.json")
require.NoError(t, err)

token := &apigen.AuthenticationToken{
Expand All @@ -113,7 +113,7 @@ func TestJWTCacheSaveToken(t *testing.T) {
func TestJWTCacheGetToken(t *testing.T) {
t.Run("loads valid non-expired token", func(t *testing.T) {
tempDir := t.TempDir()
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl", "cache", "lakectl_token_cache.json")
cache, err := awsiam.NewJWTCache(tempDir, "lakectl_token_cache.json")
require.NoError(t, err)

expirationTime := time.Now().Add(1 * time.Hour).Unix()
Expand All @@ -135,7 +135,7 @@ func TestJWTCacheGetToken(t *testing.T) {

t.Run("returns nil when cache file doesn't exist", func(t *testing.T) {
tempDir := t.TempDir()
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl", "cache", "lakectl_token_cache.json")
cache, err := awsiam.NewJWTCache(tempDir, "lakectl_token_cache.json")
require.NoError(t, err)

loadedToken, err := cache.GetToken()
Expand All @@ -145,7 +145,7 @@ func TestJWTCacheGetToken(t *testing.T) {

t.Run("returns error for corrupted cache file", func(t *testing.T) {
tempDir := t.TempDir()
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl", "cache", "lakectl_token_cache.json")
cache, err := awsiam.NewJWTCache(tempDir, "lakectl_token_cache.json")
require.NoError(t, err)

// Write invalid JSON
Expand All @@ -160,7 +160,7 @@ func TestJWTCacheGetToken(t *testing.T) {

func TestJWTCacheSaveAndLoad(t *testing.T) {
tempDir := t.TempDir()
cache, err := awsiam.NewJWTCache(tempDir, ".lakectl", "cache", "lakectl_token_cache.json")
cache, err := awsiam.NewJWTCache(tempDir, "lakectl_token_cache.json")
require.NoError(t, err)

expirationTime := time.Now().Add(30 * time.Minute).Unix()
Expand Down
Loading