Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion internal/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import (
)

var (
ErrNotFound = errors.New("file not found")
ErrNotFound = errors.New("file not found")
ErrInvalidStorageKey = errors.New("invalid storage key")
)

type StorageOptions struct {
Expand Down
59 changes: 53 additions & 6 deletions internal/storage/storage_fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,29 @@ func NewFSStorage(root string) (storage Storage, err error) {
return &fsStorage{root: root}, nil
}

// joinRootSafe returns filepath.Join(fs.root, key) only when key is
// filepath.IsLocal, so lexical ".." segments cannot escape the storage root.
func (fs *fsStorage) joinRootSafe(key string) (filename string, err error) {
if key == "" {
return "", errors.New("key is required")
}
if strings.Contains(key, "\x00") {
return "", ErrInvalidStorageKey
}
k := filepath.FromSlash(strings.Trim(filepath.ToSlash(key), "/"))
if !filepath.IsLocal(k) {
return "", ErrInvalidStorageKey
}
root := filepath.Clean(fs.root)
full := filepath.Join(root, k)
return full, nil
}

func (fs *fsStorage) Stat(key string) (stat Stat, err error) {
filename := filepath.Join(fs.root, key)
filename, err := fs.joinRootSafe(key)
if err != nil {
return nil, err
}
fi, err := os.Lstat(filename)
if err != nil {
if os.IsNotExist(err) || strings.HasSuffix(err.Error(), "not a directory") {
Expand All @@ -43,7 +64,10 @@ func (fs *fsStorage) Stat(key string) (stat Stat, err error) {
}

func (fs *fsStorage) Get(key string) (content io.ReadCloser, stat Stat, err error) {
filename := filepath.Join(fs.root, key)
filename, err := fs.joinRootSafe(key)
if err != nil {
return
}
file, err := os.Open(filename)
if err != nil && (os.IsNotExist(err) || strings.HasSuffix(err.Error(), "not a directory")) {
err = ErrNotFound
Expand All @@ -57,7 +81,10 @@ func (fs *fsStorage) Get(key string) (content io.ReadCloser, stat Stat, err erro
}

func (fs *fsStorage) Put(key string, content io.Reader) (err error) {
filename := filepath.Join(fs.root, key)
filename, err := fs.joinRootSafe(key)
if err != nil {
return err
}
err = ensureDir(filepath.Dir(filename))
if err != nil {
return
Expand All @@ -77,24 +104,44 @@ func (fs *fsStorage) Put(key string, content io.Reader) (err error) {
}

func (fs *fsStorage) Delete(key string) (err error) {
return os.Remove(filepath.Join(fs.root, key))
filename, err := fs.joinRootSafe(key)
if err != nil {
return err
}
return os.Remove(filename)
}

func (fs *fsStorage) List(prefix string) (keys []string, err error) {
dir := strings.TrimSuffix(utils.NormalizePathname(prefix)[1:], "/")
return findFiles(filepath.Join(fs.root, dir), dir)
dir = strings.Trim(strings.TrimSpace(dir), "/")
scanRoot := filepath.Clean(fs.root)
parentKey := ""
if dir != "" {
var ferr error
scanRoot, ferr = fs.joinRootSafe(dir)
if ferr != nil {
return nil, ferr
Comment on lines +121 to +123
}
parentKey = dir
}
return findFiles(scanRoot, parentKey)
}

func (fs *fsStorage) DeleteAll(prefix string) (deletedKeys []string, err error) {
dir := strings.TrimSuffix(utils.NormalizePathname(prefix)[1:], "/")
dir = strings.Trim(strings.TrimSpace(dir), "/")
if dir == "" {
return nil, errors.New("prefix is required")
}
keys, err := fs.List(prefix)
if err != nil {
return
}
err = os.RemoveAll(filepath.Join(fs.root, dir))
absDir, err := fs.joinRootSafe(dir)
if err != nil {
return nil, err
Comment on lines +140 to +142
}
err = os.RemoveAll(absDir)
if err != nil {
return
}
Expand Down
46 changes: 46 additions & 0 deletions internal/storage/storage_fs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,49 @@ func TestFSStorage(t *testing.T) {
t.Fatalf("invalid keys count(%d), shoud be 0", len(keys))
}
}

func TestFSStorageRejectPathTraversal(t *testing.T) {
root := path.Join(os.TempDir(), "storage_traversal_"+rand.Hex.String(8))
fs, err := NewFSStorage(root)
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(root)

attackKeys := []string{
"../outside.txt",
"legacy/../../../tmp/pwned",
`legacy/v111/react@19.2.0/esnext/../../../gh/a/exp@cafe/foo.md#/../../../../../../../../../../tmp/pwned`,
"safe/../../etc/passwd",
"bad\x00surprise",
}
for _, k := range attackKeys {
err = fs.Put(k, bytes.NewBufferString("evil"))
if err != ErrInvalidStorageKey {
t.Fatalf("Put(%q): want ErrInvalidStorageKey, got %v", k, err)
}
if _, err = fs.Stat(k); err != ErrInvalidStorageKey {
t.Fatalf("Stat(%q): want ErrInvalidStorageKey, got %v", k, err)
}
if _, _, err = fs.Get(k); err != ErrInvalidStorageKey {
t.Fatalf("Get(%q): want ErrInvalidStorageKey, got %v", k, err)
}
if err = fs.Delete(k); err != ErrInvalidStorageKey {
t.Fatalf("Delete(%q): want ErrInvalidStorageKey, got %v", k, err)
}
}

err = fs.Put("ok/sub/file.txt", bytes.NewBufferString("hi"))
if err != nil {
t.Fatal(err)
}
got, _, err := fs.Get("ok/sub/file.txt")
if err != nil {
t.Fatal(err)
}
defer got.Close()
b, err := io.ReadAll(got)
if err != nil || string(b) != "hi" {
t.Fatalf("Get ok path: content %q err %v", string(b), err)
}
}
Loading