Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"net"
"net/http"
"net/http/httputil"
Expand Down Expand Up @@ -175,6 +176,23 @@ type Config struct { //nolint:govet // Aligning the struct fields is not necessa
// Default: nil
Views Views `json:"-"`

// RootDir specifies the base directory for SaveFile/SaveFileToStorage uploads.
// Relative paths are resolved against this directory.
//
// Optional. Default: ""
RootDir string `json:"root_dir"`

// RootPerms specifies the permissions used when creating RootDir or RootFs prefixes.
//
// Optional. Default: 0o750
RootPerms fs.FileMode `json:"root_perms"`

// RootFs specifies the filesystem used for SaveFile/SaveFileToStorage uploads.
// When set, RootDir is treated as a relative prefix within the filesystem.
//
// Optional. Default: nil
RootFs fs.FS `json:"-"`

// Views Layout is the global layout for all template render until override on Render function.
//
// Default: ""
Expand Down Expand Up @@ -437,6 +455,15 @@ type Config struct { //nolint:govet // Aligning the struct fields is not necessa
//
// Optional. Default: a provider that returns context.Background()
ServicesShutdownContextProvider func() context.Context

uploadRootDir string
uploadRootEval string
uploadRootPath string
uploadRootFSPrefix string
uploadRootFSWriter interface {
fs.FS
OpenFile(name string, flag int, perm fs.FileMode) (fs.File, error)
}
}

// Default TrustProxyConfig
Expand Down Expand Up @@ -605,6 +632,9 @@ func New(config ...Config) *App {
"zstd": ".fiber.zst",
}
}
if app.config.RootPerms == 0 {
app.config.RootPerms = 0o750
}

if app.config.Immutable {
app.toBytes, app.toString = toBytesImmutable, toStringImmutable
Expand Down Expand Up @@ -642,6 +672,8 @@ func New(config ...Config) *App {
app.config.RequestMethods = DefaultMethods
}

app.configureUploads()

app.config.TrustProxyConfig.ips = make(map[string]struct{}, len(app.config.TrustProxyConfig.Proxies))
for _, ipAddress := range app.config.TrustProxyConfig.Proxies {
app.handleTrustedProxy(ipAddress)
Expand Down
144 changes: 144 additions & 0 deletions app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"mime/multipart"
"net"
"net/http"
Expand Down Expand Up @@ -81,6 +82,149 @@ func testErrorResponse(t *testing.T, err error, resp *http.Response, expectedBod
require.Equal(t, expectedBodyError, string(body), "Response body")
}

type testUploadFS struct {
mkdirPath string
mkdirPerm fs.FileMode
}

func (tfs *testUploadFS) Open(_ string) (fs.File, error) {
_ = tfs
return nil, fs.ErrNotExist
}

func (tfs *testUploadFS) OpenFile(_ string, _ int, _ fs.FileMode) (fs.File, error) {
_ = tfs
return &testUploadFile{buf: &bytes.Buffer{}}, nil
}

func (tfs *testUploadFS) MkdirAll(path string, perm fs.FileMode) error {
tfs.mkdirPath = path
tfs.mkdirPerm = perm
return nil
}

func (tfs *testUploadFS) Remove(_ string) error {
_ = tfs
return nil
}

type testUploadFile struct {
buf *bytes.Buffer
}

func (tf *testUploadFile) Read(p []byte) (int, error) {
//nolint:wrapcheck // test helper passthrough
return tf.buf.Read(p)
}

func (tf *testUploadFile) Write(p []byte) (int, error) {
//nolint:wrapcheck // test helper passthrough
return tf.buf.Write(p)
}

func (tf *testUploadFile) Close() error {
_ = tf
return nil
}

func (tf *testUploadFile) Stat() (fs.FileInfo, error) {
_ = tf
return testUploadFileInfo{name: "upload"}, nil
}

type testUploadFileInfo struct {
name string
}

func (fi testUploadFileInfo) Name() string { return fi.name }
func (fi testUploadFileInfo) Size() int64 {
_ = fi
return 0
}

func (fi testUploadFileInfo) Mode() fs.FileMode {
_ = fi
return 0
}

func (fi testUploadFileInfo) ModTime() time.Time {
_ = fi
return time.Time{}
}

func (fi testUploadFileInfo) IsDir() bool {
_ = fi
return false
}

func (fi testUploadFileInfo) Sys() any {
_ = fi
return nil
}

func TestRootPermsRootFs(t *testing.T) {
t.Parallel()

if runtime.GOOS == "windows" {
t.Skip("root perms are not validated on Windows in this test")
}

tests := []struct {
name string
rootPerm fs.FileMode
wantPerm fs.FileMode
}{
{
name: "default",
rootPerm: 0,
wantPerm: 0o750,
},
{
name: "custom",
rootPerm: 0o700,
wantPerm: 0o700,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

tfs := &testUploadFS{}
New(Config{
RootDir: "uploads",
RootFs: tfs,
RootPerms: tt.rootPerm,
})

if tfs.mkdirPath != "uploads" {
t.Fatalf("expected RootFs prefix %q, got %q", "uploads", tfs.mkdirPath)
}
if tfs.mkdirPerm != tt.wantPerm {
t.Fatalf("expected RootPerms %o, got %o", tt.wantPerm, tfs.mkdirPerm)
}
})
}
}

func TestValidateUploadPathPreservesLeadingDot(t *testing.T) {
t.Parallel()

path := filepath.Join(".hidden", "file.txt")

normalized, err := validateUploadPath(path)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}

if !strings.HasPrefix(normalized.osPath, ".") {
t.Fatalf("expected os path %q to preserve leading dot", normalized.osPath)
}
if normalized.slashPath != ".hidden/file.txt" {
t.Fatalf("expected slash path %q, got %q", ".hidden/file.txt", normalized.slashPath)
}
}

