diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000000..eaff96a86a --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,40 @@ +# Changelog + +All notable changes to this project will be documented in this file. +See [Conventional Commits](https://www.conventionalcommits.org/) for commit message format. + +--- + +## [Unreleased] + +### Changed + +- **`pkg/filesystem.GetGlobMatches`**: always returns a non-nil `[]string{}` (never `nil`). + Callers must use `len(result) == 0` to detect no matches instead of `result == nil`. + The cache is now bounded and configurable via three environment variables: + - `ATMOS_FS_GLOB_CACHE_MAX_ENTRIES` (default `1024`, minimum `16`) — maximum number of cached glob patterns. + - `ATMOS_FS_GLOB_CACHE_TTL` (default `5m`, minimum `1s`) — time-to-live for each cache entry. + See `applyGlobCacheConfig` in `pkg/filesystem/glob.go` for the exact logic: only positive values below the minimums are clamped up to the floor; zero or negative values cause the code to fall back to the default values. + - `ATMOS_FS_GLOB_CACHE_EMPTY` (default `1`) — set to `0` to skip caching patterns that match no files. +- **`pkg/http.normalizeHost`**: now strips default ports (`:443`, `:80`) in addition to + lower-casing and removing trailing dots, so that `api.github.com:443` is treated + identically to `api.github.com` for allowlist matching. + +### Added + +- **`pkg/filesystem`**: expvar observability counters (`atmos_glob_cache_hits`, + `atmos_glob_cache_misses`, `atmos_glob_cache_evictions`, `atmos_glob_cache_len`) published + via `RegisterGlobCacheExpvars()` and accessible at `/debug/vars` when the HTTP debug + server is enabled. +- **`pkg/http`**: host-matcher three-level precedence documented and tested: + 1. `WithGitHubHostMatcher` — custom predicate always wins. + 2. `GITHUB_API_URL` — GHES hostname added to allowlist when set. + 3. Built-in allowlist — `api.github.com`, `raw.githubusercontent.com`, `uploads.github.com`. + Authorization is only injected over HTTPS and stripped on cross-host redirects + (301 / 302 / 303 / 307 / 308) to prevent token leakage. + +### Security + +- Cross-host HTTP redirects (all five status codes: 301, 302, 303, 307, 308) no longer + forward the `Authorization` header to the redirect target, preventing accidental token + leakage via open redirects. diff --git a/CLAUDE.md b/CLAUDE.md index fc4fae917e..9daa48a0b2 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -444,6 +444,8 @@ Auto-enabled via `RootCmd.ExecuteC()`. Non-standard paths use `telemetry.Capture **Prerequisites**: Go 1.26+, golangci-lint, Make. See `.cursor/rules/atmos-rules.mdc`. +> **Minimum Go version**: `go.mod` requires Go 1.26. Test helpers use `sync.Map.Clear` (added in Go 1.23) and range-over-int (Go 1.22). CI pins the Go version via `go-version-file: go.mod`. Local development with an older toolchain will fail to compile test-only files. + **Build**: CGO disabled, cross-platform, version via ldflags, output to `./build/` ### Compilation (MANDATORY) diff --git a/Makefile b/Makefile index 92601d44a8..8670454926 100644 --- a/Makefile +++ b/Makefile @@ -124,4 +124,11 @@ link-check: @command -v lychee >/dev/null 2>&1 || { echo "Install lychee: brew install lychee"; exit 1; } lychee --config lychee.toml --root-dir "$(CURDIR)" '**/*.md' -.PHONY: readme lint lintroller gomodcheck build version build-linux build-windows build-macos deps version-linux version-windows version-macos testacc testacc-cover testacc-coverage test-short test-short-cover generate-mocks link-check +# Run quick tests with race detector and shuffled order. +# This target is recommended for CI to catch data races and order-dependent failures. +# Usage: make test-race +test-race: deps + @echo "Running tests with -race -shuffle=on" + go test -race -shuffle=on $(TEST) $(TESTARGS) -timeout 10m + +.PHONY: readme lint lintroller gomodcheck build version build-linux build-windows build-macos deps version-linux version-windows version-macos testacc testacc-cover testacc-coverage test-short test-short-cover test-race generate-mocks link-check diff --git a/errors/errors.go b/errors/errors.go index db918a12fc..a6baf7355d 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -245,12 +245,13 @@ var ( ErrInvalidFlagValue = errors.New("invalid value for flag") // File and URL handling errors. - ErrInvalidPagerCommand = errors.New("invalid pager command") - ErrEmptyURL = errors.New("empty URL provided") - ErrFailedToFindImport = errors.New("failed to find import") - ErrInvalidFilePath = errors.New("invalid file path") - ErrRelPath = errors.New("error determining relative path") - ErrHTTPRequestFailed = errors.New("HTTP request failed") + ErrInvalidPagerCommand = errors.New("invalid pager command") + ErrEmptyURL = errors.New("empty URL provided") + ErrFailedToFindImport = errors.New("failed to find import") + ErrInvalidFilePath = errors.New("invalid file path") + ErrRelPath = errors.New("error determining relative path") + ErrHTTPRequestFailed = errors.New("HTTP request failed") + ErrRedirectLimitExceeded = errors.New("stopped after 10 redirects") // Config loading errors. ErrAtmosDirConfigNotFound = errors.New("atmos config directory not found") diff --git a/go.mod b/go.mod index b470930be7..d8f2b3408c 100644 --- a/go.mod +++ b/go.mod @@ -71,6 +71,7 @@ require ( github.com/hairyhenderson/gomplate/v4 v4.3.3 github.com/hashicorp/go-getter v1.8.5 github.com/hashicorp/go-version v1.8.0 + github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/hashicorp/hcl v1.0.1-vault-7 github.com/hashicorp/hcl/v2 v2.24.0 github.com/hashicorp/terraform-config-inspect v0.0.0-20260224005459-813a97530220 @@ -288,7 +289,6 @@ require ( github.com/hashicorp/go-sockaddr v1.0.7 // indirect github.com/hashicorp/go-uuid v1.0.3 // indirect github.com/hashicorp/golang-lru v1.0.2 // indirect - github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/hashicorp/serf v0.10.2 // indirect github.com/hashicorp/vault/api v1.22.0 // indirect github.com/hashicorp/vault/api/auth/approle v0.11.0 // indirect diff --git a/internal/exec/stack_processor_utils.go b/internal/exec/stack_processor_utils.go index ac1dbc2cdc..496541e0a1 100644 --- a/internal/exec/stack_processor_utils.go +++ b/internal/exec/stack_processor_utils.go @@ -1102,7 +1102,10 @@ func processYAMLConfigFileWithContextInternal( err, ) return nil, nil, nil, nil, nil, nil, nil, nil, errors.New(errorMessage) - } else if importMatches == nil { + } else { + // err == nil but importMatches is an empty slice (not nil): pkg/filesystem.GetGlobMatches + // guarantees a non-nil result, so the old "else if importMatches == nil" check was dead + // code. We reach this branch when no files matched and the call returned ([]string{}, nil). errorMessage := fmt.Sprintf("no matches found for the import '%s' in the file '%s'", imp, relativeFilePath, diff --git a/lychee.toml b/lychee.toml index 92e1694bac..e9f8b2decd 100644 --- a/lychee.toml +++ b/lychee.toml @@ -70,6 +70,8 @@ exclude = [ "lfaidata\\.foundation", # Test data placeholder URLs (intentional) "github\\.com/test/", + # Taskfile.dev (intermittent connection resets from CI runner IP ranges) + "taskfile\\.dev", # Planned documentation paths (not yet published) "file://.*/ai/skills", "file://.*/ai/skill-marketplace", diff --git a/pkg/filesystem/doc.go b/pkg/filesystem/doc.go new file mode 100644 index 0000000000..6ded464245 --- /dev/null +++ b/pkg/filesystem/doc.go @@ -0,0 +1,32 @@ +// Package filesystem provides file-system utilities for the Atmos CLI, including +// atomic file writes (POSIX rename and a Windows-compatible remove-before-rename +// variant) and glob-pattern matching with a bounded, time-limited LRU cache. +// +// # GetGlobMatches contract +// +// [GetGlobMatches] always returns a non-nil []string. An empty result set is +// returned as []string{}, never nil. Callers must check len(result) == 0, not +// result == nil. +// +// # Cache configuration +// +// The glob LRU cache is configurable at startup via environment variables: +// +// - ATMOS_FS_GLOB_CACHE_MAX_ENTRIES – maximum number of cached patterns +// (default: 1024, minimum: 16; values below 16 are clamped up). +// - ATMOS_FS_GLOB_CACHE_TTL – TTL per entry as a Go duration string, e.g. +// "10m" (default: 5m, minimum: 1s; values below 1s are clamped up). +// - ATMOS_FS_GLOB_CACHE_EMPTY – set to "0" or "false" to disable caching +// of empty (no-match) results (default: "1" = enabled). +// +// # Observability +// +// Three atomic int64 counters track cache activity: +// - hits, misses, evictions (accessible via [GlobCacheHits], [GlobCacheMisses], +// [GlobCacheEvictions] in tests). +// +// Call [RegisterGlobCacheExpvars] once at startup to expose these counters via +// the expvar /debug/vars HTTP endpoint: +// +// filesystem.RegisterGlobCacheExpvars() +package filesystem diff --git a/pkg/filesystem/export_test.go b/pkg/filesystem/export_test.go new file mode 100644 index 0000000000..d39d196dbd --- /dev/null +++ b/pkg/filesystem/export_test.go @@ -0,0 +1,100 @@ +package filesystem + +import ( + "path/filepath" + "sync" + "sync/atomic" + "time" +) + +// ResetGlobMatchesCache clears the glob matches LRU cache and resets all counters. +// This is exported only for testing to avoid data races from direct struct assignment. +func ResetGlobMatchesCache() { + globMatchesLRUMu.Lock() + globMatchesLRU.Purge() + globMatchesLRUMu.Unlock() + atomic.StoreInt64(&globMatchesEvictions, 0) + atomic.StoreInt64(&globMatchesHits, 0) + atomic.StoreInt64(&globMatchesMisses, 0) +} + +// ResetPathMatchCache clears the path match cache. +// This is exported only for testing to ensure consistent state between tests. +func ResetPathMatchCache() { + pathMatchCacheMu.Lock() + pathMatchCache = make(map[pathMatchKey]bool) + pathMatchCacheMu.Unlock() +} + +// SetGlobCacheEntryExpired forcibly marks a cache entry as expired for testing TTL eviction. +// It re-adds the entry with an expiry in the past, simulating TTL expiry. +func SetGlobCacheEntryExpired(pattern string) { + normalizedPattern := filepath.ToSlash(pattern) + globMatchesLRUMu.Lock() + if entry, ok := globMatchesLRU.Get(normalizedPattern); ok { + entry.expiry = time.Time{} // zero time is in the past + globMatchesLRU.Add(normalizedPattern, entry) + } + globMatchesLRUMu.Unlock() +} + +// GlobCacheLen returns the number of entries currently in the glob LRU cache. +func GlobCacheLen() int { + globMatchesLRUMu.RLock() + defer globMatchesLRUMu.RUnlock() + return globMatchesLRU.Len() +} + +// GlobCacheEvictions returns the total number of LRU evictions since the last cache reset. +// This counter is incremented atomically by the LRU eviction callback. +func GlobCacheEvictions() int64 { + return atomic.LoadInt64(&globMatchesEvictions) +} + +// GlobCacheHits returns the total number of cache hits since the last cache reset. +func GlobCacheHits() int64 { + return atomic.LoadInt64(&globMatchesHits) +} + +// GlobCacheMisses returns the total number of cache misses since the last cache reset. +func GlobCacheMisses() int64 { + return atomic.LoadInt64(&globMatchesMisses) +} + +// ApplyGlobCacheConfigForTest re-reads ATMOS_FS_GLOB_CACHE_* env vars and reinitializes +// the glob LRU cache. Tests should call this after setting env vars via t.Setenv. +// It also resets all counters so tests start from a clean baseline. +func ApplyGlobCacheConfigForTest() { + applyGlobCacheConfig() + atomic.StoreInt64(&globMatchesEvictions, 0) + atomic.StoreInt64(&globMatchesHits, 0) + atomic.StoreInt64(&globMatchesMisses, 0) +} + +// GlobCacheEmptyEnabled returns the current empty-result caching setting. +func GlobCacheEmptyEnabled() bool { + globMatchesLRUMu.RLock() + defer globMatchesLRUMu.RUnlock() + return globCacheEmptyEnabled +} + +// ResetGlobExpvarOnce resets the sync.Once guard so RegisterGlobCacheExpvars +// can be called again in the same test binary. Only for use in tests that +// need to verify expvar registration after a cache reset. +func ResetGlobExpvarOnce() { + globExpvarOnce = sync.Once{} +} + +// GlobCacheTTL returns the currently active cache TTL for test introspection. +func GlobCacheTTL() time.Duration { + globMatchesLRUMu.RLock() + defer globMatchesLRUMu.RUnlock() + return globCacheTTL +} + +// GlobCacheMaxEntries returns the currently configured LRU capacity for test introspection. +func GlobCacheMaxEntries() int { + globMatchesLRUMu.RLock() + defer globMatchesLRUMu.RUnlock() + return globCacheMaxEntries +} diff --git a/pkg/filesystem/glob.go b/pkg/filesystem/glob.go index fe4c6d6652..57e5cd94a2 100644 --- a/pkg/filesystem/glob.go +++ b/pkg/filesystem/glob.go @@ -3,14 +3,46 @@ package filesystem import ( "os" "path/filepath" + "strconv" "sync" + "sync/atomic" + "time" + + lru "github.com/hashicorp/golang-lru/v2" "github.com/bmatcuk/doublestar/v4" errUtils "github.com/cloudposse/atmos/errors" + log "github.com/cloudposse/atmos/pkg/logger" "github.com/cloudposse/atmos/pkg/perf" ) +const ( + // defaultGlobCacheMaxEntries is the default maximum number of entries in the glob LRU cache. + // Override at startup with ATMOS_FS_GLOB_CACHE_MAX_ENTRIES. + defaultGlobCacheMaxEntries = 1024 + + // defaultGlobCacheTTL is the default time-to-live for each cache entry. + // Override at startup with ATMOS_FS_GLOB_CACHE_TTL (e.g. "10m", "30s"). + defaultGlobCacheTTL = 5 * time.Minute + + // minGlobCacheTTL is the minimum accepted TTL value. Values parsed from + // ATMOS_FS_GLOB_CACHE_TTL that are positive but below this floor are clamped up. + // A sub-second TTL would make the cache nearly useless and cause excessive I/O. + minGlobCacheTTL = time.Second + + // minGlobCacheMaxEntries is the minimum accepted LRU capacity. Values parsed + // from ATMOS_FS_GLOB_CACHE_MAX_ENTRIES that are positive but below this floor + // are clamped up to prevent near-empty caches that evict on nearly every call. + minGlobCacheMaxEntries = 16 +) + +// globCacheEntry holds a cached glob result together with its expiry timestamp. +type globCacheEntry struct { + matches []string + expiry time.Time +} + // pathMatchKey is a composite key for PathMatch cache to avoid collisions. // Using a struct prevents issues when pattern or name contains delimiter characters. type pathMatchKey struct { @@ -19,7 +51,27 @@ type pathMatchKey struct { } var ( - getGlobMatchesSyncMap = sync.Map{} + // globMatchesLRU is a bounded LRU cache for GetGlobMatches results. + // It replaces the unbounded sync.Map to prevent unbounded memory growth. + // Access is mediated by a mutex so that the LRU's internal state is not + // corrupted under concurrent use (hashicorp/golang-lru/v2 is thread-safe, + // but we still need the mutex for atomic load+check+store sequences in our TTL logic). + globMatchesLRU *lru.Cache[string, globCacheEntry] + globMatchesLRUMu sync.RWMutex + globMatchesLRUErr error // non-nil only if lru.New fails (should never happen at runtime) + globMatchesEvictions int64 // incremented atomically by the LRU eviction callback + globMatchesHits int64 // incremented atomically on each cache hit + globMatchesMisses int64 // incremented atomically on each cache miss + + // globCacheTTL is the active TTL, configurable via ATMOS_FS_GLOB_CACHE_TTL. + globCacheTTL = defaultGlobCacheTTL + + // globCacheMaxEntries is the active LRU capacity, configurable via ATMOS_FS_GLOB_CACHE_MAX_ENTRIES. + globCacheMaxEntries = defaultGlobCacheMaxEntries + + // globCacheEmptyEnabled controls whether empty-result sets are stored in the cache. + // Default true. Set ATMOS_FS_GLOB_CACHE_EMPTY=0 to disable. + globCacheEmptyEnabled = true // PathMatchCache stores PathMatch results to avoid redundant pattern matching. // Cache key: pathMatchKey{pattern, name} -> match result (bool). @@ -29,30 +81,105 @@ var ( pathMatchCacheMu sync.RWMutex ) -// GetGlobMatches tries to read and return the Glob matches content from the sync map if it exists in the map, -// otherwise it finds and returns all files matching the pattern, stores the files in the map and returns the files. +// applyGlobCacheConfig reads ATMOS_FS_GLOB_CACHE_* environment variables and +// (re-)initializes the glob LRU cache accordingly. It is called once from +// init() and may be called again from tests to pick up env changes. +func applyGlobCacheConfig() { + maxEntries := defaultGlobCacheMaxEntries + //nolint:forbidigo // Direct env lookup required for cache configuration. + if v := os.Getenv("ATMOS_FS_GLOB_CACHE_MAX_ENTRIES"); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + if n < minGlobCacheMaxEntries { + log.Warn("ATMOS_FS_GLOB_CACHE_MAX_ENTRIES below minimum, clamping up", + "requested", n, "minimum", minGlobCacheMaxEntries) + n = minGlobCacheMaxEntries + } + maxEntries = n + } + } + + ttl := defaultGlobCacheTTL + //nolint:forbidigo // Direct env lookup required for cache configuration. + if v := os.Getenv("ATMOS_FS_GLOB_CACHE_TTL"); v != "" { + if d, err := time.ParseDuration(v); err == nil && d > 0 { + if d < minGlobCacheTTL { + log.Warn("ATMOS_FS_GLOB_CACHE_TTL below minimum, clamping up", + "requested", d, "minimum", minGlobCacheTTL) + d = minGlobCacheTTL + } + ttl = d + } + } + + emptyEnabled := true + //nolint:forbidigo // Direct env lookup required for cache configuration. + if v := os.Getenv("ATMOS_FS_GLOB_CACHE_EMPTY"); v != "" { + // Only "1" or "true" explicitly enables; "0" or "false" disables. + // Any other value is treated as disabled for safety. + switch v { + case "1", "true": + emptyEnabled = true + default: + emptyEnabled = false + } + } + + newLRU, err := lru.NewWithEvict[string, globCacheEntry]( + maxEntries, + func(_ string, _ globCacheEntry) { + atomic.AddInt64(&globMatchesEvictions, 1) + }, + ) + + globMatchesLRUMu.Lock() + globMatchesLRU = newLRU + globMatchesLRUErr = err + globCacheTTL = ttl + globCacheMaxEntries = maxEntries + globCacheEmptyEnabled = emptyEnabled + globMatchesLRUMu.Unlock() +} + +func init() { + applyGlobCacheConfig() +} + +// GetGlobMatches tries to read and return the Glob matches content from the cache if it exists, +// otherwise it finds and returns all files matching the pattern, stores the files in the cache and returns the files. +// +// Contract: the returned slice is always non-nil (never nil). An empty result is returned as []string{}, not nil. +// This guarantee holds for both cache hits and misses, allowing callers to safely use len(result) without a nil check. +// +// Migration note: if your code uses "if result == nil" to detect no matches, update it to "if len(result) == 0". +// Callers should always use len(result) == 0 to detect no matches, not a nil comparison. +// +// Caching policy: +// - Results are bounded to a configurable LRU (default 1024 entries, minimum 16, ATMOS_FS_GLOB_CACHE_MAX_ENTRIES). +// - Each entry expires after a configurable TTL (default 5 minutes, minimum 1s, ATMOS_FS_GLOB_CACHE_TTL). +// - Empty results are cached by default; set ATMOS_FS_GLOB_CACHE_EMPTY=0 to disable. +// - Cached slices are cloned on read, so callers may safely mutate the returned slice. func GetGlobMatches(pattern string) ([]string, error) { defer perf.Track(nil, "filesystem.GetGlobMatches")() // Normalize pattern for cache lookup to ensure consistent keys across platforms. normalizedPattern := filepath.ToSlash(pattern) - existingMatches, found := getGlobMatchesSyncMap.Load(normalizedPattern) - if found && existingMatches != nil { - // Assert to []string and return a cloned copy so callers can't mutate cached data. - cached, ok := existingMatches.([]string) - if !ok { - // If assertion fails, invalidate cache and fall through to recompute. - getGlobMatchesSyncMap.Delete(normalizedPattern) - } - if ok { - // Return a clone to prevent callers from mutating the cached slice. - result := make([]string, len(cached)) - copy(result, cached) - return result, nil - } + // Try cache lookup (read lock on the LRU wrapper). + globMatchesLRUMu.RLock() + entry, found := globMatchesLRU.Get(normalizedPattern) + ttl := globCacheTTL + globMatchesLRUMu.RUnlock() + + if found && time.Now().Before(entry.expiry) { + atomic.AddInt64(&globMatchesHits, 1) + // Return a clone to prevent callers from mutating the cached slice. + result := make([]string, len(entry.matches)) + copy(result, entry.matches) + return result, nil } + atomic.AddInt64(&globMatchesMisses, 1) + base, cleanPattern := doublestar.SplitPattern(normalizedPattern) // Check if base directory exists before attempting glob. @@ -79,16 +206,29 @@ func GetGlobMatches(pattern string) ([]string, error) { matches = []string{} } - var fullMatches []string + fullMatches := make([]string, 0, len(matches)) for _, match := range matches { fullMatches = append(fullMatches, filepath.Join(base, match)) } - // Store a copy of the slice in the cache (not the shared backing slice). - // This prevents callers from mutating cached data and preserves empty results. - cachedCopy := make([]string, len(fullMatches)) - copy(cachedCopy, fullMatches) - getGlobMatchesSyncMap.Store(normalizedPattern, cachedCopy) + // Only store in cache when: (a) there are matches, or (b) empty caching is enabled. + globMatchesLRUMu.RLock() + cacheEmpty := globCacheEmptyEnabled + globMatchesLRUMu.RUnlock() + + if len(fullMatches) > 0 || cacheEmpty { + // Store a copy of the slice in the cache (not the shared backing slice). + // This prevents callers from mutating cached data and preserves empty results. + cachedCopy := make([]string, len(fullMatches)) + copy(cachedCopy, fullMatches) + + globMatchesLRUMu.Lock() + globMatchesLRU.Add(normalizedPattern, globCacheEntry{ + matches: cachedCopy, + expiry: time.Now().Add(ttl), + }) + globMatchesLRUMu.Unlock() + } return fullMatches, nil } diff --git a/pkg/filesystem/glob_atomic_test.go b/pkg/filesystem/glob_atomic_test.go new file mode 100644 index 0000000000..3a387635d3 --- /dev/null +++ b/pkg/filesystem/glob_atomic_test.go @@ -0,0 +1,674 @@ +//go:build !windows + +package filesystem + +import ( + "errors" + "expvar" + "fmt" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + errUtils "github.com/cloudposse/atmos/errors" +) + +// TestGetGlobMatches_CacheHit verifies that GetGlobMatches returns cached results +// on a second call with the same pattern, without re-reading the filesystem. +func TestGetGlobMatches_CacheHit(t *testing.T) { + // Use a fresh cache state by clearing it. + ResetGlobMatchesCache() + t.Cleanup(ResetGlobMatchesCache) + + tmpDir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "a.yaml"), []byte(""), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "b.yaml"), []byte(""), 0o644)) + + pattern := filepath.Join(tmpDir, "*.yaml") + + // First call - cache miss. + first, err := GetGlobMatches(pattern) + require.NoError(t, err) + assert.Len(t, first, 2) + + // Second call with same pattern - should hit cache. + second, err := GetGlobMatches(pattern) + require.NoError(t, err) + assert.Len(t, second, 2) + + // Results should be strictly equal — same type (non-nil), same order, same content. + // Using assert.Equal (not ElementsMatch) to lock in the "always non-nil" return contract + // and verify that cached results are identical (not just order-equivalent). + assert.Equal(t, first, second) +} + +// TestGetGlobMatches_CacheIsolation verifies that cached results are cloned, so +// mutating the returned slice does not corrupt subsequent calls. +func TestGetGlobMatches_CacheIsolation(t *testing.T) { + ResetGlobMatchesCache() + t.Cleanup(ResetGlobMatchesCache) + + tmpDir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "c.yaml"), []byte(""), 0o644)) + + pattern := filepath.Join(tmpDir, "*.yaml") + + first, err := GetGlobMatches(pattern) + require.NoError(t, err) + + // Mutate the returned slice. + if len(first) > 0 { + first[0] = "mutated" + } + + // Second call should still return the original value. + second, err := GetGlobMatches(pattern) + require.NoError(t, err) + if len(second) > 0 { + assert.NotEqual(t, "mutated", second[0]) + } +} + +// TestGetGlobMatches_NonExistentBaseDir verifies that GetGlobMatches returns an +// appropriate error when the base directory does not exist. +func TestGetGlobMatches_NonExistentBaseDir(t *testing.T) { + ResetGlobMatchesCache() + t.Cleanup(ResetGlobMatchesCache) + + // Build a path guaranteed to not exist by using a non-existent sub-directory + // of a fresh t.TempDir() (which will be cleaned up, but we never create the subdir). + pattern := filepath.Join(t.TempDir(), "nonexistent", "*.yaml") + _, err := GetGlobMatches(pattern) + require.Error(t, err, "expected error for non-existent base directory") + assert.True(t, errors.Is(err, errUtils.ErrFailedToFindImport), "expected ErrFailedToFindImport, got: %v", err) +} + +// TestGetGlobMatches_EmptyResults verifies that a pattern matching no files returns +// an empty slice (not an error). +func TestGetGlobMatches_EmptyResults(t *testing.T) { + ResetGlobMatchesCache() + t.Cleanup(ResetGlobMatchesCache) + + tmpDir := t.TempDir() + // No files created in tmpDir. + + pattern := filepath.Join(tmpDir, "*.yaml") + matches, err := GetGlobMatches(pattern) + require.NoError(t, err) + assert.NotNil(t, matches, "GetGlobMatches must return non-nil slice for empty results") + assert.Empty(t, matches) +} + +// TestGetGlobMatches_NonNilContractOnCacheHit verifies the non-nil slice contract is +// preserved on a cache hit — the cached empty result must also be non-nil. +// This is the contract test for the Critical #1 behavior documented in the function docstring. +func TestGetGlobMatches_NonNilContractOnCacheHit(t *testing.T) { + ResetGlobMatchesCache() + t.Cleanup(ResetGlobMatchesCache) + + tmpDir := t.TempDir() + pattern := filepath.Join(tmpDir, "*.nomatches") + + // First call (cache miss): verify non-nil. + first, err := GetGlobMatches(pattern) + require.NoError(t, err) + assert.NotNil(t, first, "first call must return non-nil slice for empty results") + assert.Empty(t, first) + + // Second call (cache hit): must also be non-nil with identical content. + second, err := GetGlobMatches(pattern) + require.NoError(t, err) + assert.NotNil(t, second, "cached result must also be non-nil — never nil on cache hit") + assert.Empty(t, second) + + // Strict equality (same type, same content) between cache miss and cache hit. + assert.Equal(t, first, second, "cache hit must return same non-nil type as cache miss") +} + +// TestGetGlobMatches_EmptyResultsCache verifies that empty results are cached and +// retrieved without hitting the filesystem again. +func TestGetGlobMatches_EmptyResultsCache(t *testing.T) { + ResetGlobMatchesCache() + t.Cleanup(ResetGlobMatchesCache) + + tmpDir := t.TempDir() + + pattern := filepath.Join(tmpDir, "*.nonexistent") + + // First call - should return empty (not nil) slice and cache it. + first, err := GetGlobMatches(pattern) + require.NoError(t, err) + + // Second call should use cache. + second, err := GetGlobMatches(pattern) + require.NoError(t, err) + + // Both should be strictly equal — same type (non-nil empty slice), same content. + // This catches a nil vs []string{} inconsistency between the first and cached call. + assert.Equal(t, first, second) +} + +// TestPathMatch_CacheHit verifies that the PathMatch cache is used on repeated calls. +func TestPathMatch_CacheHit(t *testing.T) { + // Clear the path match cache using the exported test helper. + ResetPathMatchCache() + t.Cleanup(ResetPathMatchCache) + + pattern := "stacks/**/*.yaml" + name := "stacks/dev/vpc.yaml" + + // First call - cache miss. + first, err := PathMatch(pattern, name) + require.NoError(t, err) + + // Second call - should hit cache. + second, err := PathMatch(pattern, name) + require.NoError(t, err) + + assert.Equal(t, first, second) + assert.True(t, first, "pattern should match the name") +} + +// TestPathMatch_CacheHit_NoMatch verifies that cache entries for non-matching patterns +// are also cached and returned correctly. +func TestPathMatch_CacheHit_NoMatch(t *testing.T) { + ResetPathMatchCache() + t.Cleanup(ResetPathMatchCache) + + pattern := "*.go" + name := "file.yaml" + + // First call - cache miss. + first, err := PathMatch(pattern, name) + require.NoError(t, err) + assert.False(t, first) + + // Second call - should hit cache. + second, err := PathMatch(pattern, name) + require.NoError(t, err) + + assert.Equal(t, first, second) +} + +// TestPathMatch_InvalidPattern verifies that an invalid glob pattern returns an error. +func TestPathMatch_InvalidPattern(t *testing.T) { + ResetPathMatchCache() + t.Cleanup(ResetPathMatchCache) + + // An invalid pattern with unclosed bracket. + _, err := PathMatch("[invalid", "file.yaml") + // doublestar.Match returns an error for invalid patterns. + assert.Error(t, err) +} + +// TestWriteFileAtomic verifies that WriteFileAtomic writes file contents correctly. +func TestWriteFileAtomic(t *testing.T) { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "atomic-test.txt") + content := []byte("hello atomic world") + + err := WriteFileAtomicUnix(filePath, content, 0o644) + require.NoError(t, err) + + // Verify file was written. + got, err := os.ReadFile(filePath) + require.NoError(t, err) + assert.Equal(t, content, got) +} + +// TestWriteFileAtomic_Overwrite verifies that WriteFileAtomic correctly overwrites +// an existing file atomically. +func TestWriteFileAtomic_Overwrite(t *testing.T) { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "atomic-overwrite.txt") + + // Write initial content. + require.NoError(t, os.WriteFile(filePath, []byte("initial content"), 0o644)) + + // Overwrite with atomic write. + newContent := []byte("new content") + err := WriteFileAtomicUnix(filePath, newContent, 0o644) + require.NoError(t, err) + + got, err := os.ReadFile(filePath) + require.NoError(t, err) + assert.Equal(t, newContent, got) +} + +// TestOSFileSystem_WriteFileAtomic verifies that OSFileSystem.WriteFileAtomic works. +func TestOSFileSystem_WriteFileAtomic(t *testing.T) { + fs := NewOSFileSystem() + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "os-atomic.txt") + content := []byte("atomic content via OSFileSystem") + + err := fs.WriteFileAtomic(filePath, content, 0o644) + require.NoError(t, err) + + got, err := os.ReadFile(filePath) + require.NoError(t, err) + assert.Equal(t, content, got) +} + +// TestOSFileSystem_WriteFileAtomic_Overwrite verifies that OSFileSystem.WriteFileAtomic +// correctly overwrites an existing file atomically. +func TestOSFileSystem_WriteFileAtomic_Overwrite(t *testing.T) { + fs := NewOSFileSystem() + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "os-atomic-overwrite.txt") + + // Write initial content. + require.NoError(t, os.WriteFile(filePath, []byte("initial content"), 0o644)) + + // Overwrite with atomic write. + newContent := []byte("overwritten content via OSFileSystem") + err := fs.WriteFileAtomic(filePath, newContent, 0o644) + require.NoError(t, err) + + got, err := os.ReadFile(filePath) + require.NoError(t, err) + assert.Equal(t, newContent, got) +} + +// TestGetGlobMatches_LRU_Eviction verifies that the LRU cache evicts the least-recently-used +// entry when the cache reaches its capacity (defaultGlobCacheMaxEntries). +// This test uses a small in-process simulation: it fills the cache to capacity + 1 and +// then checks that the first entry was evicted (i.e., a fresh filesystem read is triggered). +// It also verifies that the eviction counter increments as expected. +func TestGetGlobMatches_LRU_Eviction(t *testing.T) { + ResetGlobMatchesCache() + t.Cleanup(ResetGlobMatchesCache) + + tmpDir := t.TempDir() + + // Populate the LRU cache with defaultGlobCacheMaxEntries unique patterns (all non-matching + // since we only need entries in the cache, not actual files). + // Use sub-directories that don't exist — GetGlobMatches returns an error for + // non-existent base directories, so instead write empty-match YAML patterns. + require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "seed.yaml"), []byte(""), 0o644)) + // Insert a "seed" entry that we will check for eviction later. + seedPattern := filepath.Join(tmpDir, "seed.yaml") + _, err := GetGlobMatches(seedPattern) + require.NoError(t, err) + initialLen := GlobCacheLen() + require.Equal(t, 1, initialLen, "seed entry should be in cache") + + // Fill the cache to defaultGlobCacheMaxEntries by using unique patterns that each match + // the same seed file (pattern variation, not file variation). + // We create defaultGlobCacheMaxEntries additional real files so all patterns resolve. + for i := range defaultGlobCacheMaxEntries { + // Use fmt.Sprintf to guarantee unique filenames for all i values (i > 26 would + // cycle single-character names and produce duplicates). + name := filepath.Join(tmpDir, fmt.Sprintf("file_evict_%d.yaml", i)) + _ = os.WriteFile(name, []byte(""), 0o644) + _, err := GetGlobMatches(name) + require.NoError(t, err) + } + + // After inserting defaultGlobCacheMaxEntries more entries, the LRU should have evicted the + // seed entry (it was the oldest / least recently used). + // We verify this by checking the cache size is bounded at defaultGlobCacheMaxEntries. + afterLen := GlobCacheLen() + assert.LessOrEqual(t, afterLen, defaultGlobCacheMaxEntries, "LRU cache must not exceed max capacity") + + // The eviction counter must have incremented at least once. + evictions := GlobCacheEvictions() + assert.Positive(t, evictions, "eviction counter must increment when LRU capacity is exceeded") +} + +// TestGetGlobMatches_TTL_Expiry verifies that a stale cache entry (past TTL) +// is treated as a cache miss and triggers a fresh filesystem read. +func TestGetGlobMatches_TTL_Expiry(t *testing.T) { + ResetGlobMatchesCache() + t.Cleanup(ResetGlobMatchesCache) + + tmpDir := t.TempDir() + + // Create a file so the first call returns a result. + file1 := filepath.Join(tmpDir, "a.yaml") + require.NoError(t, os.WriteFile(file1, []byte(""), 0o644)) + + pattern := filepath.Join(tmpDir, "*.yaml") + + // First call — cache miss, should find the file. + res1, err := GetGlobMatches(pattern) + require.NoError(t, err) + assert.Len(t, res1, 1, "should find exactly one file") + + // Forcibly expire the cache entry via the test helper. + SetGlobCacheEntryExpired(pattern) + + // Add a second file before the second call. + file2 := filepath.Join(tmpDir, "b.yaml") + require.NoError(t, os.WriteFile(file2, []byte(""), 0o644)) + + // Second call — the TTL has expired, so the cache should be bypassed and + // both files should be discovered. + res2, err := GetGlobMatches(pattern) + require.NoError(t, err) + assert.Len(t, res2, 2, "TTL expiry should trigger fresh filesystem read returning both files") +} + +// TestGetGlobMatches_EmptyResultCached verifies that empty results (no matching files) +// are cached and served from cache on subsequent calls. +func TestGetGlobMatches_EmptyResultCached(t *testing.T) { + ResetGlobMatchesCache() + t.Cleanup(ResetGlobMatchesCache) + + tmpDir := t.TempDir() + + // Pattern that matches nothing. + pattern := filepath.Join(tmpDir, "nonexistent_*.yaml") + + res1, err := GetGlobMatches(pattern) + require.NoError(t, err) + assert.Empty(t, res1, "should return empty result for non-matching pattern") + assert.NotNil(t, res1, "empty result must be non-nil (contract)") + + // Second call — should be a cache hit (no filesystem walk). + res2, err := GetGlobMatches(pattern) + require.NoError(t, err) + assert.Empty(t, res2, "second call should return empty result from cache") + assert.NotNil(t, res2, "cached empty result must be non-nil (contract)") +} + +// TestGetGlobMatches_HitMissCounters verifies that the hit and miss counters +// are incremented correctly across cache hits and misses. +func TestGetGlobMatches_HitMissCounters(t *testing.T) { + ResetGlobMatchesCache() + t.Cleanup(ResetGlobMatchesCache) + + tmpDir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "a.yaml"), []byte(""), 0o644)) + + pattern := filepath.Join(tmpDir, "*.yaml") + + // First call is always a miss. + _, err := GetGlobMatches(pattern) + require.NoError(t, err) + assert.Equal(t, int64(0), GlobCacheHits(), "no hits yet") + assert.Equal(t, int64(1), GlobCacheMisses(), "first call is a miss") + + // Second call should be a hit. + _, err = GetGlobMatches(pattern) + require.NoError(t, err) + assert.Equal(t, int64(1), GlobCacheHits(), "second call is a hit") + assert.Equal(t, int64(1), GlobCacheMisses(), "miss count must not change") +} + +// TestGetGlobMatches_EmptyResultCachingDisabled verifies that when +// ATMOS_FS_GLOB_CACHE_EMPTY=0 is set, empty results are not cached. +func TestGetGlobMatches_EmptyResultCachingDisabled(t *testing.T) { + // Set the env var BEFORE applying config. + t.Setenv("ATMOS_FS_GLOB_CACHE_EMPTY", "0") + ApplyGlobCacheConfigForTest() + // Only reset cache in cleanup; t.Setenv restores the env var (LIFO) so subsequent + // tests that call ApplyGlobCacheConfigForTest() will see the restored unset value. + t.Cleanup(ResetGlobMatchesCache) + + assert.False(t, GlobCacheEmptyEnabled(), "empty caching must be disabled when env var is 0") + + tmpDir := t.TempDir() + // Use a wildcard that initially matches nothing. + pattern := filepath.Join(tmpDir, "*.yaml") + + // First call — cache miss, empty result; must NOT be stored. + res1, err := GetGlobMatches(pattern) + require.NoError(t, err) + assert.Empty(t, res1, "should return empty result") + assert.NotNil(t, res1, "must be non-nil per contract") + assert.Equal(t, 0, GlobCacheLen(), "empty result must not be cached when disabled") + + // Create a file so the next call returns a non-empty result. + require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "found.yaml"), []byte(""), 0o644)) + + // Second call — since the empty result was NOT cached, the filesystem is re-read + // and the newly-created file should be discovered. + res2, err := GetGlobMatches(pattern) + require.NoError(t, err) + assert.Len(t, res2, 1, "new file should be found after cache bypass") +} + +// TestGetGlobMatches_RaceStress hammers the glob cache from many goroutines to +// surface data races. Run with -race to exercise the race detector. +func TestGetGlobMatches_RaceStress(t *testing.T) { + ResetGlobMatchesCache() + t.Cleanup(ResetGlobMatchesCache) + + tmpDir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "stress.yaml"), []byte(""), 0o644)) + + const numGoroutines = 32 + const callsPerGoroutine = 50 + + done := make(chan struct{}) + for g := range numGoroutines { + g := g + go func() { + defer func() { done <- struct{}{} }() + for i := range callsPerGoroutine { + // Use a mix of unique and shared patterns to exercise both cache hits + // and cache misses concurrently. + var pattern string + if i%2 == 0 { + pattern = filepath.Join(tmpDir, "*.yaml") + } else { + pattern = filepath.Join(tmpDir, fmt.Sprintf("unique_%d_%d_*.yaml", g, i)) + } + _, _ = GetGlobMatches(pattern) + } + }() + } + + for range numGoroutines { + <-done + } +} + +// TestGetGlobMatches_EnvTTL verifies that ATMOS_FS_GLOB_CACHE_TTL is honoured. +// A very short TTL means entries expire immediately, so every call is a miss. +func TestGetGlobMatches_EnvTTL(t *testing.T) { + t.Setenv("ATMOS_FS_GLOB_CACHE_TTL", "1ns") + ApplyGlobCacheConfigForTest() + // Only reset the cache (not the config) in cleanup; the env restore from + // t.Setenv runs after this cleanup (LIFO), so ApplyGlobCacheConfigForTest + // in another test's setup will pick up the restored (empty) value. + t.Cleanup(ResetGlobMatchesCache) + + tmpDir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "ttl.yaml"), []byte(""), 0o644)) + + pattern := filepath.Join(tmpDir, "*.yaml") + + _, err := GetGlobMatches(pattern) + require.NoError(t, err) + + // With 1ns TTL the entry will already be stale. Force it expired just to be safe. + SetGlobCacheEntryExpired(pattern) + + // Add a second file to prove the second call re-reads the filesystem. + require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "ttl2.yaml"), []byte(""), 0o644)) + + res, err := GetGlobMatches(pattern) + require.NoError(t, err) + assert.Len(t, res, 2, "short TTL should cause re-read and find both files") +} + +// TestRegisterGlobCacheExpvars verifies that RegisterGlobCacheExpvars publishes +// counters that reflect actual cache activity. +func TestRegisterGlobCacheExpvars(t *testing.T) { + // ApplyGlobCacheConfigForTest re-reads env vars and reinitializes the LRU. + // This is essential when a prior test (e.g. TestGetGlobMatches_EnvTTL) left + // the in-package globCacheTTL at 1ns due to cleanup ordering. + ApplyGlobCacheConfigForTest() + ResetGlobMatchesCache() + ResetGlobExpvarOnce() + t.Cleanup(func() { + ApplyGlobCacheConfigForTest() + ResetGlobMatchesCache() + ResetGlobExpvarOnce() + }) + + RegisterGlobCacheExpvars() + + tmpDir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "ev.yaml"), []byte(""), 0o644)) + pattern := filepath.Join(tmpDir, "*.yaml") + + // First call is a miss. + _, err := GetGlobMatches(pattern) + require.NoError(t, err) + + // Second call is a hit. + _, err = GetGlobMatches(pattern) + require.NoError(t, err) + + // Verify expvar values match atomic counters. + hitsVar := expvar.Get("atmos_glob_cache_hits") + require.NotNil(t, hitsVar, "atmos_glob_cache_hits expvar must be registered") + assert.Equal(t, "1", hitsVar.String(), "hit counter must be 1 after second call") + + missesVar := expvar.Get("atmos_glob_cache_misses") + require.NotNil(t, missesVar) + assert.Equal(t, "1", missesVar.String(), "miss counter must be 1 after first call") + + lenVar := expvar.Get("atmos_glob_cache_len") + require.NotNil(t, lenVar) + assert.Equal(t, "1", lenVar.String(), "cache len must be 1 after one unique pattern") +} + +// TestApplyGlobCacheConfig_InvalidInputsClamped verifies that invalid or out-of-range +// values for ATMOS_FS_GLOB_CACHE_TTL and ATMOS_FS_GLOB_CACHE_MAX_ENTRIES are rejected +// and that the defaults are preserved. +func TestApplyGlobCacheConfig_InvalidInputsClamped(t *testing.T) { + // Cannot use t.Parallel() here because subtests call t.Setenv which modifies + // process-wide environment variables. + + type testCase struct { + name string + ttlEnv string + maxEntriesEnv string + wantTTL time.Duration + wantMaxEntries int + } + + const ( + defaultTTL = 5 * time.Minute + defaultMaxEntries = 1024 + ) + + cases := []testCase{ + // Zero values should fall back to defaults. + { + name: "zero_TTL_falls_back_to_default", + ttlEnv: "0s", + wantTTL: defaultTTL, + wantMaxEntries: defaultMaxEntries, + }, + { + name: "zero_maxEntries_falls_back_to_default", + maxEntriesEnv: "0", + wantTTL: defaultTTL, + wantMaxEntries: defaultMaxEntries, + }, + // Negative values should fall back to defaults. + { + name: "negative_TTL_falls_back_to_default", + ttlEnv: "-1m", + wantTTL: defaultTTL, + wantMaxEntries: defaultMaxEntries, + }, + { + name: "negative_maxEntries_falls_back_to_default", + maxEntriesEnv: "-5", + wantTTL: defaultTTL, + wantMaxEntries: defaultMaxEntries, + }, + // Unparseable values should fall back to defaults. + { + name: "invalid_TTL_string_falls_back_to_default", + ttlEnv: "not-a-duration", + wantTTL: defaultTTL, + wantMaxEntries: defaultMaxEntries, + }, + { + name: "invalid_maxEntries_string_falls_back_to_default", + maxEntriesEnv: "not-a-number", + wantTTL: defaultTTL, + wantMaxEntries: defaultMaxEntries, + }, + // Valid values should be accepted. + { + name: "valid_TTL_accepted", + ttlEnv: "10m", + wantTTL: 10 * time.Minute, + wantMaxEntries: defaultMaxEntries, + }, + { + name: "valid_maxEntries_accepted", + maxEntriesEnv: "256", + wantTTL: defaultTTL, + wantMaxEntries: 256, + }, + // Values below the minimum should be clamped up, not rejected. + { + name: "TTL_below_minimum_clamped_to_1s", + ttlEnv: "100ms", + wantTTL: time.Second, + wantMaxEntries: defaultMaxEntries, + }, + { + name: "TTL_500ms_clamped_to_1s", + ttlEnv: "500ms", + wantTTL: time.Second, + wantMaxEntries: defaultMaxEntries, + }, + { + name: "maxEntries_below_minimum_clamped_to_16", + maxEntriesEnv: "5", + wantTTL: defaultTTL, + wantMaxEntries: 16, + }, + { + name: "maxEntries_15_clamped_to_16", + maxEntriesEnv: "15", + wantTTL: defaultTTL, + wantMaxEntries: 16, + }, + { + name: "maxEntries_exactly_16_accepted", + maxEntriesEnv: "16", + wantTTL: defaultTTL, + wantMaxEntries: 16, + }, + { + name: "TTL_exactly_1s_accepted", + ttlEnv: "1s", + wantTTL: time.Second, + wantMaxEntries: defaultMaxEntries, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if tc.ttlEnv != "" { + t.Setenv("ATMOS_FS_GLOB_CACHE_TTL", tc.ttlEnv) + } + if tc.maxEntriesEnv != "" { + t.Setenv("ATMOS_FS_GLOB_CACHE_MAX_ENTRIES", tc.maxEntriesEnv) + } + ApplyGlobCacheConfigForTest() + t.Cleanup(func() { + ApplyGlobCacheConfigForTest() + ResetGlobMatchesCache() + }) + + assert.Equal(t, tc.wantTTL, GlobCacheTTL(), "TTL mismatch for env TTL=%q", tc.ttlEnv) + assert.Equal(t, tc.wantMaxEntries, GlobCacheMaxEntries(), "MaxEntries mismatch for env MAX=%q", tc.maxEntriesEnv) + }) + } +} diff --git a/pkg/filesystem/glob_atomic_windows_test.go b/pkg/filesystem/glob_atomic_windows_test.go new file mode 100644 index 0000000000..fceca45032 --- /dev/null +++ b/pkg/filesystem/glob_atomic_windows_test.go @@ -0,0 +1,80 @@ +//go:build windows + +package filesystem + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestWriteFileAtomicWindows_Create verifies that WriteFileAtomicWindows creates a new file. +func TestWriteFileAtomicWindows_Create(t *testing.T) { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "new-file.txt") + content := []byte("hello windows atomic create") + + err := WriteFileAtomicWindows(filePath, content, 0o644) + require.NoError(t, err, "WriteFileAtomicWindows should create a new file") + + got, err := os.ReadFile(filePath) + require.NoError(t, err) + assert.Equal(t, content, got) +} + +// TestWriteFileAtomicWindows_Overwrite verifies that WriteFileAtomicWindows overwrites an +// existing file atomically (remove-before-rename path on Windows). +func TestWriteFileAtomicWindows_Overwrite(t *testing.T) { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "existing-file.txt") + + // Write initial content. + require.NoError(t, os.WriteFile(filePath, []byte("initial content"), 0o644)) + + newContent := []byte("overwritten content via WriteFileAtomicWindows") + err := WriteFileAtomicWindows(filePath, newContent, 0o644) + require.NoError(t, err, "WriteFileAtomicWindows should overwrite an existing file") + + got, err := os.ReadFile(filePath) + require.NoError(t, err) + assert.Equal(t, newContent, got, "file must contain new content after overwrite") +} + +// TestWriteFileAtomicWindows_RemoveBeforeRename exercises the remove-before-rename code path +// by simulating the scenario where the destination file already exists. +// On Windows, os.Rename fails if the target exists so WriteFileAtomicWindows removes it first. +func TestWriteFileAtomicWindows_RemoveBeforeRename(t *testing.T) { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "replace-me.txt") + + // Create a non-empty existing file. + require.NoError(t, os.WriteFile(filePath, []byte("old data"), 0o644)) + + // Overwrite multiple times to ensure the remove-before-rename path is exercised reliably. + for i := range 3 { + content := []byte("iteration " + string(rune('0'+i))) + err := WriteFileAtomicWindows(filePath, content, 0o644) + require.NoError(t, err) + + got, err := os.ReadFile(filePath) + require.NoError(t, err) + assert.Equal(t, content, got) + } +} + +// TestWriteFileAtomicWindows_ModePreserved verifies that WriteFileAtomicWindows sets the +// requested file permissions on the written file. +func TestWriteFileAtomicWindows_ModePreserved(t *testing.T) { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "mode-check.txt") + + err := WriteFileAtomicWindows(filePath, []byte("content"), 0o644) + require.NoError(t, err, "WriteFileAtomicWindows should succeed") + + info, err := os.Stat(filePath) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0o644), info.Mode().Perm(), "file permissions must match requested mode") +} diff --git a/pkg/filesystem/glob_metrics.go b/pkg/filesystem/glob_metrics.go new file mode 100644 index 0000000000..99a53b5fe5 --- /dev/null +++ b/pkg/filesystem/glob_metrics.go @@ -0,0 +1,45 @@ +package filesystem + +import ( + "expvar" + "sync" + "sync/atomic" +) + +// globExpvarOnce ensures expvar variables are registered at most once. +// expvar.Publish panics on duplicate registration, so callers that call +// RegisterGlobCacheExpvars multiple times (e.g., in tests) are protected. +var globExpvarOnce sync.Once + +// RegisterGlobCacheExpvars publishes the glob-cache counters as expvar integers +// under the /debug/vars HTTP endpoint. The function is no-op after the first call +// (duplicate registration would panic). +// +// Call this once at program startup to expose cache performance metrics: +// +// import _ "net/http/pprof" // enable /debug/vars +// filesystem.RegisterGlobCacheExpvars() +// +// The following variables are published: +// - atmos_glob_cache_hits – number of cache hits since last reset +// - atmos_glob_cache_misses – number of cache misses since last reset +// - atmos_glob_cache_evictions – number of LRU evictions since last reset +// - atmos_glob_cache_len – current number of entries in the cache +func RegisterGlobCacheExpvars() { + globExpvarOnce.Do(func() { + expvar.Publish("atmos_glob_cache_hits", expvar.Func(func() any { + return atomic.LoadInt64(&globMatchesHits) + })) + expvar.Publish("atmos_glob_cache_misses", expvar.Func(func() any { + return atomic.LoadInt64(&globMatchesMisses) + })) + expvar.Publish("atmos_glob_cache_evictions", expvar.Func(func() any { + return atomic.LoadInt64(&globMatchesEvictions) + })) + expvar.Publish("atmos_glob_cache_len", expvar.Func(func() any { + globMatchesLRUMu.RLock() + defer globMatchesLRUMu.RUnlock() + return globMatchesLRU.Len() + })) + }) +} diff --git a/pkg/flags/compat/compatibility_flags_test.go b/pkg/flags/compat/compatibility_flags_test.go index 7af911f0c5..5fa13b92db 100644 --- a/pkg/flags/compat/compatibility_flags_test.go +++ b/pkg/flags/compat/compatibility_flags_test.go @@ -998,3 +998,53 @@ func TestCompatibilityFlagTranslator_ShorthandNormalization(t *testing.T) { }) } } + +// TestCompatibilityFlagTranslator_UnknownBehavior verifies that unknown/custom CompatibilityBehavior +// values fall through to the default case (passed as-is to Atmos args). +// This tests the defensive default branches in applyFlagBehaviorWithEquals and +// applyFlagBehaviorWithoutEquals. +func TestCompatibilityFlagTranslator_UnknownBehavior(t *testing.T) { + // Use a custom behavior value that is not MapToAtmosFlag or AppendToSeparated. + unknownBehavior := CompatibilityBehavior(999) + + tests := []struct { + name string + args []string + flagMap map[string]CompatibilityFlag + expectedAtmosArgs []string + expectedSeparated []string + }{ + { + name: "unknown behavior with equals syntax", + args: []string{"-custom=value"}, + flagMap: map[string]CompatibilityFlag{ + "-custom": {Behavior: unknownBehavior, Target: "--custom"}, + }, + // default branch: pass original arg to atmos args. + expectedAtmosArgs: []string{"-custom=value"}, + expectedSeparated: []string{}, + }, + { + name: "unknown behavior without equals syntax", + args: []string{"-custom", "value"}, + flagMap: map[string]CompatibilityFlag{ + "-custom": {Behavior: unknownBehavior, Target: "--custom"}, + }, + // default branch: pass the flag arg to atmos args, consumed=0 so the next + // token ("value") is NOT consumed here — it is processed independently in + // the next loop iteration and becomes a positional/atmos arg. + expectedAtmosArgs: []string{"-custom", "value"}, + expectedSeparated: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + translator := NewCompatibilityFlagTranslator(tt.flagMap) + atmosArgs, separatedArgs := translator.Translate(tt.args) + + assert.Equal(t, tt.expectedAtmosArgs, atmosArgs) + assert.Equal(t, tt.expectedSeparated, separatedArgs) + }) + } +} diff --git a/pkg/flags/compat/separated.go b/pkg/flags/compat/separated.go index c4d31d48d6..635f2fe477 100644 --- a/pkg/flags/compat/separated.go +++ b/pkg/flags/compat/separated.go @@ -17,6 +17,11 @@ var ( // Separated args are flags that should be passed through to the underlying command, // (e.g., terraform -out=/tmp/plan, -var=foo=bar) rather than being parsed by Atmos. // These are identified by the CompatibilityFlagTranslator during preprocessing. +// +// Note: passing an empty (non-nil) slice is equivalent to passing nil — the global +// state will be nil and GetSeparated() will return nil. This is intentional: callers +// that range over the result are unaffected, and it avoids spurious "no args" vs +// "zero-length args" ambiguity. This contract is tested and must be preserved. func SetSeparated(separatedArgs []string) { defer perf.Track(nil, "compat.SetSeparated")() diff --git a/pkg/flags/compat/separated_test.go b/pkg/flags/compat/separated_test.go new file mode 100644 index 0000000000..8f34e3c6f4 --- /dev/null +++ b/pkg/flags/compat/separated_test.go @@ -0,0 +1,116 @@ +package compat + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSetSeparated(t *testing.T) { + t.Cleanup(func() { ResetSeparated() }) + + args := []string{"-var", "region=us-east-1", "-var-file", "prod.tfvars"} + SetSeparated(args) + + got := GetSeparated() + assert.Equal(t, args, got) +} + +func TestSetSeparated_DefensiveCopy(t *testing.T) { + t.Cleanup(func() { ResetSeparated() }) + + original := []string{"-var", "foo=bar"} + SetSeparated(original) + + // Mutate the original slice. + original[0] = "mutated" + + // GetSeparated should return the original values, not the mutated ones. + got := GetSeparated() + assert.Equal(t, "-var", got[0], "SetSeparated should make a defensive copy") +} + +func TestGetSeparated_ReturnsNilWhenNotSet(t *testing.T) { + t.Cleanup(func() { ResetSeparated() }) + ResetSeparated() + + got := GetSeparated() + assert.Nil(t, got) +} + +func TestGetSeparated_DefensiveCopy(t *testing.T) { + t.Cleanup(func() { ResetSeparated() }) + + SetSeparated([]string{"-var", "x=1"}) + + // Mutate the returned slice. + got1 := GetSeparated() + got1[0] = "mutated" + + // Second call should return original value. + got2 := GetSeparated() + assert.Equal(t, "-var", got2[0], "GetSeparated should return a defensive copy") +} + +func TestGetSeparated_ReturnsEmptySliceForEmpty(t *testing.T) { + t.Cleanup(func() { ResetSeparated() }) + + // SetSeparated with empty slice: append([]string(nil), []string{}...) returns nil. + // This is expected Go behavior - appending zero elements to nil yields nil. + SetSeparated([]string{}) + + got := GetSeparated() + // An empty input slice results in nil globalSeparatedArgs (nil == nil is true). + assert.Nil(t, got) +} + +func TestResetSeparated(t *testing.T) { + t.Cleanup(func() { ResetSeparated() }) + + SetSeparated([]string{"-var", "x=1"}) + assert.NotNil(t, GetSeparated()) + + ResetSeparated() + assert.Nil(t, GetSeparated()) +} + +func TestSeparated_Concurrent(t *testing.T) { + t.Cleanup(func() { ResetSeparated() }) + + // Phase 1: Parallel Set + Get only — no Reset during reads. + // Verify that defensive copies from GetSeparated are independent. + const goroutines = 50 + SetSeparated([]string{"-var", "key=value"}) + + copies := make([][]string, goroutines) + var wg sync.WaitGroup + wg.Add(goroutines) + for i := range goroutines { + // In Go 1.22+, loop variables are per-iteration; i := i is a no-op shadow. + go func(idx int) { + defer wg.Done() + copies[idx] = GetSeparated() + }(i) + } + wg.Wait() + + // All copies must equal the originally set value. + for i, c := range copies { + assert.Equal(t, []string{"-var", "key=value"}, c, "goroutine %d got unexpected result", i) + } + + // Mutating one copy must not affect others (defensive copy guarantee). + // Use require.NotNil so any future change that makes GetSeparated() return nil + // fails loudly here rather than silently skipping the independence assertion. + require.True(t, len(copies) >= 2, "expected at least 2 goroutine copies") + require.NotNil(t, copies[0], "goroutine 0 must return non-nil slice") + require.NotNil(t, copies[1], "goroutine 1 must return non-nil slice") + copies[0][0] = "mutated" + assert.Equal(t, "-var", copies[1][0], "defensive copies must be independent") + + // Phase 2: Reset — run serially after all reads are complete. + ResetSeparated() + assert.Nil(t, GetSeparated()) +} diff --git a/pkg/flags/flag_parser_test.go b/pkg/flags/flag_parser_test.go index 45ebf84a3e..1a5e81c08f 100644 --- a/pkg/flags/flag_parser_test.go +++ b/pkg/flags/flag_parser_test.go @@ -531,3 +531,90 @@ func TestFlagParser_NoOptDefVal(t *testing.T) { }) } } + +// TestFlagParser_Reset verifies that Reset clears registered command flag state +// so parsers can be reused cleanly between test runs. +func TestFlagParser_Reset(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("stack", "", "Stack name") + + v := viper.New() + translator := compat.NewCompatibilityFlagTranslator(nil) + registry := NewFlagRegistry() + parser := NewAtmosFlagParser(cmd, v, translator, registry) + + // Parse once to populate flags. + _, err := parser.Parse([]string{"--stack", "dev"}) + require.NoError(t, err) + assert.Equal(t, "dev", v.GetString("stack")) + + // Verify the flag is marked as Changed after the first parse. + flag := cmd.Flags().Lookup("stack") + require.NotNil(t, flag) + assert.True(t, flag.Changed, "flag should be Changed after first parse") + assert.Equal(t, "dev", flag.Value.String()) + + // Reset should not panic and must clear the Changed state and restore defaults. + assert.NotPanics(t, func() { + parser.Reset() + }) + + // After Reset, the flag's Changed state must be cleared and value back to default. + assert.False(t, flag.Changed, "flag Changed state must be false after Reset") + assert.Equal(t, "", flag.Value.String(), "flag value must be reset to default after Reset") + // Resetting the pflag clears the viper value bound via BindPFlags. + assert.Equal(t, "", v.GetString("stack"), "viper value must also be cleared after Reset") + + // A second parse with no flags should not see the value from the first parse. + result, err := parser.Parse([]string{}) + require.NoError(t, err) + assert.Equal(t, "", GetString(result.Flags, "stack"), "second parse must not inherit value from first parse") +} + +// TestParsedConfig_GetArgsForTool verifies that GetArgsForTool combines positional +// and separated args into the expected subprocess argument array. +func TestParsedConfig_GetArgsForTool(t *testing.T) { + tests := []struct { + name string + positionalArgs []string + separatedArgs []string + want []string + }{ + { + name: "positional only", + positionalArgs: []string{"plan", "vpc"}, + separatedArgs: []string{}, + want: []string{"plan", "vpc"}, + }, + { + name: "separated only", + positionalArgs: []string{}, + separatedArgs: []string{"-var", "region=us-east-1"}, + want: []string{"-var", "region=us-east-1"}, + }, + { + name: "both positional and separated", + positionalArgs: []string{"plan", "vpc"}, + separatedArgs: []string{"-var", "region=us-east-1", "-var-file", "prod.tfvars"}, + want: []string{"plan", "vpc", "-var", "region=us-east-1", "-var-file", "prod.tfvars"}, + }, + { + name: "empty both", + positionalArgs: []string{}, + separatedArgs: []string{}, + want: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pc := &ParsedConfig{ + PositionalArgs: tt.positionalArgs, + SeparatedArgs: tt.separatedArgs, + } + + got := pc.GetArgsForTool() + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/flags/registry_test.go b/pkg/flags/registry_test.go index 77d310bf59..30a08a73cc 100644 --- a/pkg/flags/registry_test.go +++ b/pkg/flags/registry_test.go @@ -3,6 +3,8 @@ package flags import ( "testing" + "github.com/spf13/cobra" + "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -135,6 +137,209 @@ func TestPackerFlags(t *testing.T) { assert.False(t, registry.Has("identity"), "identity should be inherited from RootCmd") } +func TestFlagRegistry_RegisterStringFlag(t *testing.T) { + registry := NewFlagRegistry() + registry.RegisterStringFlag("my-flag", "m", "default-val", "My flag description", false) + + flag := registry.Get("my-flag") + require.NotNil(t, flag) + + sf, ok := flag.(*StringFlag) + require.True(t, ok) + assert.Equal(t, "my-flag", sf.Name) + assert.Equal(t, "m", sf.Shorthand) + assert.Equal(t, "default-val", sf.Default) + assert.Equal(t, "My flag description", sf.Description) + assert.False(t, sf.Required) +} + +func TestFlagRegistry_RegisterStringFlag_Required(t *testing.T) { + registry := NewFlagRegistry() + registry.RegisterStringFlag("required-flag", "", "", "A required flag", true) + + flag := registry.Get("required-flag") + require.NotNil(t, flag) + assert.True(t, flag.IsRequired()) +} + +func TestFlagRegistry_RegisterBoolFlag(t *testing.T) { + registry := NewFlagRegistry() + registry.RegisterBoolFlag("verbose", "v", true, "Enable verbose output") + + flag := registry.Get("verbose") + require.NotNil(t, flag) + + bf, ok := flag.(*BoolFlag) + require.True(t, ok) + assert.Equal(t, "verbose", bf.Name) + assert.Equal(t, "v", bf.Shorthand) + assert.True(t, bf.Default) + assert.Equal(t, "Enable verbose output", bf.Description) +} + +func TestFlagRegistry_RegisterIntFlag(t *testing.T) { + registry := NewFlagRegistry() + registry.RegisterIntFlag("count", "c", 5, "Number of items", false) + + flag := registry.Get("count") + require.NotNil(t, flag) + + intf, ok := flag.(*IntFlag) + require.True(t, ok) + assert.Equal(t, "count", intf.Name) + assert.Equal(t, "c", intf.Shorthand) + assert.Equal(t, 5, intf.Default) + assert.Equal(t, "Number of items", intf.Description) + assert.False(t, intf.Required) +} + +func TestFlagRegistry_RegisterFlags(t *testing.T) { + registry := NewFlagRegistry() + registry.Register(&StringFlag{Name: "stack", Shorthand: "s", Description: "Stack name"}) + registry.Register(&BoolFlag{Name: "dry-run", Description: "Dry run mode"}) + registry.Register(&IntFlag{Name: "timeout", Description: "Timeout seconds"}) + + cmd := &cobra.Command{Use: "test"} + registry.RegisterFlags(cmd) + + // Verify all flags were registered with the cobra command. + assert.NotNil(t, cmd.Flags().Lookup("stack"), "stack flag should be registered") + assert.NotNil(t, cmd.Flags().Lookup("dry-run"), "dry-run flag should be registered") + assert.NotNil(t, cmd.Flags().Lookup("timeout"), "timeout flag should be registered") + + // Verify shorthands. + assert.Equal(t, "s", cmd.Flags().Lookup("stack").Shorthand) +} + +func TestFlagRegistry_RegisterFlags_WithNoOptDefVal(t *testing.T) { + registry := NewFlagRegistry() + registry.Register(&StringFlag{ + Name: "identity", + NoOptDefVal: cfg.IdentityFlagSelectValue, + Description: "Identity selector", + }) + + cmd := &cobra.Command{Use: "test"} + registry.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("identity") + require.NotNil(t, flag) + assert.Equal(t, cfg.IdentityFlagSelectValue, flag.NoOptDefVal) +} + +func TestFlagRegistry_RegisterPersistentFlags(t *testing.T) { + registry := NewFlagRegistry() + registry.Register(&StringFlag{Name: "logs-level", Description: "Log level"}) + registry.Register(&BoolFlag{Name: "verbose", Description: "Verbose output"}) + + cmd := &cobra.Command{Use: "test"} + registry.RegisterPersistentFlags(cmd) + + // Persistent flags should appear in PersistentFlags(), not Flags(). + assert.NotNil(t, cmd.PersistentFlags().Lookup("logs-level"), "logs-level should be registered as persistent") + assert.NotNil(t, cmd.PersistentFlags().Lookup("verbose"), "verbose should be registered as persistent") + + // PersistentFlags are NOT in non-persistent Flags(). + assert.Nil(t, cmd.Flags().Lookup("logs-level"), "logs-level should NOT be in non-persistent flags") +} + +func TestFlagRegistry_SetCompletionFunc(t *testing.T) { + registry := NewFlagRegistry() + registry.Register(&StringFlag{Name: "stack", Description: "Stack name"}) + + completionFn := func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + return []string{"dev", "prod", "staging"}, cobra.ShellCompDirectiveNoFileComp + } + + // Should not panic. + assert.NotPanics(t, func() { + registry.SetCompletionFunc("stack", completionFn) + }) + + // Verify the function was set. + flag := registry.Get("stack") + sf, ok := flag.(*StringFlag) + require.True(t, ok) + require.NotNil(t, sf.CompletionFunc) + + // Verify the completion function works. + results, _ := sf.CompletionFunc(nil, nil, "") + assert.Equal(t, []string{"dev", "prod", "staging"}, results) +} + +func TestFlagRegistry_SetCompletionFunc_NonExistentFlag(t *testing.T) { + registry := NewFlagRegistry() + + completionFn := func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + return nil, cobra.ShellCompDirectiveNoFileComp + } + + // Should not panic even when flag doesn't exist. + assert.NotPanics(t, func() { + registry.SetCompletionFunc("nonexistent", completionFn) + }) +} + +func TestFlagRegistry_SetCompletionFunc_BoolFlagIgnored(t *testing.T) { + registry := NewFlagRegistry() + registry.Register(&BoolFlag{Name: "verbose", Description: "Verbose output"}) + + completionFn := func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + return []string{"true", "false"}, cobra.ShellCompDirectiveNoFileComp + } + + // Should not panic, but SetCompletionFunc only works for StringFlags. + assert.NotPanics(t, func() { + registry.SetCompletionFunc("verbose", completionFn) + }) + + // BoolFlag doesn't support CompletionFunc - setting it on a non-StringFlag is a no-op. + // The flag should still be registered and functional. + flag := registry.Get("verbose") + assert.NotNil(t, flag) + assert.Equal(t, "verbose", flag.GetName()) + + // Verify the flag type was not changed - it must still be a *BoolFlag. + // BoolFlag has no CompletionFunc field, confirming the call was definitively a no-op. + _, ok := flag.(*BoolFlag) + assert.True(t, ok, "flag should still be *BoolFlag after calling SetCompletionFunc on it") +} + +func TestFlagRegistry_BindToViper(t *testing.T) { + registry := NewFlagRegistry() + registry.Register(&StringFlag{ + Name: "stack", + EnvVars: []string{"ATMOS_STACK"}, + }) + registry.Register(&BoolFlag{ + Name: "dry-run", + EnvVars: []string{"ATMOS_DRY_RUN"}, + }) + + v := viper.New() + err := registry.BindToViper(v) + require.NoError(t, err) + + // BindToViper calls viper.BindEnv (not AutomaticEnv) to map flag names to env vars. + // This is intentional: BindEnv binds a specific env var to a specific key without + // enabling global env-var lookup. If the implementation changes to rely on AutomaticEnv + // instead, this test would silently pass while production behavior breaks for keys + // that don't follow the ATMOS_ prefix convention. + t.Setenv("ATMOS_STACK", "test-stack") + got := v.GetString("stack") + assert.Equal(t, "test-stack", got, "viper should read ATMOS_STACK via bound flag name 'stack'") +} + +func TestFlagRegistry_BindToViper_NoEnvVars(t *testing.T) { + registry := NewFlagRegistry() + registry.Register(&StringFlag{Name: "no-env-flag", Description: "Flag without env vars"}) + + v := viper.New() + err := registry.BindToViper(v) + // Flags without env vars should not cause errors. + assert.NoError(t, err) +} + func TestFlagRegistry_Validate(t *testing.T) { tests := []struct { name string diff --git a/pkg/flags/standard_builder_test.go b/pkg/flags/standard_builder_test.go index ae57babdf2..35ff77f42f 100644 --- a/pkg/flags/standard_builder_test.go +++ b/pkg/flags/standard_builder_test.go @@ -2,6 +2,7 @@ package flags import ( "context" + "path/filepath" "testing" "github.com/spf13/cobra" @@ -638,3 +639,496 @@ func TestStandardOptionsBuilder_WithModulePaths(t *testing.T) { parser := builder.Build() require.NotNil(t, parser) } + +func TestStandardOptionsBuilder_WithTimeout(t *testing.T) { + t.Run("flag registration", func(t *testing.T) { + builder := NewStandardOptionsBuilder().WithTimeout(30) + parser := builder.Build() + require.NotNil(t, parser) + + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("timeout") + require.NotNil(t, flag, "timeout flag should be registered") + assert.Equal(t, "30", flag.DefValue) + }) + + t.Run("default value", func(t *testing.T) { + builder := NewStandardOptionsBuilder().WithTimeout(30) + parser := builder.Build() + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{}) + require.NoError(t, err) + assert.Equal(t, 30, opts.Timeout) + }) + + t.Run("explicit value", func(t *testing.T) { + builder := NewStandardOptionsBuilder().WithTimeout(30) + parser := builder.Build() + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{"--timeout", "60"}) + require.NoError(t, err) + assert.Equal(t, 60, opts.Timeout) + }) +} + +func TestStandardOptionsBuilder_WithSchemasAtmosManifest(t *testing.T) { + builder := NewStandardOptionsBuilder().WithSchemasAtmosManifest("schema.json") + parser := builder.Build() + require.NotNil(t, parser) + + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("schemas-atmos-manifest") + require.NotNil(t, flag, "schemas-atmos-manifest flag should be registered") + assert.Equal(t, "schema.json", flag.DefValue) +} + +func TestStandardOptionsBuilder_WithLogin(t *testing.T) { + builder := NewStandardOptionsBuilder().WithLogin() + parser := builder.Build() + require.NotNil(t, parser) + + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("login") + require.NotNil(t, flag, "login flag should be registered") + + // Verify that the parsed value is propagated correctly. + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{"--login"}) + require.NoError(t, err) + assert.True(t, opts.Login) +} + +func TestStandardOptionsBuilder_WithProvider(t *testing.T) { + builder := NewStandardOptionsBuilder().WithProvider() + parser := builder.Build() + require.NotNil(t, parser) + + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("provider") + require.NotNil(t, flag, "provider flag should be registered") + + // Verify that the parsed value is propagated correctly. + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{"--provider", "aws"}) + require.NoError(t, err) + assert.Equal(t, "aws", opts.Provider) +} + +func TestStandardOptionsBuilder_WithProviders(t *testing.T) { + builder := NewStandardOptionsBuilder().WithProviders() + parser := builder.Build() + require.NotNil(t, parser) + + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("providers") + require.NotNil(t, flag, "providers flag should be registered") + + // Verify that the parsed value is propagated correctly. + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{"--providers", "aws,gcp"}) + require.NoError(t, err) + assert.Equal(t, "aws,gcp", opts.Providers) +} + +func TestStandardOptionsBuilder_WithIdentities(t *testing.T) { + builder := NewStandardOptionsBuilder().WithIdentities() + parser := builder.Build() + require.NotNil(t, parser) + + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("identities") + require.NotNil(t, flag, "identities flag should be registered") + + // Verify that the parsed value is propagated correctly. + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{"--identities", "role1,role2"}) + require.NoError(t, err) + assert.Equal(t, "role1,role2", opts.Identities) +} + +func TestStandardOptionsBuilder_WithAll(t *testing.T) { + builder := NewStandardOptionsBuilder().WithAll() + parser := builder.Build() + require.NotNil(t, parser) + + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("all") + require.NotNil(t, flag, "all flag should be registered") + + // Verify that the parsed value is propagated correctly. + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{"--all"}) + require.NoError(t, err) + assert.True(t, opts.All) +} + +func TestStandardOptionsBuilder_WithEverything(t *testing.T) { + builder := NewStandardOptionsBuilder().WithEverything() + parser := builder.Build() + require.NotNil(t, parser) + + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("everything") + require.NotNil(t, flag, "everything flag should be registered") + + // Verify that the parsed value is propagated correctly. + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{"--everything"}) + require.NoError(t, err) + assert.True(t, opts.Everything) +} + +func TestStandardOptionsBuilder_WithRef(t *testing.T) { + t.Run("flag registration with default", func(t *testing.T) { + builder := NewStandardOptionsBuilder().WithRef("main") + parser := builder.Build() + require.NotNil(t, parser) + + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("ref") + require.NotNil(t, flag, "ref flag should be registered") + assert.Equal(t, "main", flag.DefValue) + }) + + t.Run("default value", func(t *testing.T) { + builder := NewStandardOptionsBuilder().WithRef("main") + parser := builder.Build() + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{}) + require.NoError(t, err) + assert.Equal(t, "main", opts.Ref) + }) + + t.Run("explicit value overrides default", func(t *testing.T) { + builder := NewStandardOptionsBuilder().WithRef("main") + parser := builder.Build() + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{"--ref", "v1.0.0"}) + require.NoError(t, err) + assert.Equal(t, "v1.0.0", opts.Ref) + }) +} + +func TestStandardOptionsBuilder_WithSha(t *testing.T) { + builder := NewStandardOptionsBuilder().WithSha("") + parser := builder.Build() + require.NotNil(t, parser) + + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("sha") + require.NotNil(t, flag, "sha flag should be registered") + + // Verify that the parsed value is propagated correctly. + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{"--sha", "abc123def"}) + require.NoError(t, err) + assert.Equal(t, "abc123def", opts.Sha) +} + +func TestStandardOptionsBuilder_WithRepoPath(t *testing.T) { + repoPath := filepath.Join(t.TempDir(), "test-repo") + builder := NewStandardOptionsBuilder().WithRepoPath(repoPath) + parser := builder.Build() + require.NotNil(t, parser) + + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("repo-path") + require.NotNil(t, flag, "repo-path flag should be registered") + assert.Equal(t, repoPath, flag.DefValue) +} + +func TestStandardOptionsBuilder_WithSSHKey(t *testing.T) { + builder := NewStandardOptionsBuilder().WithSSHKey(filepath.Join(t.TempDir(), ".ssh", "id_rsa")) + parser := builder.Build() + require.NotNil(t, parser) + + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("ssh-key") + require.NotNil(t, flag, "ssh-key flag should be registered") +} + +func TestStandardOptionsBuilder_WithSSHKeyPassword(t *testing.T) { + builder := NewStandardOptionsBuilder().WithSSHKeyPassword("") + parser := builder.Build() + require.NotNil(t, parser) + + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("ssh-key-password") + require.NotNil(t, flag, "ssh-key-password flag should be registered") + + // Verify that the parsed value is propagated correctly. + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{"--ssh-key-password", "s3cr3t"}) + require.NoError(t, err) + assert.Equal(t, "s3cr3t", opts.SSHKeyPassword) +} + +func TestStandardOptionsBuilder_WithIncludeSpaceliftAdminStacks(t *testing.T) { + builder := NewStandardOptionsBuilder().WithIncludeSpaceliftAdminStacks() + parser := builder.Build() + require.NotNil(t, parser) + + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("include-spacelift-admin-stacks") + require.NotNil(t, flag, "include-spacelift-admin-stacks flag should be registered") + + // Verify that the parsed value is propagated correctly. + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{"--include-spacelift-admin-stacks"}) + require.NoError(t, err) + assert.True(t, opts.IncludeSpaceliftAdminStacks) +} + +func TestStandardOptionsBuilder_WithIncludeDependents(t *testing.T) { + builder := NewStandardOptionsBuilder().WithIncludeDependents() + parser := builder.Build() + require.NotNil(t, parser) + + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("include-dependents") + require.NotNil(t, flag, "include-dependents flag should be registered") + + // Verify that the parsed value is propagated correctly. + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{"--include-dependents"}) + require.NoError(t, err) + assert.True(t, opts.IncludeDependents) +} + +func TestStandardOptionsBuilder_WithIncludeSettings(t *testing.T) { + builder := NewStandardOptionsBuilder().WithIncludeSettings() + parser := builder.Build() + require.NotNil(t, parser) + + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("include-settings") + require.NotNil(t, flag, "include-settings flag should be registered") + + // Verify that the parsed value is propagated correctly. + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{"--include-settings"}) + require.NoError(t, err) + assert.True(t, opts.IncludeSettings) +} + +func TestStandardOptionsBuilder_WithUpload(t *testing.T) { + builder := NewStandardOptionsBuilder().WithUpload() + parser := builder.Build() + require.NotNil(t, parser) + + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("upload") + require.NotNil(t, flag, "upload flag should be registered") + + // Verify that the parsed value is propagated correctly. + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{"--upload"}) + require.NoError(t, err) + assert.True(t, opts.Upload) +} + +func TestStandardOptionsBuilder_WithCloneTargetRef(t *testing.T) { + builder := NewStandardOptionsBuilder().WithCloneTargetRef() + parser := builder.Build() + require.NotNil(t, parser) + + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("clone-target-ref") + require.NotNil(t, flag, "clone-target-ref flag should be registered") + + // Verify that the parsed value is propagated correctly. + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{"--clone-target-ref"}) + require.NoError(t, err) + assert.True(t, opts.CloneTargetRef) +} + +func TestStandardOptionsBuilder_WithExcludeLocked(t *testing.T) { + builder := NewStandardOptionsBuilder().WithExcludeLocked() + parser := builder.Build() + require.NotNil(t, parser) + + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("exclude-locked") + require.NotNil(t, flag, "exclude-locked flag should be registered") + + // Verify that the parsed value is propagated correctly. + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{"--exclude-locked"}) + require.NoError(t, err) + assert.True(t, opts.ExcludeLocked) +} + +func TestStandardOptionsBuilder_WithComponents(t *testing.T) { + builder := NewStandardOptionsBuilder().WithComponents() + parser := builder.Build() + require.NotNil(t, parser) + + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("components") + require.NotNil(t, flag, "components flag should be registered") + + // Verify that the parsed value is propagated correctly. + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{"--components", "vpc", "--components", "rds"}) + require.NoError(t, err) + assert.Equal(t, []string{"vpc", "rds"}, opts.Components) +} + +func TestStandardOptionsBuilder_WithOutput(t *testing.T) { + validOutputs := []string{"json", "yaml", "table"} + + t.Run("flag registration with default", func(t *testing.T) { + builder := NewStandardOptionsBuilder().WithOutput(validOutputs, "json") + parser := builder.Build() + require.NotNil(t, parser) + + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + flag := cmd.Flags().Lookup("output") + require.NotNil(t, flag, "output flag should be registered") + assert.Equal(t, "json", flag.DefValue) + }) + + t.Run("default value", func(t *testing.T) { + builder := NewStandardOptionsBuilder().WithOutput(validOutputs, "json") + parser := builder.Build() + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{}) + require.NoError(t, err) + assert.Equal(t, "json", opts.Output) + }) + + t.Run("explicit value overrides default", func(t *testing.T) { + builder := NewStandardOptionsBuilder().WithOutput(validOutputs, "json") + parser := builder.Build() + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{"--output", "yaml"}) + require.NoError(t, err) + assert.Equal(t, "yaml", opts.Output) + }) +} + +func TestStandardOptionsBuilder_WithPositionalArgs(t *testing.T) { + specs := []*PositionalArgSpec{ + { + Name: "component", + Description: "Component name", + Required: true, + TargetField: "Component", + }, + } + validator := cobra.ExactArgs(1) + + builder := NewStandardOptionsBuilder().WithPositionalArgs(specs, validator, "component") + parser := builder.Build() + require.NotNil(t, parser) + + // Test that positional args are extracted correctly via Parse. + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{"vpc"}) + require.NoError(t, err) + + // Verify the component was extracted from positional args. + assert.Equal(t, "vpc", opts.Component) +} diff --git a/pkg/flags/standard_test.go b/pkg/flags/standard_test.go index 1307e80046..793fdbad5c 100644 --- a/pkg/flags/standard_test.go +++ b/pkg/flags/standard_test.go @@ -1413,3 +1413,104 @@ func TestStandardFlagParser_PromptForOptionalValueFlags_MultipleFlagsOrder(t *te assert.Equal(t, "z-default", result.Flags["zebra"]) }) } + +// TestStandardFlagParser_Registry verifies that Registry() returns the underlying flag registry. +func TestStandardFlagParser_Registry(t *testing.T) { + parser := NewStandardFlagParser( + WithStringFlag("stack", "s", "", "Stack name"), + WithBoolFlag("dry-run", "", false, "Dry run mode"), + ) + + registry := parser.Registry() + require.NotNil(t, registry) + assert.Equal(t, 2, registry.Count()) + assert.True(t, registry.Has("stack")) + assert.True(t, registry.Has("dry-run")) +} + +// TestRegisterCompletionRecursive verifies that registerCompletionRecursive propagates +// completion functions to all descendant commands that have the named flag. +func TestRegisterCompletionRecursive(t *testing.T) { + completionFn := func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + return []string{"dev", "staging", "prod"}, cobra.ShellCompDirectiveNoFileComp + } + + // Build a command tree where child1 and grandchild each own the --stack flag + // (persistent flags from root are NOT visible via child.Flags().Lookup, so each + // descendant that should receive the completion must register the flag itself). + // + // root + // └── child1 (owns --stack as a persistent flag) + // └── grandchild (owns --stack as a local flag) + root := &cobra.Command{Use: "root"} + + child1 := &cobra.Command{Use: "child1"} + child1.PersistentFlags().String("stack", "", "Stack name") + root.AddCommand(child1) + + grandchild := &cobra.Command{Use: "grandchild"} + grandchild.Flags().String("stack", "", "Stack name") + child1.AddCommand(grandchild) + + // Call registerCompletionRecursive starting from root. + registerCompletionRecursive(root, "stack", completionFn) + + // Verify child1 has the completion function registered and it returns expected values. + gotFn, found := child1.GetFlagCompletionFunc("stack") + require.True(t, found, "child1 should have the completion function registered") + results, directive := gotFn(child1, nil, "") + assert.Equal(t, []string{"dev", "staging", "prod"}, results) + assert.Equal(t, cobra.ShellCompDirectiveNoFileComp, directive) + + // Verify grandchild also has the completion function registered. + gotFn2, found2 := grandchild.GetFlagCompletionFunc("stack") + require.True(t, found2, "grandchild should have the completion function registered") + results2, directive2 := gotFn2(grandchild, nil, "") + assert.Equal(t, []string{"dev", "staging", "prod"}, results2) + assert.Equal(t, cobra.ShellCompDirectiveNoFileComp, directive2) +} + +// TestStandardParser_Registry verifies that StandardParser.Registry() delegates correctly. +func TestStandardParser_Registry(t *testing.T) { + parser := NewStandardParser( + WithStringFlag("format", "f", "yaml", "Output format"), + ) + + registry := parser.Registry() + require.NotNil(t, registry) + assert.True(t, registry.Has("format")) +} + +// TestStandardParser_SetPositionalArgs verifies that StandardParser.SetPositionalArgs +// correctly configures positional argument handling. +func TestStandardParser_SetPositionalArgs(t *testing.T) { + parser := NewStandardParser( + WithStringFlag("stack", "s", "", "Stack name"), + ) + + specs := []*PositionalArgSpec{ + { + Name: "component", + Description: "Component name", + Required: true, + TargetField: "Component", + }, + } + validator := cobra.ExactArgs(1) + + // Should not panic. + assert.NotPanics(t, func() { + parser.SetPositionalArgs(specs, validator, "component") + }) + + // Verify that positional args are actually extracted during Parse. + cmd := &cobra.Command{Use: "test"} + parser.RegisterFlags(cmd) + + v := viper.New() + _ = parser.BindToViper(v) + + opts, err := parser.Parse(context.Background(), []string{"vpc"}) + require.NoError(t, err) + assert.Equal(t, "vpc", opts.Component, "positional arg should be mapped to Component field") +} diff --git a/pkg/flags/types_test.go b/pkg/flags/types_test.go index a6601340a4..7ebe2dc677 100644 --- a/pkg/flags/types_test.go +++ b/pkg/flags/types_test.go @@ -84,3 +84,55 @@ func TestIdentityFlag(t *testing.T) { assert.Equal(t, "__SELECT__", flag.GetNoOptDefVal()) assert.Equal(t, []string{"ATMOS_IDENTITY", "IDENTITY"}, flag.GetEnvVars()) } + +// TestGetInt verifies that GetInt correctly retrieves integer values from flags map. +func TestGetInt(t *testing.T) { + t.Run("returns integer value when key exists", func(t *testing.T) { + m := map[string]interface{}{"count": 42} + assert.Equal(t, 42, GetInt(m, "count")) + }) + + t.Run("returns zero when key does not exist", func(t *testing.T) { + m := map[string]interface{}{} + assert.Equal(t, 0, GetInt(m, "count")) + }) + + t.Run("returns zero when value is not int", func(t *testing.T) { + m := map[string]interface{}{"count": "not-an-int"} + assert.Equal(t, 0, GetInt(m, "count")) + }) +} + +// TestParsedConfig_GetIdentity verifies GetIdentity delegates to GetString for "identity". +func TestParsedConfig_GetIdentity(t *testing.T) { + t.Run("returns identity when set", func(t *testing.T) { + pc := &ParsedConfig{ + Flags: map[string]interface{}{"identity": "prod"}, + } + assert.Equal(t, "prod", pc.GetIdentity()) + }) + + t.Run("returns empty when not set", func(t *testing.T) { + pc := &ParsedConfig{ + Flags: map[string]interface{}{}, + } + assert.Equal(t, "", pc.GetIdentity()) + }) +} + +// TestParsedConfig_GetStack verifies GetStack delegates to GetString for "stack". +func TestParsedConfig_GetStack(t *testing.T) { + t.Run("returns stack when set", func(t *testing.T) { + pc := &ParsedConfig{ + Flags: map[string]interface{}{"stack": "ue2-dev"}, + } + assert.Equal(t, "ue2-dev", pc.GetStack()) + }) + + t.Run("returns empty when not set", func(t *testing.T) { + pc := &ParsedConfig{ + Flags: map[string]interface{}{}, + } + assert.Equal(t, "", pc.GetStack()) + }) +} diff --git a/pkg/function/env_test.go b/pkg/function/env_test.go index b524e5d6d1..d1d3a5620d 100644 --- a/pkg/function/env_test.go +++ b/pkg/function/env_test.go @@ -6,6 +6,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/cloudposse/atmos/pkg/schema" ) func TestEnvFunction_Execute_EdgeCases(t *testing.T) { @@ -145,3 +147,42 @@ func TestEnvFunction_Metadata(t *testing.T) { assert.Equal(t, TagEnv, fn.Name()) assert.Equal(t, PreMerge, fn.Phase()) } + +func TestLookupEnvFromContext_WithEnvSection(t *testing.T) { + envSection := map[string]any{ + "MY_VAR": "my_value", + "PORT": 8080, + "ENABLED": true, + } + + stackInfo := &schema.ConfigAndStacksInfo{ + ComponentEnvSection: envSection, + } + execCtx := &ExecutionContext{StackInfo: stackInfo} + + t.Run("key found returns string value", func(t *testing.T) { + val, found := lookupEnvFromContext(execCtx, "MY_VAR") + assert.True(t, found) + assert.Equal(t, "my_value", val) + }) + + t.Run("key found returns formatted non-string value", func(t *testing.T) { + val, found := lookupEnvFromContext(execCtx, "PORT") + assert.True(t, found) + assert.Equal(t, "8080", val) + }) + + t.Run("key not found returns empty", func(t *testing.T) { + val, found := lookupEnvFromContext(execCtx, "NONEXISTENT") + assert.False(t, found) + assert.Empty(t, val) + }) + + t.Run("nil env section returns false", func(t *testing.T) { + emptyInfo := &schema.ConfigAndStacksInfo{} + emptyCtx := &ExecutionContext{StackInfo: emptyInfo} + val, found := lookupEnvFromContext(emptyCtx, "MY_VAR") + assert.False(t, found) + assert.Empty(t, val) + }) +} diff --git a/pkg/function/literal_test.go b/pkg/function/literal_test.go new file mode 100644 index 0000000000..689d3f9c58 --- /dev/null +++ b/pkg/function/literal_test.go @@ -0,0 +1,70 @@ +package function + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewLiteralFunction(t *testing.T) { + fn := NewLiteralFunction() + require.NotNil(t, fn) + assert.Equal(t, TagLiteral, fn.Name()) + assert.Equal(t, PreMerge, fn.Phase()) +} + +func TestLiteralFunction_Execute(t *testing.T) { + tests := []struct { + name string + args string + expected string + }{ + { + name: "terraform template syntax", + args: "{{external.email}}", + expected: "{{external.email}}", + }, + { + name: "helm template syntax", + args: "{{ .Values.ingress.class }}", + expected: "{{ .Values.ingress.class }}", + }, + { + name: "bash variable syntax", + args: "${USER}", + expected: "${USER}", + }, + { + name: "leading and trailing whitespace trimmed", + args: " hello world ", + expected: "hello world", + }, + { + name: "empty string", + args: "", + expected: "", + }, + { + name: "multiline value", + args: "#!/bin/bash\necho \"Hello ${USER}\"", + expected: "#!/bin/bash\necho \"Hello ${USER}\"", + }, + { + name: "plain string preserved", + args: "simple-value", + expected: "simple-value", + }, + } + + fn := NewLiteralFunction() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := fn.Execute(context.Background(), tt.args, nil) + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/http/client.go b/pkg/http/client.go index 72fb8a2251..8f6b1fde6d 100644 --- a/pkg/http/client.go +++ b/pkg/http/client.go @@ -7,8 +7,11 @@ import ( "errors" "fmt" "io" + "net" "net/http" + "net/url" "os" + "strings" "time" "github.com/spf13/viper" @@ -31,12 +34,18 @@ const ( // Client defines the interface for making HTTP requests. // This interface allows for easy mocking in tests. +// +// See the package documentation (doc.go) for guidance on option ordering +// when composing [ClientOption] values such as [WithTransport] and [WithGitHubToken]. type Client interface { // Do performs an HTTP request and returns the response. Do(req *http.Request) (*http.Response, error) } // ClientOption is a functional option for configuring the DefaultClient. +// Options are applied in the order they are passed to [NewDefaultClient]. +// For composition rules when mixing [WithTransport] and [WithGitHubToken], see the +// "Option ordering" section in the package documentation (doc.go). type ClientOption func(*DefaultClient) // WithTimeout sets the HTTP client timeout. @@ -50,6 +59,11 @@ func WithTimeout(timeout time.Duration) ClientOption { // WithGitHubToken sets the GitHub token for authenticated requests. // Wraps the existing transport instead of replacing it to allow composition with WithTransport. +// It also installs a CheckRedirect handler that strips the managed Authorization header +// when a redirect crosses to a different host, preventing token leakage. +// +// Triple-composition caveat: if a second WithTransport call follows this option, the +// earlier base transport is silently replaced. See WithTransport for details. func WithGitHubToken(token string) ClientOption { defer perf.Track(nil, "http.WithGitHubToken")() @@ -64,16 +78,78 @@ func WithGitHubToken(token string) ClientOption { Base: base, GitHubToken: token, } + + // Install a redirect policy that strips Authorization on cross-host redirects. + // The transport adds Authorization per-hop (only for allowed hosts), but the + // http.Client may also forward headers from the original request on redirects. + // This ensures no stale Authorization header leaks to an unexpected host. + if c.client.CheckRedirect == nil { + c.client.CheckRedirect = stripAuthOnCrossHostRedirect + } + } + } +} + +// stripAuthOnCrossHostRedirect removes the Authorization header from a redirect request +// when the target host (host:port) differs from the originating host. +// Using req.URL.Host (not Hostname) preserves port information so that same-IP +// different-port redirects are treated as cross-host, matching Go's own redirect policy. +func stripAuthOnCrossHostRedirect(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return errUtils.ErrRedirectLimitExceeded + } + if len(via) > 0 && req.URL.Host != via[0].URL.Host { + req.Header.Del("Authorization") + } + return nil +} + +// WithGitHubHostMatcher sets a custom host-matching predicate on the GitHub authenticated +// transport. The predicate receives the request hostname (without port) and returns true +// when the host should receive GitHub authentication headers. +// +// This is useful for GitHub Enterprise Server (GHES) deployments or custom GitHub proxies +// where the API is hosted on a non-standard domain. +// +// Example usage: +// +// client := NewDefaultClient( +// WithGitHubToken("token"), +// WithGitHubHostMatcher(func(host string) bool { +// return host == "github.mycorp.example.com" +// }), +// ) +// +// If this option is applied before WithGitHubToken, it has no effect because there is no +// transport to configure yet. Apply it after WithGitHubToken. +func WithGitHubHostMatcher(matcher func(string) bool) ClientOption { + defer perf.Track(nil, "http.WithGitHubHostMatcher")() + + return func(c *DefaultClient) { + if authTransport, ok := c.client.Transport.(*GitHubAuthenticatedTransport); ok { + authTransport.hostMatcher = matcher } } } // WithTransport sets a custom HTTP transport. +// If a GitHubAuthenticatedTransport has already been applied (e.g., via WithGitHubToken), +// the provided transport is set as its Base rather than replacing the auth wrapper. +// This preserves GitHub authentication regardless of option order. +// +// Triple-composition note: when a second WithTransport call follows WithGitHubToken, the +// earlier base transport (from the first WithTransport) is silently replaced by the new one. +// Example: WithTransport(t1), WithGitHubToken("x"), WithTransport(t2) +// Result: GitHubAuthenticatedTransport{Base: t2, Token: "x"}; t1 is discarded. func WithTransport(transport http.RoundTripper) ClientOption { defer perf.Track(nil, "http.WithTransport")() return func(c *DefaultClient) { - c.client.Transport = transport + if authTransport, ok := c.client.Transport.(*GitHubAuthenticatedTransport); ok { + authTransport.Base = transport + } else { + c.client.Transport = transport + } } } @@ -103,6 +179,55 @@ func NewDefaultClient(opts ...ClientOption) *DefaultClient { type GitHubAuthenticatedTransport struct { Base http.RoundTripper GitHubToken string + + // hostMatcher is an optional custom predicate that decides whether a given hostname + // should receive GitHub authentication headers. If nil, the default allowlist is used. + // See WithGitHubHostMatcher for details. + hostMatcher func(string) bool +} + +// normalizeHost canonicalizes a hostname for allowlist comparison: +// it lower-cases the string, strips a trailing dot (FQDN form), and removes +// default HTTP/HTTPS ports (:80 and :443) so that "api.github.com:443" is +// treated identically to "api.github.com". +// +// net.SplitHostPort is used to handle IPv6 literals safely (e.g., "[::1]:443" +// is split to "::1" and "443", and the brackets are dropped for comparison). +// Non-default ports (e.g., :8443) are preserved unchanged. +// +// Note: in the hot path, callers pass url.URL.Hostname() which already strips +// the port, making the port-stripping here a defence-in-depth measure. +func normalizeHost(host string) string { + host = strings.ToLower(host) + host = strings.TrimSuffix(host, ".") + // Strip default ports so that "api.github.com:443" matches "api.github.com". + // Also strip any trailing dot from the host part (handles "api.github.com.:443"). + if h, port, err := net.SplitHostPort(host); err == nil && (port == "443" || port == "80") { + host = strings.TrimSuffix(h, ".") + } + return host +} + +// isGitHubHost is the default host allowlist. +// It is also used as the fallback when GitHubAuthenticatedTransport.hostMatcher is nil. +// +// Precedence: WithGitHubHostMatcher (explicit custom predicate) takes full precedence over +// this default allowlist, including the GITHUB_API_URL lookup. If you need GHES support +// together with a custom matcher, include the GHES host in your custom predicate. +func isGitHubHost(host string) bool { + host = normalizeHost(host) + + // Respect GITHUB_API_URL for GitHub Enterprise Server (GHES) and similar deployments. + // When set, the hostname of GITHUB_API_URL is treated as an allowed GitHub API host. + //nolint:forbidigo // Direct env lookup required for GHES configuration. + if apiURL := os.Getenv("GITHUB_API_URL"); apiURL != "" { + parsed, err := url.ParseRequestURI(apiURL) + if err == nil && normalizeHost(parsed.Hostname()) == host { + return true + } + } + + return host == "api.github.com" || host == "raw.githubusercontent.com" || host == "uploads.github.com" } // RoundTrip implements http.RoundTripper interface. @@ -112,9 +237,27 @@ func (t *GitHubAuthenticatedTransport) RoundTrip(req *http.Request) (*http.Respo // Clone request to avoid mutating caller's request. reqClone := req.Clone(req.Context()) - host := reqClone.URL.Hostname() - if (host == "api.github.com" || host == "raw.githubusercontent.com") && t.GitHubToken != "" { - reqClone.Header.Set("Authorization", "Bearer "+t.GitHubToken) + // Normalize the hostname to ensure consistent matching regardless of case, + // trailing dots (FQDN form), or port remnants. + host := normalizeHost(reqClone.URL.Hostname()) + scheme := reqClone.URL.Scheme + + // Determine whether the host is allowed to receive authentication headers. + // WithGitHubHostMatcher (t.hostMatcher) takes full precedence; the default + // allowlist (isGitHubHost, including GITHUB_API_URL lookup) is used as fallback. + matcher := t.hostMatcher + if matcher == nil { + matcher = isGitHubHost + } + + // Only inject Authorization when ALL of the following are true: + // 1. The scheme is "https" (prevent token leakage over plain HTTP). + // 2. The host is in the allowed list. + // 3. The header is not already set (outermost transport wins on multi-layer composition). + if scheme == "https" && matcher(host) && t.GitHubToken != "" { + if reqClone.Header.Get("Authorization") == "" { + reqClone.Header.Set("Authorization", "Bearer "+t.GitHubToken) + } reqClone.Header.Set("User-Agent", userAgent) } @@ -139,11 +282,20 @@ func (t *GitHubAuthenticatedTransport) RoundTrip(req *http.Request) (*http.Respo // // The viper binding is configured in cmd/toolchain/toolchain.go for toolchain commands. // For non-toolchain commands, we fall back to direct environment variable lookup. -func GetGitHubTokenFromEnv() string { +// +// An optional *viper.Viper instance may be passed; when provided it is used instead of +// the global viper singleton. This is primarily useful in tests to avoid mutating +// shared global state. +func GetGitHubTokenFromEnv(v ...*viper.Viper) string { defer perf.Track(nil, "http.GetGitHubTokenFromEnv")() + viperInst := viper.GetViper() + if len(v) > 0 && v[0] != nil { + viperInst = v[0] + } + // First try viper (for toolchain commands with --github-token flag). - if token := viper.GetString("github-token"); token != "" { + if token := viperInst.GetString("github-token"); token != "" { return token } diff --git a/pkg/http/client_test.go b/pkg/http/client_test.go index 9afef26e4f..36997d866c 100644 --- a/pkg/http/client_test.go +++ b/pkg/http/client_test.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "net/http/httptest" + "reflect" "strings" "testing" "time" @@ -92,13 +93,13 @@ func TestGetGitHubTokenFromEnv(t *testing.T) { t.Setenv("ATMOS_GITHUB_TOKEN", tt.atmosToken) t.Setenv("GITHUB_TOKEN", tt.githubToken) - // Since GetGitHubTokenFromEnv() now uses global viper, we need to - // manually bind the environment variables for this test. - // In production, this is done by GlobalOptionsBuilder. - v := viper.GetViper() + // Use an isolated viper instance to avoid mutating the global singleton. + // This prevents BindEnv from leaking env-var mappings into subsequent tests. + // In production, GlobalOptionsBuilder binds these on the global viper instance. + v := viper.New() _ = v.BindEnv("github-token", "ATMOS_GITHUB_TOKEN", "GITHUB_TOKEN") - got := GetGitHubTokenFromEnv() + got := GetGitHubTokenFromEnv(v) assert.Equal(t, tt.want, got) }) } @@ -133,6 +134,40 @@ func TestGitHubAuthenticatedTransport_RoundTrip(t *testing.T) { expectAuth: false, expectUserAgent: false, }, + { + name: "uploads.github.com sets headers", + url: "https://uploads.github.com/repos/test/repo/releases/1/assets", + token: "test-token", + expectAuth: true, + expectUserAgent: true, + }, + { + // github.example.com looks like it has "github" in it but is NOT an allowed host. + // Authorization must NOT be leaked to arbitrary subdomains. + name: "github.example.com does not set auth header", + url: "https://github.example.com/api", + token: "test-token", + expectAuth: false, + expectUserAgent: false, + }, + { + // example.github.com is a GitHub-owned subdomain but NOT in the explicit allowlist. + // Authorization must NOT be set for unlisted GitHub subdomains. + name: "example.github.com does not set auth header", + url: "https://example.github.com/api", + token: "test-token", + expectAuth: false, + expectUserAgent: false, + }, + { + // Plain HTTP (not HTTPS) api.github.com — Authorization MUST NOT be set. + // Sending tokens over unencrypted HTTP would leak credentials. + name: "http scheme api.github.com does not set auth header", + url: "http://api.github.com/repos/test/repo", + token: "test-token", + expectAuth: false, + expectUserAgent: false, + }, { name: "empty token does not set auth header", url: "https://api.github.com/repos/test/repo", @@ -463,3 +498,803 @@ type errorReader struct{} func (e *errorReader) Read(p []byte) (n int, err error) { return 0, fmt.Errorf("read error") } + +func TestWithTransport(t *testing.T) { + // Create a mock transport that records requests. + var capturedReq *http.Request + mockTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + }, nil + }) + + // Create client with custom transport. + client := NewDefaultClient(WithTransport(mockTransport)) + assert.NotNil(t, client) + + // Make a request to verify custom transport is used. + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.com/test", nil) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.NotNil(t, capturedReq, "mock transport should have captured the request") + assert.Equal(t, "http://example.com/test", capturedReq.URL.String()) +} + +// TestGitHubAuthenticatedTransport_NilBase verifies that when Base transport is nil, +// the GitHubAuthenticatedTransport falls back to http.DefaultTransport. +func TestGitHubAuthenticatedTransport_NilBase(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + + transport := &GitHubAuthenticatedTransport{ + Base: nil, // Explicitly set to nil - should fall back to http.DefaultTransport. + GitHubToken: "", + } + + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) +} + +// TestGetGitHubTokenFromEnv_ViperPrecedence verifies that the viper value takes +// precedence over environment variables. +func TestGetGitHubTokenFromEnv_ViperPrecedence(t *testing.T) { + t.Setenv("ATMOS_GITHUB_TOKEN", "env-token") + t.Setenv("GITHUB_TOKEN", "fallback-token") + + // Use an isolated viper instance to avoid mutating the global singleton. + // This prevents BindEnv from leaking env-var mappings into subsequent tests. + v := viper.New() + _ = v.BindEnv("github-token", "ATMOS_GITHUB_TOKEN", "GITHUB_TOKEN") + + // Override viper to simulate --github-token flag. + v.Set("github-token", "viper-token") + + got := GetGitHubTokenFromEnv(v) + assert.Equal(t, "viper-token", got) +} + +// TestGetGitHubTokenFromEnv_NilViperFallsBackToOsEnv verifies that passing an explicit +// nil viper instance falls back to the global viper singleton (which has no token binding +// in this test context), and then falls through to the os.Getenv fallback path. +func TestGetGitHubTokenFromEnv_NilViperFallsBackToOsEnv(t *testing.T) { + t.Setenv("ATMOS_GITHUB_TOKEN", "nil-guard-token") + + // Passing nil must not panic — it must fall back to global viper, which in turn + // falls back to os.Getenv for ATMOS_GITHUB_TOKEN (since no BindEnv is active here). + assert.NotPanics(t, func() { + _ = GetGitHubTokenFromEnv(nil) + }) + + // The token is returned via the os.Getenv("ATMOS_GITHUB_TOKEN") fallback path, + // not via the global viper key (which is unbound in this test). + got := GetGitHubTokenFromEnv(nil) + assert.Equal(t, "nil-guard-token", got) +} + +// TestWithTransport_AfterWithGitHubToken verifies that WithTransport applied after +// WithGitHubToken does NOT drop the auth wrapper; the provided transport becomes the +// inner base of the GitHubAuthenticatedTransport. +func TestWithTransport_AfterWithGitHubToken(t *testing.T) { + var capturedReq *http.Request + mockTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + }, nil + }) + + // WithGitHubToken first, then WithTransport. + client := NewDefaultClient( + WithGitHubToken("secret-token"), + WithTransport(mockTransport), + ) + + // Verify the transport chain structure: outer is GitHubAuthenticatedTransport, + // inner (Base) is the mock roundTripperFunc. + authTransport, ok := client.client.Transport.(*GitHubAuthenticatedTransport) + require.True(t, ok, "client transport should be *GitHubAuthenticatedTransport") + _, baseIsRoundTripper := authTransport.Base.(roundTripperFunc) + assert.True(t, baseIsRoundTripper, "Base should be a roundTripperFunc (the mockTransport)") + assert.Equal(t, "secret-token", authTransport.GitHubToken) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://api.github.com/repos/test/repo", nil) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.NotNil(t, capturedReq, "mock transport should have been reached") + assert.Equal(t, "Bearer secret-token", capturedReq.Header.Get("Authorization"), + "Authorization header must be set when WithGitHubToken is applied after WithTransport") +} + +// TestWithGitHubToken_AfterWithTransport verifies that WithGitHubToken applied after +// WithTransport wraps the custom transport inside the auth layer. +func TestWithGitHubToken_AfterWithTransport(t *testing.T) { + var capturedReq *http.Request + mockTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + }, nil + }) + + // WithTransport first, then WithGitHubToken wraps it. + client := NewDefaultClient( + WithTransport(mockTransport), + WithGitHubToken("secret-token"), + ) + + // Verify the transport chain structure: outer is GitHubAuthenticatedTransport, + // inner (Base) is the mock roundTripperFunc. + authTransport, ok := client.client.Transport.(*GitHubAuthenticatedTransport) + require.True(t, ok, "client transport should be *GitHubAuthenticatedTransport") + + // Structural check: verify Base is the exact mockTransport instance (not just the type). + // Function values are not == comparable in Go; use reflect.ValueOf().Pointer() to compare + // the underlying function pointer, which is stable and not affected by interface boxing. + baseTransport, baseIsRoundTripper := authTransport.Base.(roundTripperFunc) + require.True(t, baseIsRoundTripper, "Base should be a roundTripperFunc (the mockTransport)") + assert.Equal(t, reflect.ValueOf(http.RoundTripper(mockTransport)).Pointer(), + reflect.ValueOf(http.RoundTripper(baseTransport)).Pointer(), + "Base transport pointer must match the exact mockTransport instance") + assert.Equal(t, "secret-token", authTransport.GitHubToken) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://api.github.com/repos/test/repo", nil) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.NotNil(t, capturedReq, "mock transport should have been reached") + assert.Equal(t, "Bearer secret-token", capturedReq.Header.Get("Authorization"), + "Authorization header must be set when WithGitHubToken is applied after WithTransport") +} + +// TestWithTransport_TripleComposition verifies that the last WithTransport call replaces +// the Base of any existing GitHubAuthenticatedTransport, not the auth wrapper itself. +// Applied as: WithTransport(t1) → WithGitHubToken("x") → WithTransport(t2) +// Result: GitHubAuthenticatedTransport{Base: t2, Token: "x"} (t1 is discarded by the second WithTransport). +func TestWithTransport_TripleComposition(t *testing.T) { + var t1Reached, t2Reached bool + + transport1 := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + t1Reached = true + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("t1"))}, nil + }) + + var capturedReq *http.Request + transport2 := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + t2Reached = true + capturedReq = req + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("t2"))}, nil + }) + + // WithTransport(t1), WithGitHubToken wraps t1, then WithTransport(t2) replaces Base with t2. + client := NewDefaultClient( + WithTransport(transport1), + WithGitHubToken("triple-token"), + WithTransport(transport2), + ) + + // The auth wrapper must still be present (not replaced by the second WithTransport). + authTransport, ok := client.client.Transport.(*GitHubAuthenticatedTransport) + require.True(t, ok, "client transport must still be *GitHubAuthenticatedTransport") + assert.Equal(t, "triple-token", authTransport.GitHubToken) + // The second WithTransport replaces Base with transport2 (a roundTripperFunc). + _, baseIsRoundTripper := authTransport.Base.(roundTripperFunc) + assert.True(t, baseIsRoundTripper, "Base should be transport2 (roundTripperFunc) after second WithTransport") + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://api.github.com/repos/test/repo", nil) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Only transport2 must be reached; transport1 was replaced. + assert.True(t, t2Reached, "transport2 should have been reached") + assert.False(t, t1Reached, "transport1 should NOT be reached (replaced by transport2)") + require.NotNil(t, capturedReq) + assert.Equal(t, "Bearer triple-token", capturedReq.Header.Get("Authorization"), + "Authorization header must be present after triple composition") +} + +// TestWithGitHubToken_MultipleCallsLastWins is a regression test for the multiple +// WithGitHubToken wrappers bug. When two WithGitHubToken calls are composed, the INNER +// (earlier-applied) transport's RoundTrip previously overwrote the OUTER (later-applied) +// transport's Authorization header, causing the wrong token to be sent. +// After the fix (only set Authorization if not already set), the outermost (last-applied) +// token must win. +func TestWithGitHubToken_MultipleCallsLastWins(t *testing.T) { + var capturedReq *http.Request + mockTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + }, nil + }) + + // Apply t1 first, t2 second — t2 is the outermost (last-applied) wrapper. + // t2's token must win. Before the fix, t1 (inner) would overwrite t2 (outer). + client := NewDefaultClient( + WithTransport(mockTransport), + WithGitHubToken("token-t1"), + WithGitHubToken("token-t2"), + ) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://api.github.com/repos/test/repo", nil) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.NotNil(t, capturedReq, "mock transport should have been reached") + // The outermost (last-applied, t2) token must be in the Authorization header. + // Before the fix, the inner t1 would overwrite: Authorization: Bearer token-t1 (wrong). + // After the fix: Authorization: Bearer token-t2 (correct). + assert.Equal(t, "Bearer token-t2", capturedReq.Header.Get("Authorization"), + "last-applied (outermost) token must win when multiple WithGitHubToken calls are composed") +} + +// TestGitHubAuthenticatedTransport_PresetAuthorizationNotClobbered verifies that a +// pre-existing Authorization header on the request is NOT overwritten by the transport. +// This tests the "only set if empty" guard at the transport level. +func TestGitHubAuthenticatedTransport_PresetAuthorizationNotClobbered(t *testing.T) { + var capturedReq *http.Request + mockTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + }, nil + }) + + client := NewDefaultClient( + WithTransport(mockTransport), + WithGitHubToken("injected-token"), + ) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://api.github.com/repos/test/repo", nil) + require.NoError(t, err) + // Pre-set a caller-supplied Authorization header — must survive the transport. + req.Header.Set("Authorization", "Bearer preset-token") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.NotNil(t, capturedReq, "mock transport must be reached") + assert.Equal(t, "Bearer preset-token", capturedReq.Header.Get("Authorization"), + "transport must not overwrite a pre-set Authorization header on the request") +} + +// TestGitHubAuthenticatedTransport_GHES verifies that a GitHub Enterprise Server host +// (specified via GITHUB_API_URL) receives authentication headers. +func TestGitHubAuthenticatedTransport_GHES(t *testing.T) { + ghesHost := "github.mycorp.example.com" + ghesURL := "https://" + ghesHost + + // Set GITHUB_API_URL so isGitHubHost recognises the GHES host. + t.Setenv("GITHUB_API_URL", ghesURL) + + var capturedReq *http.Request + mockTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("OK")), + }, nil + }) + + transport := &GitHubAuthenticatedTransport{ + Base: mockTransport, + GitHubToken: "ghes-token", + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, ghesURL+"/api/v3/repos/org/repo", nil) + require.NoError(t, err) + + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + require.NotNil(t, resp) + defer resp.Body.Close() + + require.NotNil(t, capturedReq, "mock transport must be reached") + assert.Equal(t, "Bearer ghes-token", capturedReq.Header.Get("Authorization"), + "GHES host from GITHUB_API_URL must receive Authorization header") + assert.Equal(t, userAgent, capturedReq.Header.Get("User-Agent")) +} + +// TestGitHubAuthenticatedTransport_GHES_NegativeSubdomain verifies that a host that +// contains the GHES hostname as a substring (e.g. attacker.github.mycorp.example.com) +// does NOT receive the Authorization header. +func TestGitHubAuthenticatedTransport_GHES_NegativeSubdomain(t *testing.T) { + // Set GITHUB_API_URL for a specific GHES host. + t.Setenv("GITHUB_API_URL", "https://github.mycorp.example.com") + + var capturedReq *http.Request + mockTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("OK")), + }, nil + }) + + transport := &GitHubAuthenticatedTransport{ + Base: mockTransport, + GitHubToken: "ghes-token", + } + + // This host is a superdomain of the GHES host — must NOT get auth headers. + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://attacker.github.mycorp.example.com/evil", nil) + require.NoError(t, err) + + resp, err := transport.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.NotNil(t, capturedReq) + assert.Empty(t, capturedReq.Header.Get("Authorization"), + "superdomain of GHES host must NOT receive Authorization header") +} + +// TestWithGitHubHostMatcher verifies that WithGitHubHostMatcher allows configuring +// a custom host predicate on the GitHubAuthenticatedTransport. +func TestWithGitHubHostMatcher(t *testing.T) { + var capturedReq *http.Request + mockTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("OK")), + }, nil + }) + + // Create a client with a custom host matcher that allows only our test host. + client := NewDefaultClient( + WithTransport(mockTransport), + WithGitHubToken("custom-token"), + WithGitHubHostMatcher(func(host string) bool { + return host == "custom-git.example.com" + }), + ) + + // Allowed custom host — should get auth headers. + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://custom-git.example.com/api/v1/repos", nil) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.NotNil(t, capturedReq) + assert.Equal(t, "Bearer custom-token", capturedReq.Header.Get("Authorization"), + "custom host matcher must allow the configured host") + + // Default allowed host (api.github.com) — should NOT get auth with a custom matcher + // that only allows custom-git.example.com. + req2, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://api.github.com/repos", nil) + require.NoError(t, err) + + resp2, err := client.Do(req2) + require.NoError(t, err) + defer resp2.Body.Close() + + assert.Empty(t, capturedReq.Header.Get("Authorization"), + "custom host matcher must override the default allowlist") +} + +// TestWithGitHubHostMatcher_Precedence is a table-driven test that validates the +// three-level host-matcher precedence documented in pkg/http/doc.go: +// +// 1. [WithGitHubHostMatcher] — an explicit custom predicate always wins. +// 2. GITHUB_API_URL — when set and [WithGitHubHostMatcher] was NOT applied. +// 3. Built-in allowlist — api.github.com, raw.githubusercontent.com, uploads.github.com. +func TestWithGitHubHostMatcher_Precedence(t *testing.T) { + // Cannot use t.Parallel() here because subtests call t.Setenv which modifies + // the process-wide GITHUB_API_URL environment variable. + + cases := []struct { + name string + gitHubAPIURL string // GITHUB_API_URL env value ("" = not set) + customMatcher func(string) bool // nil = don't call WithGitHubHostMatcher + requestURL string // HTTPS URL to test + wantAuth bool // whether Authorization should be injected + }{ + // ── Level 3: built-in allowlist ────────────────────────────────────────── + { + name: "builtin_api_github_com", + requestURL: "https://api.github.com/repos", + wantAuth: true, + }, + { + name: "builtin_raw_githubusercontent_com", + requestURL: "https://raw.githubusercontent.com/owner/repo/main/file.go", + wantAuth: true, + }, + { + name: "builtin_uploads_github_com", + requestURL: "https://uploads.github.com/releases/assets", + wantAuth: true, + }, + { + name: "builtin_negative_example_com", + requestURL: "https://example.com/api", + wantAuth: false, + }, + { + name: "builtin_negative_github_example_com", + requestURL: "https://github.example.com/api", + wantAuth: false, + }, + // ── Level 2: GITHUB_API_URL overrides the default allowlist ────────────── + { + name: "github_api_url_adds_ghes_host", + gitHubAPIURL: "https://github.mycorp.example.com", + requestURL: "https://github.mycorp.example.com/api/v3/repos", + wantAuth: true, + }, + { + name: "github_api_url_still_allows_builtin", + gitHubAPIURL: "https://github.mycorp.example.com", + requestURL: "https://api.github.com/repos", + wantAuth: true, + }, + { + name: "github_api_url_does_not_allow_unrelated_host", + gitHubAPIURL: "https://github.mycorp.example.com", + requestURL: "https://other.example.com/api", + wantAuth: false, + }, + // ── Level 1: WithGitHubHostMatcher overrides GITHUB_API_URL ────────────── + { + // Custom matcher for "custom-git.example.com" — GHES host set via env + // must NOT receive auth because the custom matcher doesn't include it. + name: "custom_matcher_overrides_github_api_url", + gitHubAPIURL: "https://ghes.mycorp.example.com", + customMatcher: func(host string) bool { + return host == "custom-git.example.com" + }, + requestURL: "https://ghes.mycorp.example.com/api/v3/repos", + wantAuth: false, // custom matcher wins — GHES host is excluded + }, + { + // Custom matcher — host included in custom predicate gets auth. + name: "custom_matcher_allows_its_own_host", + gitHubAPIURL: "https://ghes.mycorp.example.com", + customMatcher: func(host string) bool { + return host == "custom-git.example.com" + }, + requestURL: "https://custom-git.example.com/api", + wantAuth: true, + }, + { + // Custom matcher replaces the default allowlist too. + name: "custom_matcher_overrides_builtin_allowlist", + customMatcher: func(host string) bool { + return host == "custom-git.example.com" + }, + requestURL: "https://api.github.com/repos", + wantAuth: false, // api.github.com is NOT in the custom matcher + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // No t.Parallel() — subtests may call t.Setenv. + + var capturedReq *http.Request + mockTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + capturedReq = req.Clone(req.Context()) + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("OK")), + }, nil + }) + + if tc.gitHubAPIURL != "" { + t.Setenv("GITHUB_API_URL", tc.gitHubAPIURL) + } + + opts := []ClientOption{ + WithTransport(mockTransport), + WithGitHubToken("precedence-test-token"), + } + if tc.customMatcher != nil { + opts = append(opts, WithGitHubHostMatcher(tc.customMatcher)) + } + client := NewDefaultClient(opts...) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, tc.requestURL, nil) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.NotNil(t, capturedReq, "transport must have been invoked") + got := capturedReq.Header.Get("Authorization") + if tc.wantAuth { + assert.Equal(t, "Bearer precedence-test-token", got, + "[%s] expected Authorization to be set on %s", tc.name, tc.requestURL) + } else { + assert.Empty(t, got, + "[%s] expected Authorization to be absent on %s", tc.name, tc.requestURL) + } + }) + } +} + +// TestGet_LargeErrorBodyTruncation verifies that when an HTTP server returns a non-2xx +// response with a body larger than maxErrorBodySize, the error message: +// - contains the "[truncated]" marker +// - contains the content-type from the response +// - wraps ErrHTTPRequestFailed +func TestGet_LargeErrorBodyTruncation(t *testing.T) { + // Build a response body that is one byte larger than the limit. + // We use a mix of characters to make it identifiable. + oversizeBody := strings.Repeat("x", maxErrorBodySize+1) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusBadGateway) + _, _ = fmt.Fprint(w, oversizeBody) + })) + defer server.Close() + + client := NewDefaultClient(WithTimeout(10 * time.Second)) + _, err := Get(context.Background(), server.URL, client) + + require.Error(t, err, "a non-2xx response should return an error") + assert.True(t, errors.Is(err, errUtils.ErrHTTPRequestFailed), "error must wrap ErrHTTPRequestFailed") + assert.Contains(t, err.Error(), "[truncated]", "error message must contain truncation marker") + assert.Contains(t, err.Error(), "text/plain", "error message must contain content-type from response") + assert.Contains(t, err.Error(), "returned status 502", "error message must contain the status code") +} + +// TestGet_ErrorBodyContentType verifies that the content-type header value from a +// non-2xx response is correctly reported in the error message when the body fits within +// the truncation limit. +func TestGet_ErrorBodyContentType(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusForbidden) + _, _ = fmt.Fprint(w, `{"error":"forbidden"}`) + })) + defer server.Close() + + client := NewDefaultClient(WithTimeout(10 * time.Second)) + _, err := Get(context.Background(), server.URL, client) + + require.Error(t, err) + assert.Contains(t, err.Error(), "application/json", "error must contain the Content-Type header value") + assert.NotContains(t, err.Error(), "[truncated]", "short body must not be truncated") +} + +// TestIsGitHubHost_DefaultAllowlist verifies the default isGitHubHost allowlist +// without any GITHUB_API_URL override. +func TestIsGitHubHost_DefaultAllowlist(t *testing.T) { + // Ensure GITHUB_API_URL is not set so we test default behavior. + t.Setenv("GITHUB_API_URL", "") + + assert.True(t, isGitHubHost("api.github.com")) + assert.True(t, isGitHubHost("raw.githubusercontent.com")) + assert.True(t, isGitHubHost("uploads.github.com"), "uploads.github.com must be in the default allowlist") + + assert.False(t, isGitHubHost("github.com")) + assert.False(t, isGitHubHost("example.com")) + assert.False(t, isGitHubHost("github.example.com")) + assert.False(t, isGitHubHost("example.github.com")) + assert.False(t, isGitHubHost("")) +} + +// TestIsGitHubHost_GITHUB_API_URL verifies that GITHUB_API_URL adds a GHES host. +func TestIsGitHubHost_GITHUB_API_URL(t *testing.T) { + t.Setenv("GITHUB_API_URL", "https://github.mycorp.example.com") + + assert.True(t, isGitHubHost("github.mycorp.example.com"), "GITHUB_API_URL hostname should be allowed") + assert.True(t, isGitHubHost("api.github.com"), "default allowlist still applies") + + // Only exact hostname match, not substring. + assert.False(t, isGitHubHost("evil.github.mycorp.example.com")) + assert.False(t, isGitHubHost("github.mycorp.example.com.evil.tld")) +} + +// TestIsGitHubHost_InvalidGITHUB_API_URL verifies that an unparsable GITHUB_API_URL +// does not panic and falls back to the default allowlist. +func TestIsGitHubHost_InvalidGITHUB_API_URL(t *testing.T) { + t.Setenv("GITHUB_API_URL", "://not-a-valid-url") + + // Default allowlist still applies even when GITHUB_API_URL is invalid. + assert.True(t, isGitHubHost("api.github.com")) + assert.False(t, isGitHubHost("example.com")) +} + +// TestNormalizeHost verifies that normalizeHost canonicalises hostnames correctly. +func TestNormalizeHost(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"api.github.com", "api.github.com"}, + {"API.GITHUB.COM", "api.github.com"}, + {"Api.GitHub.Com", "api.github.com"}, + // Trailing dot (FQDN form). + {"api.github.com.", "api.github.com"}, + // Upper-case + trailing dot. + {"API.GITHUB.COM.", "api.github.com"}, + {"", ""}, + // Default port 443 should be stripped. + {"api.github.com:443", "api.github.com"}, + // Default port 80 should be stripped. + {"api.github.com:80", "api.github.com"}, + // Non-default port should be preserved. + {"api.github.com:8443", "api.github.com:8443"}, + // Port 443 + upper-case: both normalised. + {"API.GITHUB.COM:443", "api.github.com"}, + // Port 443 + trailing dot: trailing dot stripped then port stripped. + {"api.github.com.:443", "api.github.com"}, + // IPv6 with default port: brackets are stripped by net.SplitHostPort. + {"[::1]:443", "::1"}, + // IPv6 with non-default port: preserved (with brackets stripped by SplitHostPort). + {"[::1]:8080", "[::1]:8080"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + assert.Equal(t, tt.want, normalizeHost(tt.input)) + }) + } +} + +// TestIsGitHubHost_CaseAndTrailingDot verifies that isGitHubHost tolerates case +// variations and trailing dots in the host parameter. +func TestIsGitHubHost_CaseAndTrailingDot(t *testing.T) { + t.Setenv("GITHUB_API_URL", "") + + positives := []string{ + "API.GITHUB.COM", + "api.github.com.", + "API.GITHUB.COM.", + "Raw.GitHubUserContent.com", + "UPLOADS.GITHUB.COM", + // Port variants: default port should be stripped before matching. + "api.github.com:443", + "API.GITHUB.COM:443", + "uploads.github.com:443", + "raw.githubusercontent.com:80", + } + for _, h := range positives { + assert.True(t, isGitHubHost(h), "expected %q to be allowed", h) + } + + negatives := []string{ + "GITHUB.EXAMPLE.COM", + "EXAMPLE.GITHUB.COM", + "github.com", + // Port variants on disallowed hosts should still be denied. + "github.example.com:443", + "example.github.com:443", + } + for _, h := range negatives { + assert.False(t, isGitHubHost(h), "expected %q to be denied", h) + } +} + +// TestGitHubAuthenticatedTransport_CrossHostRedirect verifies that Authorization is +// NOT forwarded when the http.Client follows a redirect to a different host. +// The transport only adds auth per-hop for allowed hosts; this test also verifies +// that the CheckRedirect installed by WithGitHubToken strips any stale Authorization. +func TestGitHubAuthenticatedTransport_CrossHostRedirect(t *testing.T) { + // Target server — asserts Authorization is NOT present. + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Empty(t, r.Header.Get("Authorization"), + "Authorization must not be forwarded to the redirect target host") + w.WriteHeader(http.StatusOK) + })) + defer target.Close() + + // Origin server — redirects to target (different host). + origin := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, target.URL+"/landed", http.StatusFound) + })) + defer origin.Close() + + // Build a client with a GitHub token. The origin and target are both plain HTTP + // httptest servers, so Authorization will NOT be added by the transport anyway + // (HTTPS-only rule). But we still verify that CheckRedirect is wired up and + // does not add the header on the redirect leg. + client := NewDefaultClient( + WithGitHubToken("test-redir-token"), + ) + + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, origin.URL+"/start", nil) + require.NoError(t, err) + + // Manually add an Authorization header to simulate a caller that pre-set it. + req.Header.Set("Authorization", "Bearer manual-token") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) +} + +// TestGitHubAuthenticatedTransport_AllRedirectStatusCodes verifies that Authorization +// is stripped when the http.Client follows any of the standard redirect status codes +// (301 MovedPermanently, 302 Found, 303 SeeOther, 307 TemporaryRedirect, 308 PermanentRedirect) +// to a different host. +func TestGitHubAuthenticatedTransport_AllRedirectStatusCodes(t *testing.T) { + t.Parallel() + + redirectCases := []struct { + code int + name string + }{ + {http.StatusMovedPermanently, "301_MovedPermanently"}, + {http.StatusFound, "302_Found"}, + {http.StatusSeeOther, "303_SeeOther"}, + {http.StatusTemporaryRedirect, "307_TemporaryRedirect"}, + {http.StatusPermanentRedirect, "308_PermanentRedirect"}, + } + + for _, rc := range redirectCases { + t.Run(rc.name, func(t *testing.T) { + t.Parallel() + + // Target: a different host that must NOT receive Authorization. + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Empty(t, r.Header.Get("Authorization"), + "Authorization must not be forwarded after %d redirect", rc.code) + w.WriteHeader(http.StatusOK) + })) + defer target.Close() + + // 307 and 308 require the same method and body so we use a GET to keep it simple. + origin := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, target.URL+"/landed", rc.code) + })) + defer origin.Close() + + client := NewDefaultClient( + WithGitHubToken("redir-test-token"), + ) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, origin.URL+"/start", nil) + require.NoError(t, err) + // Pre-set Authorization to simulate a caller that added it manually. + req.Header.Set("Authorization", "Bearer manual-token") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode, + "final response after %d redirect must be 200 OK", rc.code) + }) + } +} diff --git a/pkg/http/doc.go b/pkg/http/doc.go new file mode 100644 index 0000000000..34ee222f40 --- /dev/null +++ b/pkg/http/doc.go @@ -0,0 +1,110 @@ +// Package http provides a configurable HTTP client with GitHub authentication support. +// +// # Creating a client +// +// Use [NewDefaultClient] with functional options to configure the client: +// +// client := http.NewDefaultClient( +// http.WithTimeout(30 * time.Second), +// http.WithGitHubToken("mytoken"), +// ) +// +// # Option ordering +// +// The order in which options are applied matters when composing transport layers. +// Understanding the ordering rules prevents subtle authentication bugs. +// +// ## WithGitHubToken and WithTransport +// +// [WithGitHubToken] wraps the current transport in a [GitHubAuthenticatedTransport]. +// [WithTransport] sets a new base transport. When applied after [WithGitHubToken], +// [WithTransport] replaces the base (inner) transport while preserving the auth wrapper: +// +// // Recommended: set transport first, then add authentication. +// client := http.NewDefaultClient( +// http.WithTransport(myCustomTransport), // base transport +// http.WithGitHubToken("token"), // wraps myCustomTransport +// ) +// +// // Also valid: apply in reverse order — WithTransport after WithGitHubToken +// // replaces the base of the existing GitHubAuthenticatedTransport. +// client := http.NewDefaultClient( +// http.WithGitHubToken("token"), // wraps http.DefaultTransport +// http.WithTransport(myCustomTransport), // updates base to myCustomTransport +// ) +// +// Triple-composition note: adding a second [WithTransport] after [WithGitHubToken] + +// the first [WithTransport] silently discards the first base transport: +// +// // t1 is discarded; result is GitHubAuthenticatedTransport{Base: t2, Token: "x"} +// client := http.NewDefaultClient( +// http.WithTransport(t1), +// http.WithGitHubToken("x"), +// http.WithTransport(t2), // t1 is gone +// ) +// +// ## WithGitHubHostMatcher +// +// [WithGitHubHostMatcher] must be applied AFTER [WithGitHubToken], because the host +// matcher is stored on the [GitHubAuthenticatedTransport] which is only created by +// [WithGitHubToken]. Applying it before [WithGitHubToken] has no effect. +// +// // Correct: token first, then custom host matcher. +// client := http.NewDefaultClient( +// http.WithGitHubToken("mytoken"), +// http.WithGitHubHostMatcher(func(host string) bool { +// return host == "github.mycorp.example.com" // GHES +// }), +// ) +// +// // Incorrect: matcher applied before token — has no effect. +// client := http.NewDefaultClient( +// http.WithGitHubHostMatcher(...), // no-op: no transport yet +// http.WithGitHubToken("mytoken"), +// ) +// +// ## GitHub Enterprise Server (GHES) +// +// For GHES deployments set the GITHUB_API_URL environment variable to the GHES API +// base URL. The default host matcher ([isGitHubHost]) reads this variable and treats +// the configured hostname as an additional allowed host: +// +// GITHUB_API_URL=https://github.mycorp.example.com +// +// Alternatively, use [WithGitHubHostMatcher] for programmatic control. +// +// # Host-matcher precedence +// +// The host predicate used to decide whether to inject Authorization follows this +// precedence order (highest to lowest): +// +// 1. [WithGitHubHostMatcher] — an explicit custom predicate always wins. +// 2. GITHUB_API_URL — when set and [WithGitHubHostMatcher] was NOT applied, +// the GHES hostname from the environment variable is added to the allowlist. +// 3. Built-in allowlist — api.github.com, raw.githubusercontent.com, uploads.github.com. +// +// If you need GHES support together with a custom matcher, include the GHES host +// in your custom predicate; [WithGitHubHostMatcher] bypasses the GITHUB_API_URL lookup: +// +// ghesHost := "github.mycorp.example.com" +// client := http.NewDefaultClient( +// http.WithGitHubToken("mytoken"), +// http.WithGitHubHostMatcher(func(host string) bool { +// return host == "api.github.com" || host == ghesHost +// }), +// ) +// +// # Security notes +// +// Authorization headers are only injected when ALL of the following are true: +// - The request URL scheme is "https". +// - The request hostname matches the host predicate. +// - The Authorization header is not already set on the request. +// +// Cross-host redirects: [WithGitHubToken] installs a CheckRedirect handler that +// strips the Authorization header from redirect requests that target a different +// host:port, preventing token leakage via open redirects. +// +// This prevents accidental token leakage over unencrypted HTTP and ensures +// that caller-supplied Authorization headers are never overwritten. +package http diff --git a/pkg/utils/export_test.go b/pkg/utils/export_test.go new file mode 100644 index 0000000000..1d50a93f6f --- /dev/null +++ b/pkg/utils/export_test.go @@ -0,0 +1,15 @@ +package utils + +// ResetGlobMatchesCache clears the glob matches sync map. +// This is exported only for testing to avoid data races from direct struct assignment. +func ResetGlobMatchesCache() { + getGlobMatchesSyncMap.Clear() +} + +// ResetPathMatchCache clears the path match cache. +// This is exported only for testing to ensure consistent state between tests. +func ResetPathMatchCache() { + pathMatchCacheMu.Lock() + pathMatchCache = make(map[pathMatchKey]bool) + pathMatchCacheMu.Unlock() +} diff --git a/pkg/utils/glob_utils.go b/pkg/utils/glob_utils.go index f0b0470f78..edb1a8d8fd 100644 --- a/pkg/utils/glob_utils.go +++ b/pkg/utils/glob_utils.go @@ -3,7 +3,6 @@ package utils import ( "os" "path/filepath" - "strings" "sync" "github.com/bmatcuk/doublestar/v4" @@ -32,15 +31,38 @@ var ( // GetGlobMatches tries to read and return the Glob matches content from the sync map if it exists in the map, // otherwise it finds and returns all files matching the pattern, stores the files in the map and returns the files. +// +// Note: unlike pkg/filesystem.GetGlobMatches, this function returns an error when no files match the pattern +// (consistent with its use as an import-path resolver). The returned slice may be nil when an error is returned. +// See pkg/filesystem.GetGlobMatches for the variant that returns ([]string{}, nil) instead of an error +// when no files match. +// +// Caching contract: only non-empty result sets are cached. The cache stores []string directly (not a +// comma-joined string) so that paths containing commas are preserved correctly on cache hits. +// Cached slices are cloned before being returned, so callers may safely mutate the returned slice without +// affecting the cache. func GetGlobMatches(pattern string) ([]string, error) { defer perf.Track(nil, "utils.GetGlobMatches")() + // Normalize pattern before cache lookup so that Windows backslash paths and + // forward-slash paths resolve to the same cache key (mirrors pkg/filesystem behavior). + pattern = filepath.ToSlash(pattern) + existingMatches, found := getGlobMatchesSyncMap.Load(pattern) if found && existingMatches != nil { - return strings.Split(existingMatches.(string), ","), nil + // Cache stores []string directly to avoid the comma-splitting bug: + // paths containing commas would be split incorrectly if stored as a + // comma-joined string and then re-split on read. + if cached, ok := existingMatches.([]string); ok { + // Return a clone so callers cannot mutate the cached slice. + result := make([]string, len(cached)) + copy(result, cached) + return result, nil + } + // Unexpected cache type: invalidate and recompute. + getGlobMatchesSyncMap.Delete(pattern) } - pattern = filepath.ToSlash(pattern) base, cleanPattern := doublestar.SplitPattern(pattern) f := os.DirFS(base) @@ -62,7 +84,17 @@ func GetGlobMatches(pattern string) ([]string, error) { fullMatches = append(fullMatches, filepath.Join(filepath.FromSlash(base), match)) } - getGlobMatchesSyncMap.Store(pattern, strings.Join(fullMatches, ",")) + // Only cache non-empty results. Empty results are not cached because + // pkg/utils.GetGlobMatches treats "no matches" as an error, so there is + // nothing useful to cache (the error is re-computed on every call). + // Storing []string directly avoids the comma-splitting bug: paths that + // contain commas would be mangled if stored/read as a joined string. + if len(fullMatches) > 0 { + // Store a clone so callers cannot mutate cached data. + cached := make([]string, len(fullMatches)) + copy(cached, fullMatches) + getGlobMatchesSyncMap.Store(pattern, cached) + } return fullMatches, nil } diff --git a/pkg/utils/glob_utils_test.go b/pkg/utils/glob_utils_test.go index c4e56822ef..59958d3840 100644 --- a/pkg/utils/glob_utils_test.go +++ b/pkg/utils/glob_utils_test.go @@ -217,6 +217,9 @@ func TestPathMatch_AtmosStackPatterns(t *testing.T) { // TestPathMatch_ConsistentResults tests that multiple calls with same inputs return consistent results. // This indirectly validates that caching doesn't break behavior. func TestPathMatch_ConsistentResults(t *testing.T) { + ResetPathMatchCache() + t.Cleanup(ResetPathMatchCache) + pattern := "stacks/**/*.yaml" path := "stacks/catalog/vpc.yaml" @@ -346,7 +349,10 @@ func TestGetGlobMatches_Basic(t *testing.T) { matches, err := GetGlobMatches(pattern) require.NoError(t, err) - assert.NotNil(t, matches) + // pkg/utils.GetGlobMatches returns non-nil only when matches are found (unlike + // pkg/filesystem.GetGlobMatches which always returns a non-nil slice). For a "*.go" + // pattern in a Go package directory there must be at least one match. + assert.NotEmpty(t, matches) // We can't assert exact matches since it depends on the directory contents. // But we can verify the function completes without error. } @@ -445,6 +451,9 @@ func TestGetGlobMatches_WindowsAbsolutePath(t *testing.T) { // Before the fix, pattern="a|b" + name="c" and pattern="a" + name="b|c" would collide // because both produced cache key "a|b|c" when using string concatenation. func TestPathMatch_PipeCharacterNoCollision(t *testing.T) { + ResetPathMatchCache() + t.Cleanup(ResetPathMatchCache) + tests := []struct { name string pattern string @@ -512,3 +521,86 @@ func TestPathMatch_PipeCharacterNoCollision(t *testing.T) { assert.Equal(t, match1, match2, "Both should have same result (false)") } } + +// TestGetGlobMatches_EmptyResultCachingBug is a regression test for the phantom-path bug. +// When doublestar.Glob returns a non-nil but empty slice, the original code stored +// strings.Join([]string{}, ",") == "" in the cache. A subsequent call (cache hit) would +// return strings.Split("", ",") == []string{""} — a single empty-string "phantom path". +// The fix: only cache non-empty result sets; return and guard against empty cached entries. +// +// Note: this test relies on doublestar.Glob returning nil (not []string{}) for no matches — +// documented behavior per the doublestar library: "returns nil, nil when no matches are found". +// If that library contract changes, this test will return (nil, nil) rather than an error. +func TestGetGlobMatches_EmptyResultCachingBug(t *testing.T) { + // Reset cache before and after to avoid inter-test pollution. + ResetGlobMatchesCache() + t.Cleanup(ResetGlobMatchesCache) + + // Use a pattern that legitimately matches no files in the tmp directory. + tmpDir := t.TempDir() + pattern := filepath.Join(tmpDir, "no-such-file-*.xyz") + + // First call — no matches, should return an error (pkg/utils contract). + result1, err1 := GetGlobMatches(pattern) + assert.Error(t, err1, "first call with no matches should return an error") + assert.Nil(t, result1) + + // Second call — should not return a phantom empty-string entry from a cached "". + // Expected: same error as first call (no-op cache miss — nothing was cached for empty results). + result2, err2 := GetGlobMatches(pattern) + assert.Error(t, err2, "second call (potential cache hit) should still return an error") + // Must return nil, not []string{""} (the phantom path from the original bug). + assert.Nil(t, result2, "second call must not return a phantom empty-string path from cache") +} + +// TestGetGlobMatches_CommaSafeCache verifies that filenames containing commas are +// preserved correctly through a cache round-trip. Before the fix, the cache stored +// paths as a comma-joined string; a path like "stack,env.yaml" would be split into +// ["stack", "env.yaml"] on a cache hit, producing two wrong entries. +func TestGetGlobMatches_CommaSafeCache(t *testing.T) { + ResetGlobMatchesCache() + t.Cleanup(ResetGlobMatchesCache) + + tmpDir := t.TempDir() + + // Create a file whose name contains a comma. + commaFile := filepath.Join(tmpDir, "stack,env.yaml") + require.NoError(t, os.WriteFile(commaFile, []byte(""), 0o644)) + + pattern := filepath.Join(tmpDir, "*.yaml") + + // First call (cache miss): verify we get the correct single result. + first, err := GetGlobMatches(pattern) + require.NoError(t, err) + require.Len(t, first, 1, "expected exactly one file matching *.yaml") + assert.True(t, strings.HasSuffix(filepath.ToSlash(first[0]), "stack,env.yaml"), + "returned path must contain the full filename with comma, got: %s", first[0]) + + // Second call (cache hit): verify the comma-containing filename survived round-trip. + second, err := GetGlobMatches(pattern) + require.NoError(t, err) + require.Len(t, second, 1, "cache hit must return exactly one entry, not split on comma") + assert.Equal(t, first[0], second[0], "cached filename must be identical to original") +} + +// TestGetGlobMatches_NonExistentBaseDirError verifies that a pattern whose base directory +// does not exist returns an error regardless of doublestar.Glob's nil-vs-empty semantics. +// This is the "library-contract-independent" companion to TestGetGlobMatches_EmptyResultCachingBug: +// it proves the error path without relying on doublestar returning nil for no matches. +func TestGetGlobMatches_NonExistentBaseDirError(t *testing.T) { + ResetGlobMatchesCache() + t.Cleanup(ResetGlobMatchesCache) + + // Use a base dir that cannot exist — a path inside a temp dir that was never created. + pattern := filepath.Join(t.TempDir(), "subdir-that-does-not-exist", "*.yaml") + + // First call: the base dir doesn't exist; doublestar.Glob will fail to open it. + result1, err1 := GetGlobMatches(pattern) + assert.Error(t, err1, "pattern with non-existent base dir must return an error") + assert.Nil(t, result1, "error path must not return a non-nil slice") + + // Second call: nothing was cached (empty results are not stored), so same error. + result2, err2 := GetGlobMatches(pattern) + assert.Error(t, err2, "second call must also return an error — nothing was cached") + assert.Nil(t, result2, "second call must also return nil slice") +}