Skip to content

Commit 3b82d11

Browse files
authored
fix: path traversal when uploading files (#220)
1 parent 3b187e1 commit 3b82d11

File tree

2 files changed

+111
-1
lines changed

2 files changed

+111
-1
lines changed

skynet/modules/ttt/rag/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,17 @@ async def save_files(folder: str, files: list[UploadFile]) -> list[str]:
3131
os.makedirs(folder, exist_ok=True)
3232

3333
for file in files:
34-
file_path = f'{folder}/{file.filename}'
34+
if not file.filename:
35+
raise ValueError("File must have a filename")
36+
37+
# Construct the file path and validate it stays within the intended folder
38+
file_path = os.path.join(folder, file.filename)
39+
resolved_path = os.path.abspath(os.path.realpath(file_path))
40+
resolved_folder = os.path.abspath(os.path.realpath(folder))
41+
42+
if not resolved_path.startswith(resolved_folder + os.sep):
43+
raise ValueError(f"Invalid file path: {file.filename}")
44+
3545
async with aiofiles.open(file_path, 'wb') as f:
3646
await f.write(await file.read())
3747
file_paths.append(file_path)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import os
2+
import tempfile
3+
from unittest.mock import AsyncMock, MagicMock
4+
5+
import pytest
6+
from fastapi import UploadFile
7+
8+
from skynet.modules.ttt.rag.utils import save_files
9+
10+
11+
class TestSaveFiles:
12+
@pytest.fixture
13+
def temp_dir(self):
14+
"""Create a temporary directory for testing"""
15+
with tempfile.TemporaryDirectory() as temp_dir:
16+
yield temp_dir
17+
18+
@pytest.fixture
19+
def mock_upload_file(self):
20+
"""Create a mock UploadFile"""
21+
file = MagicMock(spec=UploadFile)
22+
file.filename = "test.txt"
23+
file.read = AsyncMock(return_value=b"test content")
24+
return file
25+
26+
@pytest.mark.asyncio
27+
async def test_save_files_empty_list(self, temp_dir):
28+
"""Test saving empty file list"""
29+
result = await save_files(temp_dir, [])
30+
assert result == []
31+
32+
@pytest.mark.asyncio
33+
async def test_save_files_valid_file(self, temp_dir, mock_upload_file):
34+
"""Test saving a valid file"""
35+
result = await save_files(temp_dir, [mock_upload_file])
36+
37+
assert len(result) == 1
38+
assert result[0] == os.path.join(temp_dir, "test.txt")
39+
assert os.path.exists(result[0])
40+
41+
with open(result[0], 'rb') as f:
42+
assert f.read() == b"test content"
43+
44+
@pytest.mark.asyncio
45+
async def test_save_files_path_traversal_attack(self, temp_dir):
46+
"""Test prevention of path traversal attacks"""
47+
malicious_file = MagicMock(spec=UploadFile)
48+
malicious_file.filename = "../../../etc/passwd"
49+
malicious_file.read = AsyncMock(return_value=b"malicious content")
50+
51+
# Path traversal should be rejected
52+
with pytest.raises(ValueError, match="Invalid file path"):
53+
await save_files(temp_dir, [malicious_file])
54+
55+
@pytest.mark.asyncio
56+
async def test_save_files_no_filename(self, temp_dir):
57+
"""Test handling of file without filename"""
58+
file_without_name = MagicMock(spec=UploadFile)
59+
file_without_name.filename = None
60+
file_without_name.read = AsyncMock(return_value=b"content")
61+
62+
with pytest.raises(ValueError, match="File must have a filename"):
63+
await save_files(temp_dir, [file_without_name])
64+
65+
@pytest.mark.asyncio
66+
async def test_save_files_multiple_files(self, temp_dir):
67+
"""Test saving multiple files"""
68+
files = []
69+
for i in range(3):
70+
file = MagicMock(spec=UploadFile)
71+
file.filename = f"file_{i}.txt"
72+
file.read = AsyncMock(return_value=f"content {i}".encode())
73+
files.append(file)
74+
75+
result = await save_files(temp_dir, files)
76+
77+
assert len(result) == 3
78+
for i, file_path in enumerate(result):
79+
assert file_path == os.path.join(temp_dir, f"file_{i}.txt")
80+
assert os.path.exists(file_path)
81+
82+
with open(file_path, 'rb') as f:
83+
assert f.read() == f"content {i}".encode()
84+
85+
@pytest.mark.asyncio
86+
async def test_save_files_creates_directory(self):
87+
"""Test that save_files creates the directory if it doesn't exist"""
88+
with tempfile.TemporaryDirectory() as temp_dir:
89+
new_dir = os.path.join(temp_dir, "new_folder")
90+
assert not os.path.exists(new_dir)
91+
92+
mock_file = MagicMock(spec=UploadFile)
93+
mock_file.filename = "test.txt"
94+
mock_file.read = AsyncMock(return_value=b"test content")
95+
96+
result = await save_files(new_dir, [mock_file])
97+
98+
assert os.path.exists(new_dir)
99+
assert len(result) == 1
100+
assert os.path.exists(result[0])

0 commit comments

Comments
 (0)