diff --git a/internal/cas/benchmark_test.go b/internal/cas/benchmark_test.go index 681c8299d6..550558eec9 100644 --- a/internal/cas/benchmark_test.go +++ b/internal/cas/benchmark_test.go @@ -12,11 +12,11 @@ import ( "github.com/gruntwork-io/terragrunt/internal/cas" "github.com/gruntwork-io/terragrunt/internal/git" "github.com/gruntwork-io/terragrunt/test/helpers/logger" + "github.com/stretchr/testify/require" ) func BenchmarkClone(b *testing.B) { - // Use a small, public repository for consistent results - repo := "https://github.com/gruntwork-io/terragrunt.git" + repoURL := startBenchServer(b) l := logger.CreateLogger() @@ -31,20 +31,15 @@ func BenchmarkClone(b *testing.B) { storePath := filepath.Join(tempDir, "store", strconv.Itoa(i)) targetPath := filepath.Join(tempDir, "repo", strconv.Itoa(i)) - c, err := cas.New(cas.Options{ - StorePath: storePath, - }) - if err != nil { - b.Fatal(err) - } + c, err := cas.New(cas.WithStorePath(storePath)) + require.NoError(b, err) b.StartTimer() - if err := c.Clone(b.Context(), l, &cas.CloneOptions{ - Dir: targetPath, - }, repo); err != nil { - b.Fatal(err) - } + require.NoError(b, c.Clone(b.Context(), l, &cas.CloneOptions{ + Dir: targetPath, + Depth: -1, + }, repoURL)) } }) @@ -53,18 +48,13 @@ func BenchmarkClone(b *testing.B) { storePath := filepath.Join(tempDir, "store") // First clone to populate store - c, err := cas.New(cas.Options{ - StorePath: storePath, - }) - if err != nil { - b.Fatal(err) - } + c, err := cas.New(cas.WithStorePath(storePath)) + require.NoError(b, err) - if err := c.Clone(b.Context(), l, &cas.CloneOptions{ - Dir: filepath.Join(tempDir, "initial"), - }, repo); err != nil { - b.Fatal(err) - } + require.NoError(b, c.Clone(b.Context(), l, &cas.CloneOptions{ + Dir: filepath.Join(tempDir, "initial"), + Depth: -1, + }, repoURL)) b.ResetTimer() @@ -73,20 +63,15 @@ func BenchmarkClone(b *testing.B) { targetPath := filepath.Join(tempDir, "repo", strconv.Itoa(i)) - c, err := cas.New(cas.Options{ - StorePath: storePath, - }) - if err != nil { - b.Fatal(err) - } + c, err := cas.New(cas.WithStorePath(storePath)) + require.NoError(b, err) b.StartTimer() - if err := c.Clone(b.Context(), l, &cas.CloneOptions{ - Dir: targetPath, - }, repo); err != nil { - b.Fatal(err) - } + require.NoError(b, c.Clone(b.Context(), l, &cas.CloneOptions{ + Dir: targetPath, + Depth: -1, + }, repoURL)) } }) } @@ -109,9 +94,7 @@ func BenchmarkContent(b *testing.B) { b.StartTimer() - if err := content.Store(l, hash, testData); err != nil { - b.Fatal(err) - } + require.NoError(b, content.Store(l, hash, testData)) } }) @@ -148,37 +131,29 @@ func BenchmarkContent(b *testing.B) { } func BenchmarkGitOperations(b *testing.B) { - // Setup a git repository for testing + repoURL := startBenchServer(b) + + // Clone the repo locally for tree operations repoDir := b.TempDir() g, err := git.NewGitRunner() - if err != nil { - b.Fatal(err) - } + require.NoError(b, err) g = g.WithWorkDir(repoDir) ctx := b.Context() - if err = g.Clone(ctx, "https://github.com/gruntwork-io/terragrunt.git", false, 1, "main"); err != nil { - b.Fatal(err) - } + require.NoError(b, g.Clone(ctx, repoURL, false, 0, "")) b.Run("ls-remote", func(b *testing.B) { - g, err = git.NewGitRunner() - if err != nil { - b.Fatal(err) - } - - g = g.WithWorkDir(repoDir) + runner, err := git.NewGitRunner() + require.NoError(b, err) b.ResetTimer() for b.Loop() { - _, err := g.LsRemote(ctx, "https://github.com/gruntwork-io/terragrunt.git", "HEAD") - if err != nil { - b.Fatal(err) - } + _, err := runner.LsRemote(ctx, repoURL, "HEAD") + require.NoError(b, err) } }) @@ -187,31 +162,21 @@ func BenchmarkGitOperations(b *testing.B) { for b.Loop() { _, err := g.LsTreeRecursive(ctx, "HEAD") - if err != nil { - b.Fatal(err) - } + require.NoError(b, err) } }) b.Run("cat-file", func(b *testing.B) { - // First get a valid hash tree, err := g.LsTreeRecursive(ctx, "HEAD") - if err != nil { - b.Fatal(err) - } - - if len(tree.Entries()) == 0 { - b.Fatal("no entries in tree") - } + require.NoError(b, err) + require.NotEmpty(b, tree.Entries(), "no entries in tree") hash := tree.Entries()[0].Hash tmpFile := b.TempDir() + "/cat-file" tmp, err := os.Create(tmpFile) - if err != nil { - b.Fatal(err) - } + require.NoError(b, err) defer os.Remove(tmpFile) defer tmp.Close() @@ -220,9 +185,25 @@ func BenchmarkGitOperations(b *testing.B) { for b.Loop() { err := g.CatFile(ctx, hash, tmp) - if err != nil { - b.Fatal(err) - } + require.NoError(b, err) } }) } + +func startBenchServer(b *testing.B) string { + b.Helper() + + srv, err := git.NewServer() + require.NoError(b, err) + + b.Cleanup(func() { _ = srv.Close() }) + + require.NoError(b, srv.CommitFile("README.md", []byte("# test repo"), "add readme")) + require.NoError(b, srv.CommitFile("main.tf", []byte(`resource "null_resource" "test" {}`), "add main.tf")) + require.NoError(b, srv.CommitFile("test/integration_test.go", []byte("package test"), "add test file")) + + url, err := srv.Start(b.Context()) + require.NoError(b, err) + + return url +} diff --git a/internal/cas/cas.go b/internal/cas/cas.go index cea55f79b1..880c829b94 100644 --- a/internal/cas/cas.go +++ b/internal/cas/cas.go @@ -15,71 +15,101 @@ import ( "path/filepath" "runtime" - "github.com/gofrs/flock" "github.com/gruntwork-io/terragrunt/internal/errors" "github.com/gruntwork-io/terragrunt/internal/git" "github.com/gruntwork-io/terragrunt/internal/telemetry" + "github.com/gruntwork-io/terragrunt/internal/vfs" "github.com/gruntwork-io/terragrunt/pkg/log" ) -// Options configures the behavior of CAS -type Options struct { - // StorePath specifies a custom path for the content store - // If empty, uses $HOME/.cache/terragrunt/cas/store - StorePath string -} +const defaultCloneDepth = 1 + +// Option configures the behavior of CAS. +type Option func(*CAS) -// CloneOptions configures the behavior of a specific clone operation +// CloneOptions configures the behavior of a specific clone operation. type CloneOptions struct { - // Dir specifies the target directory for the clone - // If empty, uses the repository name + // Dir specifies the target directory for the clone. + // If empty, uses the repository name. Dir string - // Branch specifies which branch to clone - // If empty, uses HEAD + // Branch specifies which branch to clone. + // If empty, uses HEAD. Branch string - // IncludedGitFiles specifies the files to preserve from the .git directory - // If empty, does not preserve any files + // IncludedGitFiles specifies the files to preserve from the .git directory. + // If empty, does not preserve any files. IncludedGitFiles []string + + // Depth limits the clone history to the given number of commits. + // If zero, defaults to 1 (shallow clone). Set to -1 for full history. + Depth int } // CAS clones a git repository using content-addressable storage. type CAS struct { - store *Store - git *git.GitRunner - opts Options + fs vfs.FS + store *Store + git *git.GitRunner + storePath string } -// New creates a new CAS instance with the given options -// -// TODO: Make these options optional -func New(opts Options) (*CAS, error) { - if opts.StorePath == "" { +// WithStorePath specifies a custom path for the content store. +// If not set, defaults to $HOME/.cache/terragrunt/cas/store. +func WithStorePath(path string) Option { + return func(c *CAS) { + c.storePath = path + } +} + +// WithFS specifies the filesystem for file operations. +// If not set, defaults to the real OS filesystem. +func WithFS(fs vfs.FS) Option { + return func(c *CAS) { + c.fs = fs + } +} + +// New creates a new CAS instance with the given options. +func New(opts ...Option) (*CAS, error) { + c := &CAS{} + + for _, opt := range opts { + opt(c) + } + + if c.fs == nil { + c.fs = vfs.NewOSFS() + } + + if c.storePath == "" { home, err := os.UserHomeDir() if err != nil { return nil, err } - opts.StorePath = filepath.Join(home, ".cache", "terragrunt", "cas", "store") + c.storePath = filepath.Join(home, ".cache", "terragrunt", "cas", "store") } - if err := os.MkdirAll(opts.StorePath, DefaultDirPerms); err != nil { + if err := c.fs.MkdirAll(c.storePath, DefaultDirPerms); err != nil { return nil, fmt.Errorf("failed to create CAS store path: %w", err) } - store := NewStore(opts.StorePath) + c.store = NewStore(c.storePath).WithFS(c.fs) g, err := git.NewGitRunner() if err != nil { return nil, err } - return &CAS{ - store: store, - git: g, - opts: opts, - }, nil + c.git = g + + return c, nil +} + +// FS returns the configured filesystem. +func (c *CAS) FS() vfs.FS { + return c.fs } // Clone performs the clone operation @@ -87,14 +117,13 @@ func New(opts Options) (*CAS, error) { // TODO: Make options optional func (c *CAS) Clone(ctx context.Context, l log.Logger, opts *CloneOptions, url string) error { // Ensure the store path exists - if err := os.MkdirAll(c.store.Path(), DefaultDirPerms); err != nil { + if err := c.fs.MkdirAll(c.store.Path(), DefaultDirPerms); err != nil { return fmt.Errorf("failed to create store path: %w", err) } // Acquire global clone lock to ensure only one clone at a time - globalLock := flock.New(filepath.Join(c.store.Path(), "clone.lock")) - - if err := globalLock.Lock(); err != nil { + globalLock, err := vfs.Lock(c.fs, filepath.Join(c.store.Path(), "clone.lock")) + if err != nil { return fmt.Errorf("failed to acquire global clone lock: %w", err) } @@ -182,7 +211,16 @@ func (c *CAS) cloneAndStoreContent( url, hash string, ) error { - if err := c.git.Clone(ctx, url, true, 1, opts.Branch); err != nil { + depth := opts.Depth + if depth == 0 { + depth = defaultCloneDepth + } + + if depth < 0 { + depth = 0 + } + + if err := c.git.Clone(ctx, url, true, depth, opts.Branch); err != nil { return err } @@ -211,7 +249,7 @@ func (c *CAS) storeRootTree(ctx context.Context, l log.Logger, hash string, opts } for _, file := range opts.IncludedGitFiles { - stat, err := os.Stat(filepath.Join(c.git.WorkDir, file)) + stat, err := c.fs.Stat(filepath.Join(c.git.WorkDir, file)) if err != nil { return err } @@ -222,7 +260,7 @@ func (c *CAS) storeRootTree(ctx context.Context, l log.Logger, hash string, opts workDirPath := filepath.Join(c.git.WorkDir, file) - includedHash, err := hashFile(workDirPath) + includedHash, err := hashFile(c.fs, workDirPath) if err != nil { return err } @@ -310,8 +348,8 @@ func (c *CAS) ensureBlob(ctx context.Context, hash string) error { // We want to make sure we remove the temporary file // if we encounter an error defer func() { - if _, osStatErr := os.Stat(tmpPath); osStatErr == nil { - err = errors.Join(err, os.Remove(tmpPath)) + if _, statErr := c.fs.Stat(tmpPath); statErr == nil { + err = errors.Join(err, c.fs.Remove(tmpPath)) } }() @@ -331,30 +369,40 @@ func (c *CAS) ensureBlob(ctx context.Context, hash string) error { return err } - if err = os.Rename(tmpPath, content.getPath(hash)); err != nil { + if err = c.fs.Rename(tmpPath, content.getPath(hash)); err != nil { return err } - if err = os.Chmod(content.getPath(hash), StoredFilePerms); err != nil { + if err = c.fs.Chmod(content.getPath(hash), StoredFilePerms); err != nil { return err } return nil } -func hashFile(path string) (string, error) { - file, err := os.Open(path) +func hashFile(fs vfs.FS, path string) (string, error) { + file, err := fs.Open(path) if err != nil { return "", err } - defer file.Close() - h := sha1.New() if _, err := io.Copy(h, file); err != nil { return "", err } - return hex.EncodeToString(h.Sum(nil)), nil + hash := hex.EncodeToString(h.Sum(nil)) + + if err := file.Close(); err != nil { + return hash, fmt.Errorf( + "hash of %s successfully computed as "+ + "%s, but closing the file failed: %w", + path, + hash, + err, + ) + } + + return hash, nil } diff --git a/internal/cas/cas_test.go b/internal/cas/cas_test.go index 6be75226a9..bd082a8f6c 100644 --- a/internal/cas/cas_test.go +++ b/internal/cas/cas_test.go @@ -15,6 +15,7 @@ func TestCAS_Clone(t *testing.T) { t.Parallel() l := logger.CreateLogger() + repoURL := startTestServer(t) t.Run("clone new repository", func(t *testing.T) { t.Parallel() @@ -22,14 +23,13 @@ func TestCAS_Clone(t *testing.T) { storePath := filepath.Join(tempDir, "store") targetPath := filepath.Join(tempDir, "repo") - c, err := cas.New(cas.Options{ - StorePath: storePath, - }) + c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) err = c.Clone(t.Context(), l, &cas.CloneOptions{ - Dir: targetPath, - }, "https://github.com/gruntwork-io/terragrunt.git") + Dir: targetPath, + Depth: -1, + }, repoURL) require.NoError(t, err) // Verify repository was cloned @@ -47,15 +47,14 @@ func TestCAS_Clone(t *testing.T) { storePath := filepath.Join(tempDir, "store") targetPath := filepath.Join(tempDir, "repo") - c, err := cas.New(cas.Options{ - StorePath: storePath, - }) + c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) err = c.Clone(t.Context(), l, &cas.CloneOptions{ Dir: targetPath, Branch: "main", - }, "https://github.com/gruntwork-io/terragrunt.git") + Depth: -1, + }, repoURL) require.NoError(t, err) // Verify repository was cloned @@ -69,15 +68,14 @@ func TestCAS_Clone(t *testing.T) { storePath := filepath.Join(tempDir, "store") targetPath := filepath.Join(tempDir, "repo") - c, err := cas.New(cas.Options{ - StorePath: storePath, - }) + c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) err = c.Clone(t.Context(), l, &cas.CloneOptions{ Dir: targetPath, IncludedGitFiles: []string{"HEAD", "config"}, - }, "https://github.com/gruntwork-io/terragrunt.git") + Depth: -1, + }, repoURL) require.NoError(t, err) // Verify repository was cloned diff --git a/internal/cas/content.go b/internal/cas/content.go index 0990409e83..823b372a35 100644 --- a/internal/cas/content.go +++ b/internal/cas/content.go @@ -3,20 +3,17 @@ package cas import ( "bufio" "context" + "errors" "io" "os" "path/filepath" "runtime" "github.com/gruntwork-io/terragrunt/internal/telemetry" + "github.com/gruntwork-io/terragrunt/internal/vfs" "github.com/gruntwork-io/terragrunt/pkg/log" ) -// Content manages git object storage and linking -type Content struct { - store *Store -} - const ( // DefaultDirPerms represents standard directory permissions (rwxr-xr-x) DefaultDirPerms = os.FileMode(0755) @@ -28,10 +25,17 @@ const ( WindowsOS = "windows" ) +// Content manages git object storage and linking +type Content struct { + store *Store + fs vfs.FS +} + // NewContent creates a new Content instance func NewContent(store *Store) *Content { return &Content{ store: store, + fs: store.FS(), } } @@ -44,7 +48,7 @@ func (c *Content) Link(ctx context.Context, hash, targetPath string) error { sourcePath := c.getPath(hash) // Try to create hard link directly (most efficient path) - if err := os.Link(sourcePath, targetPath); err != nil { + if err := vfs.Link(c.fs, sourcePath, targetPath); err != nil { // Check if it's because target already exists if os.IsExist(err) { // File already exists, which is fine @@ -52,7 +56,7 @@ func (c *Content) Link(ctx context.Context, hash, targetPath string) error { } // If hard link fails for other reasons, try to copy the file - data, readErr := os.ReadFile(sourcePath) + data, readErr := vfs.ReadFile(c.fs, sourcePath) if readErr != nil { return &WrappedError{ Op: "read_source", @@ -63,7 +67,7 @@ func (c *Content) Link(ctx context.Context, hash, targetPath string) error { // Write to temporary file first tempPath := targetPath + ".tmp" - if err := os.WriteFile(tempPath, data, RegularFilePerms); err != nil { + if err := vfs.WriteFile(c.fs, tempPath, data, RegularFilePerms); err != nil { return &WrappedError{ Op: "write_target", Path: tempPath, @@ -72,7 +76,7 @@ func (c *Content) Link(ctx context.Context, hash, targetPath string) error { } // Atomic rename to final path - if err := os.Rename(tempPath, targetPath); err != nil { + if err := c.fs.Rename(tempPath, targetPath); err != nil { return &WrappedError{ Op: "rename_target", Path: tempPath, @@ -99,13 +103,13 @@ func (c *Content) Store(l log.Logger, hash string, data []byte) error { } }() - if err = os.MkdirAll(c.store.Path(), DefaultDirPerms); err != nil { + if err = c.fs.MkdirAll(c.store.Path(), DefaultDirPerms); err != nil { return wrapError("create_store_dir", c.store.Path(), ErrCreateDir) } // Ensure partition directory exists partitionDir := c.getPartition(hash) - if err = os.MkdirAll(partitionDir, DefaultDirPerms); err != nil { + if err = c.fs.MkdirAll(partitionDir, DefaultDirPerms); err != nil { return wrapError("create_partition_dir", partitionDir, ErrCreateDir) } @@ -142,26 +146,97 @@ func (c *Content) EnsureWithWait(l log.Logger, hash string, data []byte) error { } }() - if err = os.MkdirAll(c.store.Path(), DefaultDirPerms); err != nil { + if err = c.fs.MkdirAll(c.store.Path(), DefaultDirPerms); err != nil { return wrapError("create_store_dir", c.store.Path(), ErrCreateDir) } // Ensure partition directory exists partitionDir := c.getPartition(hash) - if err = os.MkdirAll(partitionDir, DefaultDirPerms); err != nil { + if err = c.fs.MkdirAll(partitionDir, DefaultDirPerms); err != nil { return wrapError("create_partition_dir", partitionDir, ErrCreateDir) } return c.writeContentToFile(l, hash, data) } +// EnsureCopy ensures that a content item exists in the store by copying from a file +func (c *Content) EnsureCopy(l log.Logger, hash, src string) (err error) { + path := c.getPath(hash) + if c.store.hasContent(path) { + return nil + } + + lock, err := c.store.AcquireLock(hash) + if err != nil { + return wrapError("acquire_lock", hash, err) + } + + defer func() { + err = errors.Join(err, lock.Unlock()) + }() + + // Ensure partition directory exists + partitionDir := c.getPartition(hash) + if err = c.fs.MkdirAll(partitionDir, DefaultDirPerms); err != nil { + return wrapError("create_partition_dir", partitionDir, ErrCreateDir) + } + + f, err := c.fs.Create(path) + if err != nil { + return wrapError("create_file", path, err) + } + + defer func() { + err = errors.Join(err, f.Close()) + }() + + r, err := c.fs.Open(src) + if err != nil { + return wrapError("open_source", src, err) + } + + defer func() { + err = errors.Join(err, r.Close()) + }() + + if _, err := io.Copy(f, r); err != nil { + return wrapError("copy_file", src, err) + } + + return nil +} + +// GetTmpHandle returns a file handle to a temporary file where content will be stored. +func (c *Content) GetTmpHandle(hash string) (vfs.File, error) { + partitionDir := c.getPartition(hash) + if err := c.fs.MkdirAll(partitionDir, DefaultDirPerms); err != nil { + return nil, wrapError("create_partition_dir", partitionDir, ErrCreateDir) + } + + path := c.getPath(hash) + tempPath := path + ".tmp" + + f, err := c.fs.Create(tempPath) + if err != nil { + return nil, wrapError("create_temp_file", tempPath, err) + } + + return f, err +} + +// Read retrieves content from the store by hash +func (c *Content) Read(hash string) ([]byte, error) { + path := c.getPath(hash) + return vfs.ReadFile(c.fs, path) +} + // writeContentToFile writes data to a temporary file, // sets appropriate permissions, and performs an atomic rename. func (c *Content) writeContentToFile(l log.Logger, hash string, data []byte) error { path := c.getPath(hash) tempPath := path + ".tmp" - f, err := os.OpenFile(tempPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, RegularFilePerms) + f, err := c.fs.OpenFile(tempPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, RegularFilePerms) if err != nil { return wrapError("create_temp_file", tempPath, err) } @@ -169,9 +244,11 @@ func (c *Content) writeContentToFile(l log.Logger, hash string, data []byte) err buf := bufio.NewWriter(f) if _, err := buf.Write(data); err != nil { - f.Close() + if closeErr := f.Close(); closeErr != nil { + l.Warnf("failed to close temp file %s: %v", tempPath, closeErr) + } - if removeErr := os.Remove(tempPath); removeErr != nil { + if removeErr := c.fs.Remove(tempPath); removeErr != nil { l.Warnf("failed to remove temp file %s: %v", tempPath, removeErr) } @@ -179,9 +256,11 @@ func (c *Content) writeContentToFile(l log.Logger, hash string, data []byte) err } if err := buf.Flush(); err != nil { - f.Close() + if closeErr := f.Close(); closeErr != nil { + l.Warnf("failed to close temp file %s: %v", tempPath, closeErr) + } - if removeErr := os.Remove(tempPath); removeErr != nil { + if removeErr := c.fs.Remove(tempPath); removeErr != nil { l.Warnf("failed to remove temp file %s: %v", tempPath, removeErr) } @@ -189,7 +268,7 @@ func (c *Content) writeContentToFile(l log.Logger, hash string, data []byte) err } if err := f.Close(); err != nil { - if removeErr := os.Remove(tempPath); removeErr != nil { + if removeErr := c.fs.Remove(tempPath); removeErr != nil { l.Warnf("failed to remove temp file %s: %v", tempPath, removeErr) } @@ -197,8 +276,8 @@ func (c *Content) writeContentToFile(l log.Logger, hash string, data []byte) err } // Set read-only permissions on the temporary file - if err := os.Chmod(tempPath, StoredFilePerms); err != nil { - if removeErr := os.Remove(tempPath); removeErr != nil { + if err := c.fs.Chmod(tempPath, StoredFilePerms); err != nil { + if removeErr := c.fs.Remove(tempPath); removeErr != nil { l.Warnf("failed to remove temp file %s: %v", tempPath, removeErr) } @@ -208,17 +287,17 @@ func (c *Content) writeContentToFile(l log.Logger, hash string, data []byte) err // For Windows, handle readonly attributes specifically if runtime.GOOS == WindowsOS { // Check if a destination file exists and is read-only - if _, err := os.Stat(path); err == nil { + if _, err := c.fs.Stat(path); err == nil { // File exists, make it writable before rename operation - if err := os.Chmod(path, RegularFilePerms); err != nil { + if err := c.fs.Chmod(path, RegularFilePerms); err != nil { l.Warnf("failed to make destination file writable %s: %v", path, err) } } } // Atomic rename - if err := os.Rename(tempPath, path); err != nil { - if removeErr := os.Remove(tempPath); removeErr != nil { + if err := c.fs.Rename(tempPath, path); err != nil { + if removeErr := c.fs.Remove(tempPath); removeErr != nil { l.Warnf("failed to remove temp file %s: %v", tempPath, removeErr) } @@ -228,7 +307,7 @@ func (c *Content) writeContentToFile(l log.Logger, hash string, data []byte) err // For Windows, we need to set the permissions again after rename if runtime.GOOS == WindowsOS { // Ensure the file has read-only permissions after rename - if err := os.Chmod(path, StoredFilePerms); err != nil { + if err := c.fs.Chmod(path, StoredFilePerms); err != nil { return wrapError("chmod_final_file", path, err) } } @@ -236,75 +315,6 @@ func (c *Content) writeContentToFile(l log.Logger, hash string, data []byte) err return nil } -// EnsureCopy ensures that a content item exists in the store by copying from a file -func (c *Content) EnsureCopy(l log.Logger, hash, src string) error { - path := c.getPath(hash) - if c.store.hasContent(path) { - return nil - } - - lock, err := c.store.AcquireLock(hash) - if err != nil { - return wrapError("acquire_lock", hash, err) - } - - defer func() { - if unlockErr := lock.Unlock(); unlockErr != nil { - l.Warnf("failed to unlock filesystem lock for hash %s: %v", hash, unlockErr) - } - }() - - // Ensure partition directory exists - partitionDir := c.getPartition(hash) - if err = os.MkdirAll(partitionDir, DefaultDirPerms); err != nil { - return wrapError("create_partition_dir", partitionDir, ErrCreateDir) - } - - f, err := os.Create(path) - if err != nil { - return wrapError("create_file", path, err) - } - - defer f.Close() - - r, err := os.Open(src) - if err != nil { - return wrapError("open_source", src, err) - } - - defer r.Close() - - if _, err := io.Copy(f, r); err != nil { - return wrapError("copy_file", src, err) - } - - return nil -} - -// GetTmpHandle returns a file handle to a temporary file where content will be stored. -func (c *Content) GetTmpHandle(hash string) (*os.File, error) { - partitionDir := c.getPartition(hash) - if err := os.MkdirAll(partitionDir, DefaultDirPerms); err != nil { - return nil, wrapError("create_partition_dir", partitionDir, ErrCreateDir) - } - - path := c.getPath(hash) - tempPath := path + ".tmp" - - f, err := os.Create(tempPath) - if err != nil { - return nil, wrapError("create_temp_file", tempPath, err) - } - - return f, err -} - -// Read retrieves content from the store by hash -func (c *Content) Read(hash string) ([]byte, error) { - path := c.getPath(hash) - return os.ReadFile(path) -} - // getPartition returns the partition path for a given hash func (c *Content) getPartition(hash string) string { return filepath.Join(c.store.Path(), hash[:2]) diff --git a/internal/cas/content_test.go b/internal/cas/content_test.go index ce44b05a7d..9b510ba259 100644 --- a/internal/cas/content_test.go +++ b/internal/cas/content_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/gruntwork-io/terragrunt/internal/cas" - "github.com/gruntwork-io/terragrunt/test/helpers" + "github.com/gruntwork-io/terragrunt/internal/vfs" "github.com/gruntwork-io/terragrunt/test/helpers/logger" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -21,7 +21,10 @@ func TestContent_Store(t *testing.T) { t.Run("store new content", func(t *testing.T) { t.Parallel() - store := cas.NewStore(helpers.TmpDirWOSymlinks(t)) + + memFs := vfs.NewMemMapFS() + require.NoError(t, memFs.MkdirAll("/store", 0755)) + store := cas.NewStore("/store").WithFS(memFs) content := cas.NewContent(store) testHash := testHashValue @@ -33,7 +36,7 @@ func TestContent_Store(t *testing.T) { // Verify content was stored partitionDir := filepath.Join(store.Path(), testHash[:2]) storedPath := filepath.Join(partitionDir, testHash) - storedData, err := os.ReadFile(storedPath) + storedData, err := vfs.ReadFile(memFs, storedPath) require.NoError(t, err) assert.Equal(t, testData, storedData) }) @@ -41,7 +44,9 @@ func TestContent_Store(t *testing.T) { t.Run("ensure existing content", func(t *testing.T) { t.Parallel() - store := cas.NewStore(helpers.TmpDirWOSymlinks(t)) + memFs := vfs.NewMemMapFS() + require.NoError(t, memFs.MkdirAll("/store", 0755)) + store := cas.NewStore("/store").WithFS(memFs) content := cas.NewContent(store) testHash := testHashValue @@ -57,7 +62,7 @@ func TestContent_Store(t *testing.T) { // Verify original content remains partitionDir := filepath.Join(store.Path(), testHash[:2]) storedPath := filepath.Join(partitionDir, testHash) - storedData, err := os.ReadFile(storedPath) + storedData, err := vfs.ReadFile(memFs, storedPath) require.NoError(t, err) assert.Equal(t, testData, storedData) }) @@ -65,7 +70,9 @@ func TestContent_Store(t *testing.T) { t.Run("overwrite existing content", func(t *testing.T) { t.Parallel() - store := cas.NewStore(helpers.TmpDirWOSymlinks(t)) + memFs := vfs.NewMemMapFS() + require.NoError(t, memFs.MkdirAll("/store", 0755)) + store := cas.NewStore("/store").WithFS(memFs) content := cas.NewContent(store) testHash := testHashValue @@ -78,10 +85,10 @@ func TestContent_Store(t *testing.T) { err = content.Store(l, testHash, differentData) require.NoError(t, err) - // Verify original content remains + // Verify content was overwritten partitionDir := filepath.Join(store.Path(), testHash[:2]) storedPath := filepath.Join(partitionDir, testHash) - storedData, err := os.ReadFile(storedPath) + storedData, err := vfs.ReadFile(memFs, storedPath) require.NoError(t, err) assert.Equal(t, differentData, storedData) }) @@ -94,8 +101,11 @@ func TestContent_Link(t *testing.T) { t.Run("create new link", func(t *testing.T) { t.Parallel() - storeDir := helpers.TmpDirWOSymlinks(t) - store := cas.NewStore(storeDir) + + memFs := vfs.NewMemMapFS() + require.NoError(t, memFs.MkdirAll("/store", 0755)) + require.NoError(t, memFs.MkdirAll("/target", 0755)) + store := cas.NewStore("/store").WithFS(memFs) content := cas.NewContent(store) testHash := testHashValue @@ -106,29 +116,52 @@ func TestContent_Link(t *testing.T) { require.NoError(t, err) // Then create a link to it - targetDir := helpers.TmpDirWOSymlinks(t) - targetPath := filepath.Join(targetDir, "test.txt") + targetPath := filepath.Join("/target", "test.txt") err = content.Link(t.Context(), testHash, targetPath) require.NoError(t, err) // Verify link was created and contains correct content - linkedData, err := os.ReadFile(targetPath) + linkedData, err := vfs.ReadFile(memFs, targetPath) require.NoError(t, err) assert.Equal(t, testData, linkedData) + }) - // Verify it's a hard link by checking inode numbers - partitionDir := filepath.Join(store.Path(), testHash[:2]) - sourceInfo, err := os.Stat(filepath.Join(partitionDir, testHash)) + t.Run("create hard link on real filesystem", func(t *testing.T) { + t.Parallel() + + osFs := vfs.NewOSFS() + storeDir := t.TempDir() + targetDir := t.TempDir() + store := cas.NewStore(storeDir).WithFS(osFs) + + content := cas.NewContent(store) + testHash := testHashValue + testData := []byte("test content") + + err := content.Store(l, testHash, testData) + require.NoError(t, err) + + targetPath := filepath.Join(targetDir, "test.txt") + err = content.Link(t.Context(), testHash, targetPath) + require.NoError(t, err) + + // Verify hard link by comparing inodes + sourcePath := filepath.Join(storeDir, testHash[:2], testHash) + sourceInfo, err := os.Stat(sourcePath) require.NoError(t, err) targetInfo, err := os.Stat(targetPath) require.NoError(t, err) - assert.Equal(t, sourceInfo.Sys(), targetInfo.Sys()) + assert.True(t, os.SameFile(sourceInfo, targetInfo), "expected hard link (same inode)") }) t.Run("link to existing file", func(t *testing.T) { t.Parallel() - store := cas.NewStore(helpers.TmpDirWOSymlinks(t)) + + memFs := vfs.NewMemMapFS() + require.NoError(t, memFs.MkdirAll("/store", 0755)) + require.NoError(t, memFs.MkdirAll("/target", 0755)) + store := cas.NewStore("/store").WithFS(memFs) content := cas.NewContent(store) testHash := testHashValue @@ -139,9 +172,8 @@ func TestContent_Link(t *testing.T) { require.NoError(t, err) // Create target file - targetDir := helpers.TmpDirWOSymlinks(t) - targetPath := filepath.Join(targetDir, "test.txt") - err = os.WriteFile(targetPath, []byte("existing content"), 0644) + targetPath := filepath.Join("/target", "test.txt") + err = vfs.WriteFile(memFs, targetPath, []byte("existing content"), 0644) require.NoError(t, err) // Try to create link @@ -149,7 +181,7 @@ func TestContent_Link(t *testing.T) { require.NoError(t, err) // Verify original content remains - existingData, err := os.ReadFile(targetPath) + existingData, err := vfs.ReadFile(memFs, targetPath) require.NoError(t, err) assert.Equal(t, []byte("existing content"), existingData) }) @@ -163,7 +195,10 @@ func TestContent_EnsureWithWait(t *testing.T) { t.Run("content already exists", func(t *testing.T) { t.Parallel() - store := cas.NewStore(helpers.TmpDirWOSymlinks(t)) + memFs := vfs.NewMemMapFS() + require.NoError(t, memFs.MkdirAll("/store", 0755)) + store := cas.NewStore("/store").WithFS(memFs) + content := cas.NewContent(store) testHash := testHashValue testData := []byte("test content") @@ -179,7 +214,7 @@ func TestContent_EnsureWithWait(t *testing.T) { // Verify original content remains partitionDir := filepath.Join(store.Path(), testHash[:2]) storedPath := filepath.Join(partitionDir, testHash) - storedData, err := os.ReadFile(storedPath) + storedData, err := vfs.ReadFile(memFs, storedPath) require.NoError(t, err) assert.Equal(t, testData, storedData) }) @@ -187,7 +222,10 @@ func TestContent_EnsureWithWait(t *testing.T) { t.Run("content doesn't exist", func(t *testing.T) { t.Parallel() - store := cas.NewStore(helpers.TmpDirWOSymlinks(t)) + memFs := vfs.NewMemMapFS() + require.NoError(t, memFs.MkdirAll("/store", 0755)) + store := cas.NewStore("/store").WithFS(memFs) + content := cas.NewContent(store) testHash := "newcontent123456" testData := []byte("new test content") @@ -199,7 +237,7 @@ func TestContent_EnsureWithWait(t *testing.T) { // Verify content was stored partitionDir := filepath.Join(store.Path(), testHash[:2]) storedPath := filepath.Join(partitionDir, testHash) - storedData, err := os.ReadFile(storedPath) + storedData, err := vfs.ReadFile(memFs, storedPath) require.NoError(t, err) assert.Equal(t, testData, storedData) }) @@ -207,7 +245,10 @@ func TestContent_EnsureWithWait(t *testing.T) { t.Run("concurrent writes - optimization", func(t *testing.T) { t.Parallel() - store := cas.NewStore(helpers.TmpDirWOSymlinks(t)) + memFs := vfs.NewMemMapFS() + require.NoError(t, memFs.MkdirAll("/store", 0755)) + store := cas.NewStore("/store").WithFS(memFs) + content := cas.NewContent(store) testHash := "concurrent123456" @@ -244,7 +285,7 @@ func TestContent_EnsureWithWait(t *testing.T) { // Verify only one content exists (from process 1) partitionDir := filepath.Join(store.Path(), testHash[:2]) storedPath := filepath.Join(partitionDir, testHash) - storedData, err := os.ReadFile(storedPath) + storedData, err := vfs.ReadFile(memFs, storedPath) require.NoError(t, err) assert.Equal(t, []byte("process 1 data"), storedData) }) diff --git a/internal/cas/getter.go b/internal/cas/getter.go index 3bed8515e1..f8020be531 100644 --- a/internal/cas/getter.go +++ b/internal/cas/getter.go @@ -4,10 +4,10 @@ import ( "context" "fmt" "net/url" - "os" "strings" "github.com/gruntwork-io/terragrunt/internal/errors" + "github.com/gruntwork-io/terragrunt/internal/vfs" "github.com/gruntwork-io/terragrunt/pkg/log" "github.com/hashicorp/go-getter/v2" ) @@ -15,6 +15,11 @@ import ( // Assert that CASGetter implements the Getter interface var _ getter.Getter = &CASGetter{} +var ( + ErrDirectoryNotFound = errors.New("directory not found") + ErrNotADirectory = errors.New("not a directory") +) + // CASGetter is a go-getter Getter implementation. type CASGetter struct { CAS *CAS @@ -107,7 +112,9 @@ func (g *CASGetter) Detect(req *getter.Request) (bool, error) { if ok { // Check if this is a FileDetector using type assertion if _, isFileDetector := detector.(*getter.FileDetector); isFileDetector { - info, statErr := os.Stat(src) + fs := g.getFS() + + info, statErr := fs.Stat(src) if statErr != nil { return false, fmt.Errorf("%w: %s", ErrDirectoryNotFound, src) } @@ -129,7 +136,11 @@ func (g *CASGetter) Detect(req *getter.Request) (bool, error) { return false, nil } -var ( - ErrDirectoryNotFound = errors.New("directory not found") - ErrNotADirectory = errors.New("not a directory") -) +// getFS returns the filesystem from the CAS instance, or a default OSFS if CAS is nil. +func (g *CASGetter) getFS() vfs.FS { + if g.CAS != nil { + return g.CAS.FS() + } + + return vfs.NewOSFS() +} diff --git a/internal/cas/getter_ssh_test.go b/internal/cas/getter_ssh_test.go index 19dc774554..a6bf01228b 100644 --- a/internal/cas/getter_ssh_test.go +++ b/internal/cas/getter_ssh_test.go @@ -48,7 +48,7 @@ func TestSSHCASGetterGet(t *testing.T) { tmpDir := helpers.TmpDirWOSymlinks(t) storePath := filepath.Join(tmpDir, "store") - c, err := cas.New(cas.Options{StorePath: storePath}) + c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) opts := &cas.CloneOptions{ diff --git a/internal/cas/getter_test.go b/internal/cas/getter_test.go index 5faf5ca70f..554ac13f08 100644 --- a/internal/cas/getter_test.go +++ b/internal/cas/getter_test.go @@ -97,16 +97,16 @@ func TestCASGetterDetect(t *testing.T) { func TestCASGetterGet(t *testing.T) { t.Parallel() + repoURL := startTestServer(t) + tempDir := helpers.TmpDirWOSymlinks(t) storePath := filepath.Join(tempDir, "store") - c, err := cas.New(cas.Options{ - StorePath: storePath, - }) + c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) opts := &cas.CloneOptions{ - Branch: "main", + Depth: -1, } l := logger.CreateLogger() @@ -117,15 +117,12 @@ func TestCASGetterGet(t *testing.T) { } tests := []struct { - name string - url string - queryRef string - expectRef string + name string + url string }{ { - name: "URL with ref parameter", - url: "github.com/gruntwork-io/terragrunt?ref=v0.75.0", - expectRef: "v0.75.0", + name: "clone via getter with ref", + url: "git::" + repoURL + "?ref=main", }, } @@ -155,9 +152,7 @@ func TestCASGetterLocalDir(t *testing.T) { tmp := helpers.TmpDirWOSymlinks(t) storePath := filepath.Join(tmp, "store") - c, err := cas.New(cas.Options{ - StorePath: storePath, - }) + c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) opts := &cas.CloneOptions{ diff --git a/internal/cas/integration_test.go b/internal/cas/integration_test.go index 29921cb440..5c412ab136 100644 --- a/internal/cas/integration_test.go +++ b/internal/cas/integration_test.go @@ -17,6 +17,7 @@ func TestIntegration_CloneAndReuse(t *testing.T) { t.Parallel() l := logger.CreateLogger() + repoURL := startTestServer(t) t.Run("clone same repo twice uses store", func(t *testing.T) { t.Parallel() @@ -25,13 +26,12 @@ func TestIntegration_CloneAndReuse(t *testing.T) { // First clone firstClonePath := filepath.Join(tempDir, "first") - cas1, err := cas.New(cas.Options{ - StorePath: storePath, - }) + cas1, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) require.NoError(t, cas1.Clone(t.Context(), l, &cas.CloneOptions{ - Dir: firstClonePath, - }, "https://github.com/gruntwork-io/terragrunt.git")) + Dir: firstClonePath, + Depth: -1, + }, repoURL)) // Get info about first clone firstReadme := filepath.Join(firstClonePath, "README.md") @@ -40,13 +40,12 @@ func TestIntegration_CloneAndReuse(t *testing.T) { // Second clone secondClonePath := filepath.Join(tempDir, "second") - cas2, err := cas.New(cas.Options{ - StorePath: storePath, - }) + cas2, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) require.NoError(t, cas2.Clone(t.Context(), l, &cas.CloneOptions{ - Dir: secondClonePath, - }, "https://github.com/gruntwork-io/terragrunt.git")) + Dir: secondClonePath, + Depth: -1, + }, repoURL)) // Get info about second clone secondReadme := filepath.Join(secondClonePath, "README.md") @@ -65,15 +64,14 @@ func TestIntegration_CloneAndReuse(t *testing.T) { t.Parallel() tempDir := helpers.TmpDirWOSymlinks(t) - c, err := cas.New(cas.Options{ - StorePath: filepath.Join(tempDir, "store"), - }) + c, err := cas.New(cas.WithStorePath(filepath.Join(tempDir, "store"))) require.NoError(t, err) err = c.Clone(t.Context(), l, &cas.CloneOptions{ Dir: filepath.Join(tempDir, "repo"), Branch: "nonexistent-branch", - }, "https://github.com/gruntwork-io/terragrunt.git") + Depth: -1, + }, repoURL) require.Error(t, err) var wrappedErr *git.WrappedError @@ -85,19 +83,14 @@ func TestIntegration_CloneAndReuse(t *testing.T) { t.Parallel() tempDir := helpers.TmpDirWOSymlinks(t) - c, err := cas.New(cas.Options{ - StorePath: filepath.Join(tempDir, "store"), - }) + c, err := cas.New(cas.WithStorePath(filepath.Join(tempDir, "store"))) require.NoError(t, err) err = c.Clone(t.Context(), l, &cas.CloneOptions{ - Dir: filepath.Join(tempDir, "repo"), - }, "https://github.com/yhakbar/nonexistent-repo.git") + Dir: filepath.Join(tempDir, "repo"), + Depth: -1, + }, "http://127.0.0.1:1/nonexistent-repo.git") require.Error(t, err) - - var wrappedErr *git.WrappedError - require.ErrorAs(t, err, &wrappedErr) - assert.ErrorIs(t, wrappedErr.Err, git.ErrCommandSpawn) }) } @@ -105,31 +98,27 @@ func TestIntegration_TreeStorage(t *testing.T) { t.Parallel() ctx := t.Context() - l := logger.CreateLogger() + repoURL := startTestServer(t) t.Run("stores tree objects", func(t *testing.T) { t.Parallel() tempDir := helpers.TmpDirWOSymlinks(t) storePath := filepath.Join(tempDir, "store") - const testTag = "v0.98.0" - // First clone to populate store - c, err := cas.New(cas.Options{ - StorePath: storePath, - }) + c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) require.NoError(t, c.Clone(ctx, l, &cas.CloneOptions{ - Dir: filepath.Join(tempDir, "repo"), - Branch: testTag, - }, "https://github.com/gruntwork-io/terragrunt.git")) + Dir: filepath.Join(tempDir, "repo"), + Depth: -1, + }, repoURL)) - // Get the commit hash for the tag + // Get the commit hash for HEAD g, err := git.NewGitRunner() require.NoError(t, err) - results, err := g.LsRemote(ctx, "https://github.com/gruntwork-io/terragrunt.git", testTag) + results, err := g.LsRemote(ctx, repoURL, "HEAD") require.NoError(t, err) require.NotEmpty(t, results) commitHash := results[0].Hash diff --git a/internal/cas/local.go b/internal/cas/local.go index 8b61cd91e1..18c1badef4 100644 --- a/internal/cas/local.go +++ b/internal/cas/local.go @@ -5,11 +5,12 @@ import ( "crypto/sha1" "encoding/hex" "fmt" - "os" + "io/fs" "path/filepath" "strings" "github.com/gruntwork-io/terragrunt/internal/git" + "github.com/gruntwork-io/terragrunt/internal/vfs" "github.com/gruntwork-io/terragrunt/pkg/log" ) @@ -42,13 +43,13 @@ func (c *CAS) hashDirectory(sourceDir string) (string, []byte, error) { var allHashes []string - err := filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error { + err := vfs.WalkDir(c.fs, sourceDir, func(path string, d fs.DirEntry, err error) error { if err != nil { return err } // Implicitly handled by tracking the file hashes. - if info.IsDir() { + if d.IsDir() { return nil } @@ -60,11 +61,16 @@ func (c *CAS) hashDirectory(sourceDir string) (string, []byte, error) { // Convert to forward slashes for consistency (git-style paths) relPath = strings.ReplaceAll(relPath, string(filepath.Separator), "/") - fileHash, err := hashFile(path) + fileHash, err := hashFile(c.fs, path) if err != nil { return fmt.Errorf("failed to hash file %s: %w", path, err) } + info, err := d.Info() + if err != nil { + return fmt.Errorf("failed to stat file %s: %w", path, err) + } + // Artificially create a tree entry for the file. mode := fmt.Sprintf("%06o", info.Mode().Perm()) treeLine := fmt.Sprintf("%s blob %s\t%s\n", mode, fileHash, relPath) @@ -95,18 +101,18 @@ func (c *CAS) storeLocalContent(l log.Logger, sourceDir, dirHash string, treeDat } // Walk the directory and store all files - return filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error { + return vfs.WalkDir(c.fs, sourceDir, func(path string, d fs.DirEntry, err error) error { if err != nil { return err } // Skip directories and the root directory itself - if info.IsDir() { + if d.IsDir() { return nil } // Hash the file to get its content hash - fileHash, err := hashFile(path) + fileHash, err := hashFile(c.fs, path) if err != nil { return fmt.Errorf("failed to hash file %s: %w", path, err) } diff --git a/internal/cas/race_test.go b/internal/cas/race_test.go index 296a0e2524..7742153997 100644 --- a/internal/cas/race_test.go +++ b/internal/cas/race_test.go @@ -17,16 +17,16 @@ import ( func TestCASGetterGetWithRacing(t *testing.T) { t.Parallel() + repoURL := startTestServer(t) + tempDir := helpers.TmpDirWOSymlinks(t) storePath := filepath.Join(tempDir, "store") - c, err := cas.New(cas.Options{ - StorePath: storePath, - }) + c, err := cas.New(cas.WithStorePath(storePath)) require.NoError(t, err) opts := &cas.CloneOptions{ - Branch: "main", + Depth: -1, } l := logger.CreateLogger() @@ -37,15 +37,12 @@ func TestCASGetterGetWithRacing(t *testing.T) { } tests := []struct { - name string - url string - queryRef string - expectRef string + name string + url string }{ { - name: "URL with ref parameter", - url: "github.com/gruntwork-io/terragrunt?ref=v0.75.0", - expectRef: "v0.75.0", + name: "clone via getter with ref", + url: "git::" + repoURL + "?ref=main", }, } diff --git a/internal/cas/store.go b/internal/cas/store.go index ce85578ae8..68df120abc 100644 --- a/internal/cas/store.go +++ b/internal/cas/store.go @@ -1,24 +1,36 @@ package cas import ( - "os" "path/filepath" - "github.com/gofrs/flock" + "github.com/gruntwork-io/terragrunt/internal/vfs" ) // Store manages the store directory and filesystem locks to prevent concurrent writes type Store struct { + fs vfs.FS path string } -// NewStore creates a new Store instance. +// NewStore creates a new Store instance with the OS filesystem. func NewStore(path string) *Store { return &Store{ path: path, + fs: vfs.NewOSFS(), } } +// WithFS sets the filesystem for file operations and returns the Store for method chaining. +func (s *Store) WithFS(fs vfs.FS) *Store { + s.fs = fs + return s +} + +// FS returns the configured filesystem. +func (s *Store) FS() vfs.FS { + return s.fs +} + // Path returns the current store path func (s *Store) Path() string { return s.path @@ -32,55 +44,32 @@ func (s *Store) NeedsWrite(hash string) bool { return !s.hasContent(path) } -// HasContent checks if a given hash exists in the store -func (s *Store) hasContent(path string) bool { - _, err := os.Stat(path) - - return err == nil -} - // AcquireLock acquires a filesystem lock for the given hash -// Returns the flock instance that should be unlocked when done -func (s *Store) AcquireLock(hash string) (*flock.Flock, error) { +// Returns the lock that should be unlocked when done +func (s *Store) AcquireLock(hash string) (vfs.Unlocker, error) { partitionDir := filepath.Join(s.path, hash[:2]) lockPath := filepath.Join(partitionDir, hash+".lock") // Ensure the partition directory exists - if err := os.MkdirAll(partitionDir, DefaultDirPerms); err != nil { - return nil, err - } - - lock := flock.New(lockPath) - if err := lock.Lock(); err != nil { + if err := s.fs.MkdirAll(partitionDir, DefaultDirPerms); err != nil { return nil, err } - return lock, nil + return vfs.Lock(s.fs, lockPath) } // TryAcquireLock attempts to acquire a filesystem lock for the given hash without blocking -// Returns the flock instance and true if successful, nil and false if the lock is already held -func (s *Store) TryAcquireLock(hash string) (*flock.Flock, bool, error) { +// Returns the lock and true if successful, nil and false if the lock is already held +func (s *Store) TryAcquireLock(hash string) (vfs.Unlocker, bool, error) { partitionDir := filepath.Join(s.path, hash[:2]) lockPath := filepath.Join(partitionDir, hash+".lock") // Ensure the partition directory exists - if err := os.MkdirAll(partitionDir, DefaultDirPerms); err != nil { + if err := s.fs.MkdirAll(partitionDir, DefaultDirPerms); err != nil { return nil, false, err } - lock := flock.New(lockPath) - - acquired, err := lock.TryLock() - if err != nil { - return nil, false, err - } - - if !acquired { - return nil, false, nil - } - - return lock, true, nil + return vfs.TryLock(s.fs, lockPath) } // EnsureWithWait tries to acquire a lock for the given hash, and if another process @@ -91,7 +80,7 @@ func (s *Store) TryAcquireLock(hash string) (*flock.Flock, bool, error) { // - needsWrite: true if content doesn't exist and caller should write it // - lock: the acquired lock (nil if needsWrite is false) // - error: any error that occurred -func (s *Store) EnsureWithWait(hash string) (needsWrite bool, lock *flock.Flock, err error) { +func (s *Store) EnsureWithWait(hash string) (needsWrite bool, lock vfs.Unlocker, err error) { // Fast path: check if content already exists partitionDir := filepath.Join(s.path, hash[:2]) path := filepath.Join(partitionDir, hash) @@ -101,7 +90,7 @@ func (s *Store) EnsureWithWait(hash string) (needsWrite bool, lock *flock.Flock, } // Try to acquire lock without blocking - flockLock, acquired, err := s.TryAcquireLock(hash) + tryLock, acquired, err := s.TryAcquireLock(hash) if err != nil { return false, nil, err } @@ -111,14 +100,14 @@ func (s *Store) EnsureWithWait(hash string) (needsWrite bool, lock *flock.Flock, // (another process might have completed while we were trying) if !s.NeedsWrite(hash) { // Content appeared while we were acquiring lock, no write needed - if err = flockLock.Unlock(); err != nil { + if err = tryLock.Unlock(); err != nil { return false, nil, err } return false, nil, nil } // We have the lock and content doesn't exist, caller should write - return true, flockLock, nil + return true, tryLock, nil } // Lock is held by another process, wait for it to complete @@ -140,3 +129,9 @@ func (s *Store) EnsureWithWait(hash string) (needsWrite bool, lock *flock.Flock, // Content still doesn't exist, caller should write it return true, waitLock, nil } + +func (s *Store) hasContent(path string) bool { + _, err := s.fs.Stat(path) + + return err == nil +} diff --git a/internal/cas/store_test.go b/internal/cas/store_test.go index 7e385a9d00..d6b6c2dd70 100644 --- a/internal/cas/store_test.go +++ b/internal/cas/store_test.go @@ -1,44 +1,48 @@ package cas_test import ( - "os" "path/filepath" "testing" "time" "github.com/gruntwork-io/terragrunt/internal/cas" - "github.com/gruntwork-io/terragrunt/test/helpers" + "github.com/gruntwork-io/terragrunt/internal/vfs" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +const defaultStorePath = "/store" + func TestStore(t *testing.T) { t.Parallel() t.Run("custom path", func(t *testing.T) { t.Parallel() - tempDir := helpers.TmpDirWOSymlinks(t) - customPath := filepath.Join(tempDir, "custom-store") - store := cas.NewStore(customPath) + memFs := vfs.NewMemMapFS() + customPath := "/custom-store" + + store := cas.NewStore(customPath).WithFS(memFs) assert.Equal(t, customPath, store.Path()) }) } func TestStore_NeedsWrite(t *testing.T) { t.Parallel() - tempDir := helpers.TmpDirWOSymlinks(t) - store := cas.NewStore(tempDir) + + memFs := vfs.NewMemMapFS() + storePath := defaultStorePath + store := cas.NewStore(storePath).WithFS(memFs) // Create a fake content file testHash := "abcdef123456" // Create partition directory partitionDir := filepath.Join(store.Path(), testHash[:2]) - err := os.MkdirAll(partitionDir, 0755) + err := memFs.MkdirAll(partitionDir, 0755) require.NoError(t, err, "Failed to create partition directory") testPath := filepath.Join(partitionDir, testHash) - err = os.WriteFile(testPath, []byte("test"), 0644) + err = vfs.WriteFile(memFs, testPath, []byte("test"), 0644) require.NoError(t, err, "Failed to create test file") tests := []struct { @@ -68,8 +72,10 @@ func TestStore_NeedsWrite(t *testing.T) { func TestStore_AcquireLock(t *testing.T) { t.Parallel() - tempDir := helpers.TmpDirWOSymlinks(t) - store := cas.NewStore(tempDir) + + memFs := vfs.NewMemMapFS() + storePath := defaultStorePath + store := cas.NewStore(storePath).WithFS(memFs) testHash := "abcdef1234567890abcdef1234567890abcdef12" // Test successful lock acquisition @@ -77,9 +83,10 @@ func TestStore_AcquireLock(t *testing.T) { require.NoError(t, err) assert.NotNil(t, lock) - // Verify lock file exists - lockPath := filepath.Join(tempDir, testHash[:2], testHash+".lock") - assert.FileExists(t, lockPath) + // Verify partition directory was created + partitionDir := filepath.Join(storePath, testHash[:2]) + _, err = memFs.Stat(partitionDir) + require.NoError(t, err) // Clean up err = lock.Unlock() @@ -88,8 +95,10 @@ func TestStore_AcquireLock(t *testing.T) { func TestStore_TryAcquireLock(t *testing.T) { t.Parallel() - tempDir := helpers.TmpDirWOSymlinks(t) - store := cas.NewStore(tempDir) + + memFs := vfs.NewMemMapFS() + storePath := defaultStorePath + store := cas.NewStore(storePath).WithFS(memFs) testHash := "abcdef1234567890abcdef1234567890abcdef12" // Test successful lock acquisition @@ -121,8 +130,10 @@ func TestStore_TryAcquireLock(t *testing.T) { func TestStore_LockConcurrency(t *testing.T) { t.Parallel() - tempDir := helpers.TmpDirWOSymlinks(t) - store := cas.NewStore(tempDir) + + memFs := vfs.NewMemMapFS() + storePath := defaultStorePath + store := cas.NewStore(storePath).WithFS(memFs) testHash := "abcdef1234567890abcdef1234567890abcdef12" // Test that multiple goroutines can't acquire the same lock @@ -169,20 +180,22 @@ func TestStore_LockConcurrency(t *testing.T) { func TestStore_EnsureWithWait(t *testing.T) { t.Parallel() - tempDir := helpers.TmpDirWOSymlinks(t) - store := cas.NewStore(tempDir) + + memFs := vfs.NewMemMapFS() + storePath := defaultStorePath + store := cas.NewStore(storePath).WithFS(memFs) testHash := "abcdef1234567890abcdef1234567890abcdef12" t.Run("content already exists", func(t *testing.T) { t.Parallel() // Create the content manually - partitionDir := filepath.Join(tempDir, testHash[:2]) - err := os.MkdirAll(partitionDir, 0755) + partitionDir := filepath.Join(storePath, testHash[:2]) + err := memFs.MkdirAll(partitionDir, 0755) require.NoError(t, err) contentPath := filepath.Join(partitionDir, testHash) - err = os.WriteFile(contentPath, []byte("existing content"), 0644) + err = vfs.WriteFile(memFs, contentPath, []byte("existing content"), 0644) require.NoError(t, err) // EnsureWithWait should return false (no write needed) diff --git a/internal/cas/testserver_test.go b/internal/cas/testserver_test.go new file mode 100644 index 0000000000..210772b714 --- /dev/null +++ b/internal/cas/testserver_test.go @@ -0,0 +1,27 @@ +package cas_test + +import ( + "testing" + + "github.com/gruntwork-io/terragrunt/internal/git" + "github.com/stretchr/testify/require" +) + +// startTestServer creates a local Git server with a few test files and +// returns its URL. The server is shut down when the test completes. +func startTestServer(t *testing.T) string { + t.Helper() + + srv, err := git.NewServer() + require.NoError(t, err) + t.Cleanup(func() { _ = srv.Close() }) + + require.NoError(t, srv.CommitFile("README.md", []byte("# test repo"), "add readme")) + require.NoError(t, srv.CommitFile("main.tf", []byte(`resource "null_resource" "test" {}`), "add main.tf")) + require.NoError(t, srv.CommitFile("test/integration_test.go", []byte("package test"), "add test file")) + + url, err := srv.Start(t.Context()) + require.NoError(t, err) + + return url +} diff --git a/internal/cas/tree.go b/internal/cas/tree.go index e4a2266782..040f138d64 100644 --- a/internal/cas/tree.go +++ b/internal/cas/tree.go @@ -2,7 +2,6 @@ package cas import ( "context" - "os" "path/filepath" "runtime" @@ -56,8 +55,10 @@ func LinkTree(ctx context.Context, store *Store, t *git.Tree, targetDir string) } } + fs := store.FS() + for dirPath := range dirsToCreate { - if err := os.MkdirAll(dirPath, DefaultDirPerms); err != nil { + if err := fs.MkdirAll(dirPath, DefaultDirPerms); err != nil { return wrapError("mkdir_all", dirPath, err) } } diff --git a/internal/cas/tree_test.go b/internal/cas/tree_test.go index ff27aad0f2..bf8f8baf67 100644 --- a/internal/cas/tree_test.go +++ b/internal/cas/tree_test.go @@ -1,13 +1,12 @@ package cas_test import ( - "os" "path/filepath" "testing" "github.com/gruntwork-io/terragrunt/internal/cas" "github.com/gruntwork-io/terragrunt/internal/git" - "github.com/gruntwork-io/terragrunt/test/helpers" + "github.com/gruntwork-io/terragrunt/internal/vfs" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -142,7 +141,7 @@ func TestLinkTree(t *testing.T) { tests := []struct { name string - setupStore func(t *testing.T) (*cas.Store, string) + setupStore func(t *testing.T) (*cas.Store, vfs.FS, string) treeData []byte wantFiles []struct { path string @@ -154,11 +153,12 @@ func TestLinkTree(t *testing.T) { }{ { name: "basic tree with files and directories", - setupStore: func(t *testing.T) (*cas.Store, string) { + setupStore: func(t *testing.T) (*cas.Store, vfs.FS, string) { t.Helper() - storeDir := helpers.TmpDirWOSymlinks(t) - store := cas.NewStore(storeDir) + memFs := vfs.NewMemMapFS() + require.NoError(t, memFs.MkdirAll("/store", 0755)) + store := cas.NewStore("/store").WithFS(memFs) content := cas.NewContent(store) // Create test content @@ -173,7 +173,7 @@ func TestLinkTree(t *testing.T) { err = content.Store(nil, srcTreeHash, []byte(srcTreeData)) require.NoError(t, err) - return store, testHash + return store, memFs, testHash }, treeData: []byte(`100644 blob a1b2c3d4 README.md 100755 blob a1b2c3d4 scripts/test.sh @@ -210,13 +210,14 @@ func TestLinkTree(t *testing.T) { }, { name: "empty tree", - setupStore: func(t *testing.T) (*cas.Store, string) { + setupStore: func(t *testing.T) (*cas.Store, vfs.FS, string) { t.Helper() - storeDir := helpers.TmpDirWOSymlinks(t) - store := cas.NewStore(storeDir) + memFs := vfs.NewMemMapFS() + require.NoError(t, memFs.MkdirAll("/store", 0755)) + store := cas.NewStore("/store").WithFS(memFs) - return store, "" + return store, memFs, "" }, treeData: []byte(""), wantFiles: []struct { @@ -228,13 +229,14 @@ func TestLinkTree(t *testing.T) { }, { name: "tree with missing content", - setupStore: func(t *testing.T) (*cas.Store, string) { + setupStore: func(t *testing.T) (*cas.Store, vfs.FS, string) { t.Helper() - storeDir := helpers.TmpDirWOSymlinks(t) - store := cas.NewStore(storeDir) + memFs := vfs.NewMemMapFS() + require.NoError(t, memFs.MkdirAll("/store", 0755)) + store := cas.NewStore("/store").WithFS(memFs) - return store, "" + return store, memFs, "" }, treeData: []byte(`100644 blob missing123 README.md`), wantErr: true, @@ -246,14 +248,15 @@ func TestLinkTree(t *testing.T) { t.Parallel() // Setup store - store, _ := tt.setupStore(t) + store, memFs, _ := tt.setupStore(t) // Parse the tree tree, err := git.ParseTree(tt.treeData, "test-repo") require.NoError(t, err) // Create target directory - targetDir := helpers.TmpDirWOSymlinks(t) + targetDir := "/target" + require.NoError(t, memFs.MkdirAll(targetDir, 0755)) // Link the tree err = cas.LinkTree(t.Context(), store, tree, targetDir) @@ -269,26 +272,21 @@ func TestLinkTree(t *testing.T) { path := filepath.Join(targetDir, want.path) // Check if file/directory exists - info, err := os.Stat(path) + info, err := memFs.Stat(path) require.NoError(t, err) assert.Equal(t, want.isDir, info.IsDir()) if !want.isDir { // Check file content - data, err := os.ReadFile(path) + data, err := vfs.ReadFile(memFs, path) require.NoError(t, err) assert.Equal(t, want.content, data) - dataStat, err := os.Stat(path) - require.NoError(t, err) - - // Verify hard link by comparing content. - // We don't compare inode numbers because the test might be running on Windows. + // Verify content matches store by reading from both locations storePath := filepath.Join(store.Path(), want.hash[:2], want.hash) - storeStat, err := os.Stat(storePath) + storeData, err := vfs.ReadFile(memFs, storePath) require.NoError(t, err) - - assert.True(t, os.SameFile(dataStat, storeStat)) + assert.Equal(t, storeData, data) } } }) diff --git a/internal/git/server.go b/internal/git/server.go new file mode 100644 index 0000000000..429b386054 --- /dev/null +++ b/internal/git/server.go @@ -0,0 +1,137 @@ +package git + +import ( + "context" + "fmt" + "net" + "net/http" + "time" + + "github.com/go-git/go-billy/v6/memfs" + gogit "github.com/go-git/go-git/v6" + backendhttp "github.com/go-git/go-git/v6/backend/http" + "github.com/go-git/go-git/v6/plumbing" + "github.com/go-git/go-git/v6/plumbing/object" + "github.com/go-git/go-git/v6/plumbing/transport" + "github.com/go-git/go-git/v6/storage" + "github.com/go-git/go-git/v6/storage/memory" +) + +// Server is a pure-Go HTTP Git server backed by in-memory storage. +// It is intended for use in tests. +type Server struct { + store storage.Storer + repo *gogit.Repository + ln net.Listener + srv *http.Server +} + +// NewServer creates a Server with an empty in-memory repository. +func NewServer() (*Server, error) { + store := memory.NewStorage() + wt := memfs.New() + + repo, err := gogit.Init( + store, + gogit.WithWorkTree(wt), + gogit.WithDefaultBranch(plumbing.NewBranchReferenceName("main")), + ) + if err != nil { + return nil, fmt.Errorf("init repo: %w", err) + } + + return &Server{ + store: store, + repo: repo, + }, nil +} + +// Repo returns the underlying go-git repository so callers can create +// commits, branches, etc. before starting the server. +func (s *Server) Repo() *gogit.Repository { + return s.repo +} + +// CommitFile is a convenience that writes a single file to the worktree and +// commits it. It returns the commit hash. +func (s *Server) CommitFile(path string, data []byte, msg string) error { + w, err := s.repo.Worktree() + if err != nil { + return fmt.Errorf("worktree: %w", err) + } + + f, err := w.Filesystem.Create(path) + if err != nil { + return fmt.Errorf("create file %s: %w", path, err) + } + + if _, err := f.Write(data); err != nil { + return fmt.Errorf("write file %s: %w", path, err) + } + + if err := f.Close(); err != nil { + return fmt.Errorf("close file %s: %w", path, err) + } + + if _, err := w.Add(path); err != nil { + return fmt.Errorf("add %s: %w", path, err) + } + + sig := &object.Signature{ + Name: "Test", + Email: "test@test.com", + When: time.Now(), + } + + _, err = w.Commit(msg, &gogit.CommitOptions{ + Author: sig, + Committer: sig, + }) + if err != nil { + return fmt.Errorf("commit: %w", err) + } + + return nil +} + +// Start begins serving Git HTTP on a random local port. +// Returns the base URL (e.g. "http://127.0.0.1:12345"). +func (s *Server) Start(ctx context.Context) (string, error) { + loader := &singleRepoLoader{store: s.store} + backend := backendhttp.NewBackend(loader) + + var lc net.ListenConfig + + ln, err := lc.Listen(ctx, "tcp", "127.0.0.1:0") + if err != nil { + return "", fmt.Errorf("listen: %w", err) + } + + s.ln = ln + s.srv = &http.Server{ + Handler: backend, + } + + go func() { _ = s.srv.Serve(ln) }() + + return "http://" + ln.Addr().String(), nil +} + +// Close shuts down the server. +func (s *Server) Close() error { + if s.srv != nil { + return s.srv.Close() + } + + return nil +} + +// singleRepoLoader implements transport.Loader by always returning the same +// storer, regardless of the endpoint path. +type singleRepoLoader struct { + store storage.Storer +} + +func (l *singleRepoLoader) Load(_ *transport.Endpoint) (storage.Storer, error) { + return l.store, nil +} diff --git a/internal/git/server_test.go b/internal/git/server_test.go new file mode 100644 index 0000000000..e1d5a11241 --- /dev/null +++ b/internal/git/server_test.go @@ -0,0 +1,120 @@ +package git_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/gruntwork-io/terragrunt/internal/git" + "github.com/gruntwork-io/terragrunt/test/helpers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestServer(t *testing.T) { + t.Parallel() + + t.Run("start and clone", func(t *testing.T) { + t.Parallel() + + srv, err := git.NewServer() + require.NoError(t, err) + t.Cleanup(func() { _ = srv.Close() }) + + require.NoError(t, srv.CommitFile("README.md", []byte("# test repo"), "initial commit")) + + url, err := srv.Start(t.Context()) + require.NoError(t, err) + + cloneDir := helpers.TmpDirWOSymlinks(t) + runner, err := git.NewGitRunner() + require.NoError(t, err) + + runner = runner.WithWorkDir(cloneDir) + + err = runner.Clone(t.Context(), url, false, 0, "") + require.NoError(t, err) + + data, err := os.ReadFile(filepath.Join(cloneDir, "README.md")) + require.NoError(t, err) + assert.Equal(t, "# test repo", string(data)) + }) + + t.Run("ls-remote returns HEAD", func(t *testing.T) { + t.Parallel() + + srv, err := git.NewServer() + require.NoError(t, err) + t.Cleanup(func() { _ = srv.Close() }) + + require.NoError(t, srv.CommitFile("file.txt", []byte("content"), "commit")) + + url, err := srv.Start(t.Context()) + require.NoError(t, err) + + runner, err := git.NewGitRunner() + require.NoError(t, err) + + results, err := runner.LsRemote(t.Context(), url, "HEAD") + require.NoError(t, err) + require.NotEmpty(t, results) + assert.Regexp(t, "^[0-9a-f]{40}$", results[0].Hash) + }) + + t.Run("multiple files", func(t *testing.T) { + t.Parallel() + + srv, err := git.NewServer() + require.NoError(t, err) + t.Cleanup(func() { _ = srv.Close() }) + + require.NoError(t, srv.CommitFile("a.txt", []byte("aaa"), "add a")) + require.NoError(t, srv.CommitFile("dir/b.txt", []byte("bbb"), "add b")) + + url, err := srv.Start(t.Context()) + require.NoError(t, err) + + cloneDir := helpers.TmpDirWOSymlinks(t) + runner, err := git.NewGitRunner() + require.NoError(t, err) + + runner = runner.WithWorkDir(cloneDir) + + err = runner.Clone(t.Context(), url, false, 0, "") + require.NoError(t, err) + + data, err := os.ReadFile(filepath.Join(cloneDir, "a.txt")) + require.NoError(t, err) + assert.Equal(t, "aaa", string(data)) + + data, err = os.ReadFile(filepath.Join(cloneDir, "dir", "b.txt")) + require.NoError(t, err) + assert.Equal(t, "bbb", string(data)) + }) + + t.Run("clone bare", func(t *testing.T) { + t.Parallel() + + srv, err := git.NewServer() + require.NoError(t, err) + t.Cleanup(func() { _ = srv.Close() }) + + require.NoError(t, srv.CommitFile("file.txt", []byte("content"), "commit")) + + url, err := srv.Start(t.Context()) + require.NoError(t, err) + + cloneDir := helpers.TmpDirWOSymlinks(t) + runner, err := git.NewGitRunner() + require.NoError(t, err) + + runner = runner.WithWorkDir(cloneDir) + + err = runner.Clone(t.Context(), url, true, 0, "") + require.NoError(t, err) + + // Bare clone has HEAD file at root + _, err = os.Stat(filepath.Join(cloneDir, "HEAD")) + require.NoError(t, err) + }) +} diff --git a/internal/runner/run/download_source.go b/internal/runner/run/download_source.go index b1f872ff63..33112f043a 100644 --- a/internal/runner/run/download_source.go +++ b/internal/runner/run/download_source.go @@ -381,7 +381,7 @@ func downloadSource( if allowCAS && !isLocalSource { l.Debugf("CAS experiment enabled: attempting to use Content Addressable Storage for source: %s", canonicalSourceURL) - c, err := cas.New(cas.Options{}) + c, err := cas.New() if err != nil { l.Warnf("Failed to initialize CAS: %v. Falling back to standard getter.", err) } else { diff --git a/internal/services/catalog/module/repo.go b/internal/services/catalog/module/repo.go index bef1da2460..bf9f40dd45 100644 --- a/internal/services/catalog/module/repo.go +++ b/internal/services/catalog/module/repo.go @@ -287,7 +287,7 @@ func (repo *Repo) performClone(ctx context.Context, l log.Logger, opts *CloneOpt client := getter.DefaultClient if repo.allowCAS { - c, err := cas.New(cas.Options{}) + c, err := cas.New() if err != nil { return err } diff --git a/internal/vfs/vfs.go b/internal/vfs/vfs.go index 38987de5ac..db53006eea 100644 --- a/internal/vfs/vfs.go +++ b/internal/vfs/vfs.go @@ -12,8 +12,11 @@ import ( "os" "path/filepath" "slices" + "sort" "strings" + "sync" + "github.com/gofrs/flock" "github.com/gruntwork-io/terragrunt/pkg/log" "github.com/spf13/afero" ) @@ -22,11 +25,155 @@ import ( // It provides an abstraction over real and in-memory filesystems. type FS = afero.Fs +// File represents a file in the filesystem. +type File = afero.File + +// HardLinker is an optional interface for filesystems that support hard links. +type HardLinker interface { + LinkIfPossible(oldname, newname string) error +} + +// Unlocker can release a held lock. +type Unlocker interface { + Unlock() error +} + +// Locker is an optional interface for filesystems that support locking. +type Locker interface { + // Lock acquires a blocking lock for the given name. + Lock(name string) (Unlocker, error) + // TryLock attempts a non-blocking lock for the given name. + // Returns the unlocker and true if acquired, nil and false otherwise. + TryLock(name string) (Unlocker, bool, error) +} + +// ErrNoHardLink is returned when a filesystem does not support hard links. +var ErrNoHardLink = errors.New("hard link not supported") + +// ErrNoLock is returned when a filesystem does not support locking. +var ErrNoLock = errors.New("locking not supported") + // NewOSFS returns a filesystem backed by the real operating system filesystem. func NewOSFS() FS { return &osFS{afero.NewOsFs()} } +// NewMemMapFS returns an in-memory filesystem for testing purposes. +// The returned filesystem supports symlink operations via an in-memory link table. +func NewMemMapFS() FS { + return &memMapFS{ + Fs: afero.NewMemMapFs(), + symlinks: make(map[string]string), + locks: make(map[string]*memLock), + } +} + +// FileExists checks if a path exists using the given filesystem. +// Returns (true, nil) if the file exists, (false, nil) if it does not exist, +// and (false, error) for other errors (e.g., permission denied). +func FileExists(vfs FS, path string) (bool, error) { + _, err := vfs.Stat(path) + if err == nil { + return true, nil + } + + if errors.Is(err, fs.ErrNotExist) { + return false, nil + } + + return false, err +} + +// WriteFile writes data to a file on the given filesystem. +func WriteFile(fs FS, filename string, data []byte, perm os.FileMode) error { + dir := filepath.Dir(filename) + if err := fs.MkdirAll(dir, os.ModePerm); err != nil { + return err + } + + return afero.WriteFile(fs, filename, data, perm) +} + +// ReadFile reads the contents of a file from the given filesystem. +func ReadFile(fs FS, filename string) ([]byte, error) { + return afero.ReadFile(fs, filename) +} + +// MkdirTemp creates a temporary directory on the given filesystem. +func MkdirTemp(fs FS, dir, pattern string) (string, error) { + return afero.TempDir(fs, dir, pattern) +} + +// Link creates a hard link. It delegates to LinkIfPossible for filesystems +// that implement the HardLinker interface. +func Link(fs FS, oldname, newname string) error { + linker, ok := fs.(HardLinker) + if !ok { + return &os.LinkError{Op: "link", Old: oldname, New: newname, Err: ErrNoHardLink} + } + + return linker.LinkIfPossible(oldname, newname) +} + +// Symlink creates a symbolic link. It uses afero's SymlinkIfPossible +// which is supported by OsFs and any FS implementing afero.Linker. +func Symlink(fs FS, oldname, newname string) error { + linker, ok := fs.(afero.Linker) + if !ok { + return &os.LinkError{Op: "symlink", Old: oldname, New: newname, Err: afero.ErrNoSymlink} + } + + return linker.SymlinkIfPossible(oldname, newname) +} + +// Lock acquires a blocking lock for the given name on the filesystem. +func Lock(fs FS, name string) (Unlocker, error) { + locker, ok := fs.(Locker) + if !ok { + return nil, ErrNoLock + } + + return locker.Lock(name) +} + +// TryLock attempts a non-blocking lock for the given name on the filesystem. +func TryLock(fs FS, name string) (Unlocker, bool, error) { + locker, ok := fs.(Locker) + if !ok { + return nil, false, ErrNoLock + } + + return locker.TryLock(name) +} + +// WalkDir walks the file tree rooted at root, calling fn for each file or +// directory in the tree, including root. The fn callback receives an fs.DirEntry +// instead of os.FileInfo, which can be more efficient since it does not require +// a stat call for every visited file. +// +// All errors that arise visiting files and directories are filtered by fn: +// see the fs.WalkDirFunc documentation for details. +// +// The files are walked in lexical order, which makes the output deterministic +// but means that for very large directories WalkDir can be inefficient. +// WalkDir does not follow symbolic links. +// +// Adapted from spf13/afero#571 — replace with afero.WalkDir once merged. +func WalkDir(fsys FS, root string, fn fs.WalkDirFunc) error { + info, err := lstatIfPossible(fsys, root) + if err != nil { + err = fn(root, nil, err) + } else { + err = walkDir(fsys, root, FileInfoDirEntry{FileInfo: info}, fn) + } + + if errors.Is(err, filepath.SkipDir) || errors.Is(err, filepath.SkipAll) { + return nil + } + + return err +} + // osFS wraps afero.OsFs with hard link support. type osFS struct { afero.Fs @@ -50,19 +197,36 @@ func (fs *osFS) LstatIfPossible(name string) (os.FileInfo, bool, error) { return info, true, err } -// NewMemMapFS returns an in-memory filesystem for testing purposes. -// The returned filesystem supports symlink operations via an in-memory link table. -func NewMemMapFS() FS { - return &memMapFS{ - Fs: afero.NewMemMapFs(), - symlinks: make(map[string]string), +func (fs *osFS) Lock(name string) (Unlocker, error) { + l := flock.New(name) + if err := l.Lock(); err != nil { + return nil, err } + + return l, nil +} + +func (fs *osFS) TryLock(name string) (Unlocker, bool, error) { + l := flock.New(name) + + acquired, err := l.TryLock() + if err != nil { + return nil, false, err + } + + if !acquired { + return nil, false, nil + } + + return l, true, nil } // memMapFS wraps afero.MemMapFs with in-memory symlink support. type memMapFS struct { afero.Fs symlinks map[string]string + locks map[string]*memLock + locksMu sync.Mutex } func (fs *memMapFS) SymlinkIfPossible(oldname, newname string) error { @@ -76,6 +240,10 @@ func (fs *memMapFS) SymlinkIfPossible(oldname, newname string) error { } func (fs *memMapFS) LinkIfPossible(oldname, newname string) error { + if _, err := fs.Fs.Stat(newname); err == nil { + return &os.LinkError{Op: "link", Old: oldname, New: newname, Err: os.ErrExist} + } + data, err := afero.ReadFile(fs.Fs, oldname) if err != nil { return &os.LinkError{Op: "link", Old: oldname, New: newname, Err: err} @@ -110,65 +278,44 @@ func (fs *memMapFS) LstatIfPossible(name string) (os.FileInfo, bool, error) { return info, false, err } -// FileExists checks if a path exists using the given filesystem. -// Returns (true, nil) if the file exists, (false, nil) if it does not exist, -// and (false, error) for other errors (e.g., permission denied). -func FileExists(vfs FS, path string) (bool, error) { - _, err := vfs.Stat(path) - if err == nil { - return true, nil - } - - if errors.Is(err, fs.ErrNotExist) { - return false, nil - } +func (fs *memMapFS) Lock(name string) (Unlocker, error) { + l := fs.getOrCreateLock(name) + l.mu.Lock() - return false, err + return l, nil } -// WriteFile writes data to a file on the given filesystem. -func WriteFile(fs FS, filename string, data []byte, perm os.FileMode) error { - return afero.WriteFile(fs, filename, data, perm) -} +func (fs *memMapFS) TryLock(name string) (Unlocker, bool, error) { + l := fs.getOrCreateLock(name) -// ReadFile reads the contents of a file from the given filesystem. -func ReadFile(fs FS, filename string) ([]byte, error) { - return afero.ReadFile(fs, filename) -} - -// MkdirTemp creates a temporary directory on the given filesystem. -func MkdirTemp(fs FS, dir, pattern string) (string, error) { - return afero.TempDir(fs, dir, pattern) -} + if !l.mu.TryLock() { + return nil, false, nil + } -// HardLinker is an optional interface for filesystems that support hard links. -type HardLinker interface { - LinkIfPossible(oldname, newname string) error + return l, true, nil } -// ErrNoHardLink is returned when a filesystem does not support hard links. -var ErrNoHardLink = errors.New("hard link not supported") +func (fs *memMapFS) getOrCreateLock(name string) *memLock { + fs.locksMu.Lock() + defer fs.locksMu.Unlock() -// Link creates a hard link. It delegates to LinkIfPossible for filesystems -// that implement the HardLinker interface. -func Link(fs FS, oldname, newname string) error { - linker, ok := fs.(HardLinker) + l, ok := fs.locks[name] if !ok { - return &os.LinkError{Op: "link", Old: oldname, New: newname, Err: ErrNoHardLink} + l = &memLock{} + fs.locks[name] = l } - return linker.LinkIfPossible(oldname, newname) + return l } -// Symlink creates a symbolic link. It uses afero's SymlinkIfPossible -// which is supported by OsFs and any FS implementing afero.Linker. -func Symlink(fs FS, oldname, newname string) error { - linker, ok := fs.(afero.Linker) - if !ok { - return &os.LinkError{Op: "symlink", Old: oldname, New: newname, Err: afero.ErrNoSymlink} - } +// memLock is an in-memory lock backed by a mutex. +type memLock struct { + mu sync.Mutex +} - return linker.SymlinkIfPossible(oldname, newname) +func (l *memLock) Unlock() error { + l.mu.Unlock() + return nil } // ZipDecompressor handles zip archive extraction with configurable limits. @@ -270,34 +417,6 @@ func (z *ZipDecompressor) Unzip(l log.Logger, fs FS, dst, src string, umask os.F return nil } -// containsDotDot checks if a path contains ".." as a path component. -// This is more precise than strings.Contains(name, "..") which would -// reject legitimate files like "file..txt". -func containsDotDot(v string) bool { - if !strings.Contains(v, "..") { - return false - } - - return slices.Contains(strings.FieldsFunc(v, func(r rune) bool { - return r == '/' || r == '\\' - }), "..") -} - -// sanitizeZipPath validates and sanitizes a zip entry path to prevent ZipSlip attacks. -func sanitizeZipPath(dst, name string) (string, error) { - if containsDotDot(name) { - return "", fmt.Errorf("illegal file path in zip: %s", name) - } - - destPath := filepath.Join(dst, filepath.Clean(name)) - - if !strings.HasPrefix(destPath, filepath.Clean(dst)+string(os.PathSeparator)) { - return "", fmt.Errorf("illegal destination path in zip: %s", destPath) - } - - return destPath, nil -} - // extractZipFile extracts a single file from a zip archive. func (z *ZipDecompressor) extractZipFile(l log.Logger, fs FS, dst string, zipFile *zip.File, umask os.FileMode, totalSize *int64) error { destPath, err := sanitizeZipPath(dst, zipFile.Name) @@ -322,57 +441,6 @@ func (z *ZipDecompressor) extractZipFile(l log.Logger, fs FS, dst string, zipFil return z.extractRegularFile(l, fs, destPath, zipFile, umask, totalSize) } -// validateSymlinkTarget validates that a symlink target doesn't escape the destination directory. -func validateSymlinkTarget(dst, linkPath, target string) error { - // Resolve the target relative to the link's directory - absTarget := target - if !filepath.IsAbs(target) { - absTarget = filepath.Join(filepath.Dir(linkPath), target) - } - - absTarget = filepath.Clean(absTarget) - cleanDst := filepath.Clean(dst) - - // Ensure it stays within dst - if !strings.HasPrefix(absTarget, cleanDst+string(os.PathSeparator)) && absTarget != cleanDst { - return fmt.Errorf("symlink target escapes destination: %s -> %s", linkPath, target) - } - - return nil -} - -// extractSymlink extracts a symlink from a zip file. -func extractSymlink(l log.Logger, fs FS, dst, destPath string, zipFile *zip.File) error { - rc, err := zipFile.Open() - if err != nil { - return fmt.Errorf("failed to open file %q: %w", zipFile.Name, err) - } - - defer func() { - if closeErr := rc.Close(); closeErr != nil { - l.Warnf("Error closing file %q: %v", zipFile.Name, closeErr) - } - }() - - targetBytes, err := io.ReadAll(rc) - if err != nil { - return fmt.Errorf("failed to read file %q: %w", zipFile.Name, err) - } - - target := string(targetBytes) - - // Validate symlink target doesn't escape destination - if err := validateSymlinkTarget(dst, destPath, target); err != nil { - return err - } - - if err := fs.MkdirAll(filepath.Dir(destPath), os.ModePerm); err != nil { - return fmt.Errorf("failed to create directory %q: %w", filepath.Dir(destPath), err) - } - - return Symlink(fs, target, destPath) -} - // extractRegularFile extracts a regular file from a zip file. func (z *ZipDecompressor) extractRegularFile( l log.Logger, @@ -438,6 +506,17 @@ func (z *ZipDecompressor) extractRegularFile( return nil } +// FileInfoDirEntry wraps os.FileInfo to implement fs.DirEntry. +// Adapted from spf13/afero#571 — replace with afero equivalent once merged. +type FileInfoDirEntry struct { + FileInfo os.FileInfo +} + +func (d FileInfoDirEntry) Name() string { return d.FileInfo.Name() } +func (d FileInfoDirEntry) IsDir() bool { return d.FileInfo.IsDir() } +func (d FileInfoDirEntry) Type() fs.FileMode { return d.FileInfo.Mode().Type() } +func (d FileInfoDirEntry) Info() (fs.FileInfo, error) { return d.FileInfo, nil } + // limitedReader wraps a reader and enforces a size limit. type limitedReader struct { reader io.Reader @@ -459,6 +538,171 @@ func (r *limitedReader) Read(p []byte) (int, error) { return n, err } +// lstatIfPossible calls Lstat if the filesystem supports it, otherwise Stat. +func lstatIfPossible(fsys FS, path string) (os.FileInfo, error) { + if lstater, ok := fsys.(afero.Lstater); ok { + info, _, err := lstater.LstatIfPossible(path) + return info, err + } + + return fsys.Stat(path) +} + +// walkDir recursively descends path, calling walkDirFn. +// Adapted from https://go.dev/src/path/filepath/path.go +func walkDir(fsys FS, path string, d fs.DirEntry, walkDirFn fs.WalkDirFunc) error { + if err := walkDirFn(path, d, nil); err != nil || !d.IsDir() { + if errors.Is(err, filepath.SkipDir) && d.IsDir() { + err = nil + } + + return err + } + + entries, err := readDirEntries(fsys, path) + if err != nil { + err = walkDirFn(path, d, err) + if err != nil { + if errors.Is(err, filepath.SkipDir) && d.IsDir() { + err = nil + } + + return err + } + } + + for _, entry := range entries { + name := filepath.Join(path, entry.Name()) + if err := walkDir(fsys, name, entry, walkDirFn); err != nil { + if errors.Is(err, filepath.SkipDir) { + break + } + + return err + } + } + + return nil +} + +// readDirEntries reads the directory named by dirname and returns +// a sorted list of directory entries. +func readDirEntries(fsys FS, dirname string) ([]fs.DirEntry, error) { + f, err := fsys.Open(dirname) + if err != nil { + return nil, err + } + + defer func() { + _ = f.Close() + }() + + if rdf, ok := f.(fs.ReadDirFile); ok { + entries, err := rdf.ReadDir(-1) + if err != nil { + return nil, err + } + + sort.Slice(entries, func(i, j int) bool { return entries[i].Name() < entries[j].Name() }) + + return entries, nil + } + + infos, err := f.Readdir(-1) + if err != nil { + return nil, err + } + + entries := make([]fs.DirEntry, len(infos)) + + for i, info := range infos { + entries[i] = FileInfoDirEntry{FileInfo: info} + } + + sort.Slice(entries, func(i, j int) bool { return entries[i].Name() < entries[j].Name() }) + + return entries, nil +} + +// containsDotDot checks if a path contains ".." as a path component. +// This is more precise than strings.Contains(name, "..") which would +// reject legitimate files like "file..txt". +func containsDotDot(v string) bool { + if !strings.Contains(v, "..") { + return false + } + + return slices.Contains(strings.FieldsFunc(v, func(r rune) bool { + return r == '/' || r == '\\' + }), "..") +} + +// sanitizeZipPath validates and sanitizes a zip entry path to prevent ZipSlip attacks. +func sanitizeZipPath(dst, name string) (string, error) { + if containsDotDot(name) { + return "", fmt.Errorf("illegal file path in zip: %s", name) + } + + destPath := filepath.Join(dst, filepath.Clean(name)) + + if !strings.HasPrefix(destPath, filepath.Clean(dst)+string(os.PathSeparator)) { + return "", fmt.Errorf("illegal destination path in zip: %s", destPath) + } + + return destPath, nil +} + +// validateSymlinkTarget validates that a symlink target doesn't escape the destination directory. +func validateSymlinkTarget(dst, linkPath, target string) error { + // Resolve the target relative to the link's directory + absTarget := target + if !filepath.IsAbs(target) { + absTarget = filepath.Join(filepath.Dir(linkPath), target) + } + + absTarget = filepath.Clean(absTarget) + cleanDst := filepath.Clean(dst) + + // Ensure it stays within dst + if !strings.HasPrefix(absTarget, cleanDst+string(os.PathSeparator)) && absTarget != cleanDst { + return fmt.Errorf("symlink target escapes destination: %s -> %s", linkPath, target) + } + + return nil +} + +// extractSymlink extracts a symlink from a zip file. +func extractSymlink(l log.Logger, fs FS, dst, destPath string, zipFile *zip.File) error { + rc, err := zipFile.Open() + if err != nil { + return fmt.Errorf("failed to open file %q: %w", zipFile.Name, err) + } + + defer func() { + if closeErr := rc.Close(); closeErr != nil { + l.Warnf("Error closing file %q: %v", zipFile.Name, closeErr) + } + }() + + targetBytes, err := io.ReadAll(rc) + if err != nil { + return fmt.Errorf("failed to read file %q: %w", zipFile.Name, err) + } + + target := string(targetBytes) + + // Validate symlink target doesn't escape destination + if err := validateSymlinkTarget(dst, destPath, target); err != nil { + return err + } + + if err := fs.MkdirAll(filepath.Dir(destPath), os.ModePerm); err != nil { + return fmt.Errorf("failed to create directory %q: %w", filepath.Dir(destPath), err) + } + + return Symlink(fs, target, destPath) +} + // applyUmask applies a umask to a file mode. func applyUmask(mode, umask os.FileMode) os.FileMode { return mode &^ umask diff --git a/internal/vfs/vfs_test.go b/internal/vfs/vfs_test.go index 9676ffaaf5..a714bbe030 100644 --- a/internal/vfs/vfs_test.go +++ b/internal/vfs/vfs_test.go @@ -3,6 +3,7 @@ package vfs_test import ( "archive/zip" "bytes" + "io/fs" "os" "path/filepath" "testing" @@ -452,143 +453,6 @@ func TestUnzipWithSymlinks(t *testing.T) { assert.Equal(t, []byte("target content"), linkData) } -// createZipArchive creates a zip archive in memory with the given files. -func createZipArchive(t *testing.T, files map[string][]byte) []byte { - t.Helper() - - var buf bytes.Buffer - - w := zip.NewWriter(&buf) - - for name, content := range files { - f, err := w.Create(name) - require.NoError(t, err) - - _, err = f.Write(content) - require.NoError(t, err) - } - - require.NoError(t, w.Close()) - - return buf.Bytes() -} - -// createZipArchiveWithDirs creates a zip archive that includes directory entries. -func createZipArchiveWithDirs(t *testing.T, files map[string][]byte) []byte { - t.Helper() - - var buf bytes.Buffer - - w := zip.NewWriter(&buf) - - for name, content := range files { - if content == nil { - _, err := w.Create(name) - require.NoError(t, err) - - continue - } - - f, err := w.Create(name) - require.NoError(t, err) - - _, err = f.Write(content) - require.NoError(t, err) - } - - require.NoError(t, w.Close()) - - return buf.Bytes() -} - -// createZipArchiveUnsafe creates a zip archive with potentially malicious paths (for testing ZipSlip). -func createZipArchiveUnsafe(t *testing.T, files map[string][]byte) []byte { - t.Helper() - - var buf bytes.Buffer - - w := zip.NewWriter(&buf) - - for name, content := range files { - header := &zip.FileHeader{ - Name: name, - Method: zip.Deflate, - } - - f, err := w.CreateHeader(header) - require.NoError(t, err) - - _, err = f.Write(content) - require.NoError(t, err) - } - - require.NoError(t, w.Close()) - - return buf.Bytes() -} - -// createZipArchiveWithMode creates a zip archive with a single file with specific permissions. -func createZipArchiveWithMode(t *testing.T, name string, content []byte, mode os.FileMode) []byte { - t.Helper() - - var buf bytes.Buffer - - w := zip.NewWriter(&buf) - - header := &zip.FileHeader{ - Name: name, - Method: zip.Deflate, - } - header.SetMode(mode) - - f, err := w.CreateHeader(header) - require.NoError(t, err) - - _, err = f.Write(content) - require.NoError(t, err) - - require.NoError(t, w.Close()) - - return buf.Bytes() -} - -// createZipArchiveWithSymlink creates a zip archive with a regular file and a symlink to it. -func createZipArchiveWithSymlink(t *testing.T, targetName string, targetContent []byte, linkName, linkTarget string) []byte { - t.Helper() - - var buf bytes.Buffer - - w := zip.NewWriter(&buf) - - targetHeader := &zip.FileHeader{ - Name: targetName, - Method: zip.Deflate, - } - targetHeader.SetMode(0644) - - f, err := w.CreateHeader(targetHeader) - require.NoError(t, err) - - _, err = f.Write(targetContent) - require.NoError(t, err) - - linkHeader := &zip.FileHeader{ - Name: linkName, - Method: zip.Deflate, - } - linkHeader.SetMode(os.ModeSymlink | 0777) - - linkFile, err := w.CreateHeader(linkHeader) - require.NoError(t, err) - - _, err = linkFile.Write([]byte(linkTarget)) - require.NoError(t, err) - - require.NoError(t, w.Close()) - - return buf.Bytes() -} - func TestContainsDotDot(t *testing.T) { t.Parallel() @@ -870,6 +734,367 @@ func TestUnzipSymlinkEscape(t *testing.T) { }) } +func TestWalkDir(t *testing.T) { + t.Parallel() + + t.Run("walks files in lexical order", func(t *testing.T) { + t.Parallel() + + memFs := vfs.NewMemMapFS() + require.NoError(t, vfs.WriteFile(memFs, "/root/b.txt", []byte("b"), 0644)) + require.NoError(t, vfs.WriteFile(memFs, "/root/a.txt", []byte("a"), 0644)) + require.NoError(t, vfs.WriteFile(memFs, "/root/c.txt", []byte("c"), 0644)) + + var names []string + + err := vfs.WalkDir(memFs, "/root", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + names = append(names, d.Name()) + + return nil + }) + + require.NoError(t, err) + assert.Equal(t, []string{"root", "a.txt", "b.txt", "c.txt"}, names) + }) + + t.Run("walks nested directories", func(t *testing.T) { + t.Parallel() + + memFs := vfs.NewMemMapFS() + require.NoError(t, vfs.WriteFile(memFs, "/root/dir/nested.txt", []byte("n"), 0644)) + require.NoError(t, vfs.WriteFile(memFs, "/root/top.txt", []byte("t"), 0644)) + + var paths []string + + err := vfs.WalkDir(memFs, "/root", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + paths = append(paths, path) + + return nil + }) + + require.NoError(t, err) + assert.Equal(t, []string{"/root", "/root/dir", "/root/dir/nested.txt", "/root/top.txt"}, paths) + }) + + t.Run("empty directory", func(t *testing.T) { + t.Parallel() + + memFs := vfs.NewMemMapFS() + require.NoError(t, memFs.MkdirAll("/empty", 0755)) + + var paths []string + + err := vfs.WalkDir(memFs, "/empty", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + paths = append(paths, path) + + return nil + }) + + require.NoError(t, err) + assert.Equal(t, []string{"/empty"}, paths) + }) + + t.Run("SkipDir skips directory", func(t *testing.T) { + t.Parallel() + + memFs := vfs.NewMemMapFS() + require.NoError(t, vfs.WriteFile(memFs, "/root/skip/hidden.txt", []byte("h"), 0644)) + require.NoError(t, vfs.WriteFile(memFs, "/root/keep/visible.txt", []byte("v"), 0644)) + + var paths []string + + err := vfs.WalkDir(memFs, "/root", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + if d.IsDir() && d.Name() == "skip" { + return filepath.SkipDir + } + + paths = append(paths, path) + + return nil + }) + + require.NoError(t, err) + assert.Equal(t, []string{"/root", "/root/keep", "/root/keep/visible.txt"}, paths) + }) + + t.Run("nonexistent root returns error to callback", func(t *testing.T) { + t.Parallel() + + memFs := vfs.NewMemMapFS() + + var callbackErr error + + err := vfs.WalkDir(memFs, "/nonexistent", func(path string, d fs.DirEntry, err error) error { + callbackErr = err + return err + }) + + require.Error(t, err) + require.Error(t, callbackErr) + }) + + t.Run("DirEntry reports correct types", func(t *testing.T) { + t.Parallel() + + memFs := vfs.NewMemMapFS() + require.NoError(t, memFs.MkdirAll("/root/subdir", 0755)) + require.NoError(t, vfs.WriteFile(memFs, "/root/file.txt", []byte("f"), 0644)) + + var dirs, files int + + err := vfs.WalkDir(memFs, "/root", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + if d.IsDir() { + dirs++ + } else { + files++ + } + + return nil + }) + + require.NoError(t, err) + assert.Equal(t, 2, dirs) // /root, /root/subdir + assert.Equal(t, 1, files) // /root/file.txt + }) +} + +func TestWalkDir_OSFS(t *testing.T) { + t.Parallel() + + osFs := vfs.NewOSFS() + root := t.TempDir() + + require.NoError(t, os.MkdirAll(filepath.Join(root, "sub"), 0755)) + require.NoError(t, os.WriteFile(filepath.Join(root, "a.txt"), []byte("a"), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(root, "sub", "b.txt"), []byte("b"), 0644)) + + var paths []string + + err := vfs.WalkDir(osFs, root, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + rel, relErr := filepath.Rel(root, path) + require.NoError(t, relErr) + + paths = append(paths, rel) + + return nil + }) + + require.NoError(t, err) + assert.Equal(t, []string{".", "a.txt", "sub", filepath.Join("sub", "b.txt")}, paths) +} + +func TestLock(t *testing.T) { + t.Parallel() + + t.Run("memMapFS lock and unlock", func(t *testing.T) { + t.Parallel() + + memFs := vfs.NewMemMapFS() + + lock, err := vfs.Lock(memFs, "test.lock") + require.NoError(t, err) + require.NotNil(t, lock) + + // TryLock should fail while held + _, acquired, err := vfs.TryLock(memFs, "test.lock") + require.NoError(t, err) + assert.False(t, acquired) + + require.NoError(t, lock.Unlock()) + + // TryLock should succeed after unlock + lock2, acquired, err := vfs.TryLock(memFs, "test.lock") + require.NoError(t, err) + assert.True(t, acquired) + require.NoError(t, lock2.Unlock()) + }) + + t.Run("osFS lock and unlock", func(t *testing.T) { + t.Parallel() + + osFs := vfs.NewOSFS() + lockPath := filepath.Join(t.TempDir(), "test.lock") + + lock, err := vfs.Lock(osFs, lockPath) + require.NoError(t, err) + require.NotNil(t, lock) + require.NoError(t, lock.Unlock()) + }) + + t.Run("unsupported filesystem returns error", func(t *testing.T) { + t.Parallel() + + readOnlyFs := afero.NewReadOnlyFs(vfs.NewMemMapFS()) + + _, err := vfs.Lock(readOnlyFs, "test.lock") + require.ErrorIs(t, err, vfs.ErrNoLock) + + _, _, err = vfs.TryLock(readOnlyFs, "test.lock") + require.ErrorIs(t, err, vfs.ErrNoLock) + }) +} + +// createZipArchive creates a zip archive in memory with the given files. +func createZipArchive(t *testing.T, files map[string][]byte) []byte { + t.Helper() + + var buf bytes.Buffer + + w := zip.NewWriter(&buf) + + for name, content := range files { + f, err := w.Create(name) + require.NoError(t, err) + + _, err = f.Write(content) + require.NoError(t, err) + } + + require.NoError(t, w.Close()) + + return buf.Bytes() +} + +// createZipArchiveWithDirs creates a zip archive that includes directory entries. +func createZipArchiveWithDirs(t *testing.T, files map[string][]byte) []byte { + t.Helper() + + var buf bytes.Buffer + + w := zip.NewWriter(&buf) + + for name, content := range files { + if content == nil { + _, err := w.Create(name) + require.NoError(t, err) + + continue + } + + f, err := w.Create(name) + require.NoError(t, err) + + _, err = f.Write(content) + require.NoError(t, err) + } + + require.NoError(t, w.Close()) + + return buf.Bytes() +} + +// createZipArchiveUnsafe creates a zip archive with potentially malicious paths (for testing ZipSlip). +func createZipArchiveUnsafe(t *testing.T, files map[string][]byte) []byte { + t.Helper() + + var buf bytes.Buffer + + w := zip.NewWriter(&buf) + + for name, content := range files { + header := &zip.FileHeader{ + Name: name, + Method: zip.Deflate, + } + + f, err := w.CreateHeader(header) + require.NoError(t, err) + + _, err = f.Write(content) + require.NoError(t, err) + } + + require.NoError(t, w.Close()) + + return buf.Bytes() +} + +// createZipArchiveWithMode creates a zip archive with a single file with specific permissions. +func createZipArchiveWithMode(t *testing.T, name string, content []byte, mode os.FileMode) []byte { + t.Helper() + + var buf bytes.Buffer + + w := zip.NewWriter(&buf) + + header := &zip.FileHeader{ + Name: name, + Method: zip.Deflate, + } + header.SetMode(mode) + + f, err := w.CreateHeader(header) + require.NoError(t, err) + + _, err = f.Write(content) + require.NoError(t, err) + + require.NoError(t, w.Close()) + + return buf.Bytes() +} + +// createZipArchiveWithSymlink creates a zip archive with a regular file and a symlink to it. +func createZipArchiveWithSymlink(t *testing.T, targetName string, targetContent []byte, linkName, linkTarget string) []byte { + t.Helper() + + var buf bytes.Buffer + + w := zip.NewWriter(&buf) + + targetHeader := &zip.FileHeader{ + Name: targetName, + Method: zip.Deflate, + } + targetHeader.SetMode(0644) + + f, err := w.CreateHeader(targetHeader) + require.NoError(t, err) + + _, err = f.Write(targetContent) + require.NoError(t, err) + + linkHeader := &zip.FileHeader{ + Name: linkName, + Method: zip.Deflate, + } + linkHeader.SetMode(os.ModeSymlink | 0777) + + linkFile, err := w.CreateHeader(linkHeader) + require.NoError(t, err) + + _, err = linkFile.Write([]byte(linkTarget)) + require.NoError(t, err) + + require.NoError(t, w.Close()) + + return buf.Bytes() +} + // createZipArchiveWithNestedSymlink creates a zip with a symlink in a subdirectory. func createZipArchiveWithNestedSymlink(t *testing.T) []byte { t.Helper()