func Test_App_Test_Goroutine_Leak_Compare(t *testing.T) {
t.Parallel()

Expand Down
94 changes: 91 additions & 3 deletions ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@ package fiber
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"io/fs"
"maps"
"mime/multipart"
"os"
"path/filepath"
"strconv"
"strings"
"sync/atomic"
Expand Down Expand Up @@ -511,12 +515,57 @@ func (c *DefaultCtx) IsPreflight() bool {
}

// SaveFile saves any multipart file to disk.
func (*DefaultCtx) SaveFile(fileheader *multipart.FileHeader, path string) error {
return fasthttp.SaveMultipartFile(fileheader, path)
func (c *DefaultCtx) SaveFile(fileheader *multipart.FileHeader, path string) error {
normalized, err := validateUploadPath(path)
if err != nil {
return err
}

if c.app.config.RootFs != nil {
fsPath := storageUploadPath(c.app.config.uploadRootFSPrefix, normalized.slashPath)
err = ensureNoSymlinkFS(c.app.config.RootFs, fsPath)
if err != nil {
return err
}
return saveMultipartFileToFS(fileheader, fsPath, c.app.config.uploadRootFSWriter)
}

fullPath := normalized.osPath
if root := c.app.config.uploadRootDir; root != "" {
fullPath = filepath.Join(root, normalized.osPath)
err = ensureUploadPathWithinRoot(c.app.config.uploadRootEval, fullPath)
if err != nil {
return err
}
}

return fasthttp.SaveMultipartFile(fileheader, fullPath)
}

// SaveFileToStorage saves any multipart file to an external storage system.
func (c *DefaultCtx) SaveFileToStorage(fileheader *multipart.FileHeader, path string, storage Storage) error {
normalized, err := validateUploadPath(path)
if err != nil {
return err
}

if c.app.config.RootFs != nil {
fsPath := storageUploadPath(c.app.config.uploadRootFSPrefix, normalized.slashPath)
err = ensureNoSymlinkFS(c.app.config.RootFs, fsPath)
if err != nil {
return err
}
}
if root := c.app.config.uploadRootDir; root != "" {
fullPath := filepath.Join(root, filepath.FromSlash(normalized.slashPath))
err = ensureUploadPathWithinRoot(c.app.config.uploadRootEval, fullPath)
if err != nil {
return err
}
}

storagePath := storageUploadPath(c.app.config.uploadRootPath, normalized.slashPath)

file, err := fileheader.Open()
if err != nil {
return fmt.Errorf("failed to open: %w", err)
Expand Down Expand Up @@ -546,13 +595,52 @@ func (c *DefaultCtx) SaveFileToStorage(fileheader *multipart.FileHeader, path st

data := append([]byte(nil), buf.Bytes()...)

if err := storage.SetWithContext(c.Context(), path, data, 0); err != nil {
if err := storage.SetWithContext(c.Context(), storagePath, data, 0); err != nil {
return fmt.Errorf("failed to store: %w", err)
}

return nil
}

func saveMultipartFileToFS(
fileheader *multipart.FileHeader,
path string,
fsys interface {
fs.FS
OpenFile(name string, flag int, perm fs.FileMode) (fs.File, error)
},
) error {
file, err := fileheader.Open()
if err != nil {
return fmt.Errorf("failed to open multipart file: %w", err)
}
defer file.Close() //nolint:errcheck // not needed

dst, err := fsys.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644)
if err != nil {
return fmt.Errorf("failed to open upload destination: %w", err)
}
writer, ok := dst.(io.Writer)
if !ok {
closeErr := dst.Close()
if closeErr != nil {
return fmt.Errorf("failed to close upload destination: %w", closeErr)
}
return errors.New("failed to open upload destination for write")
}
if _, err = io.Copy(writer, file); err != nil {
closeErr := dst.Close()
if closeErr != nil {
return fmt.Errorf("failed to close upload destination: %w", closeErr)
}
return fmt.Errorf("failed to copy upload data: %w", err)
}
Comment on lines +631 to +637
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Original copy error is lost when close also fails.

When io.Copy fails and dst.Close() also fails, only the close error is returned. The original copy error provides more diagnostic value. Consider using errors.Join or wrapping both.

Proposed fix to preserve both errors
 	if _, err = io.Copy(writer, file); err != nil {
 		closeErr := dst.Close()
 		if closeErr != nil {
-			return fmt.Errorf("failed to close upload destination: %w", closeErr)
+			return fmt.Errorf("failed to copy upload data: %w (also failed to close: %v)", err, closeErr)
 		}
 		return fmt.Errorf("failed to copy upload data: %w", err)
 	}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if _, err = io.Copy(writer, file); err != nil {
closeErr := dst.Close()
if closeErr != nil {
return fmt.Errorf("failed to close upload destination: %w", closeErr)
}
return fmt.Errorf("failed to copy upload data: %w", err)
}
if _, err = io.Copy(writer, file); err != nil {
closeErr := dst.Close()
if closeErr != nil {
return fmt.Errorf("failed to copy upload data: %w (also failed to close: %v)", err, closeErr)
}
return fmt.Errorf("failed to copy upload data: %w", err)
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ctx.go` around lines 631 - 637, The current error handling in the io.Copy
failure path discards the original copy error when dst.Close() also fails;
update the block that handles io.Copy(writer, file) in ctx.go to preserve both
errors by capturing the copy error (err) and the close error (closeErr) and
returning a combined error (e.g., using errors.Join(copyErr, closeErr) or
wrapping them together) instead of returning only closeErr; ensure references to
writer, file, dst.Close(), and io.Copy remain and that the returned error
message includes both failure contexts.

if err := dst.Close(); err != nil {
return fmt.Errorf("failed to close upload destination: %w", err)
}
return nil
}

// Secure returns whether a secure connection was established.
func (c *DefaultCtx) Secure() bool {
return c.Protocol() == schemeHTTPS
Expand Down
Loading
Loading