Skip to content

Commit e99a726

Browse files
authored
Merge pull request #3 from muxi-ai/feat/comprehensive-input-validation
feat: add comprehensive input validation for API parameters
2 parents 824c7ea + 1e307d9 commit e99a726

File tree

3 files changed

+467
-2
lines changed

3 files changed

+467
-2
lines changed

onellm/files.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,10 @@ def upload(
113113
max_size=max_size,
114114
name="file data"
115115
)
116-
filename = getattr(file, "name", kwargs.pop("filename", "file.bin"))
116+
# Check for file.name first, only pop from kwargs if needed
117+
filename = getattr(file, "name", None)
118+
if not filename:
119+
filename = kwargs.pop("filename", "file.bin")
117120

118121
else:
119122
raise InvalidRequestError(
@@ -205,7 +208,10 @@ async def aupload(
205208
max_size=max_size,
206209
name="file data"
207210
)
208-
filename = getattr(file, "name", kwargs.pop("filename", "file.bin"))
211+
# Check for file.name first, only pop from kwargs if needed
212+
filename = getattr(file, "name", None)
213+
if not filename:
214+
filename = kwargs.pop("filename", "file.bin")
209215

210216
else:
211217
raise InvalidRequestError(

onellm/utils/file_validator.py

Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Unified interface for LLM providers using OpenAI format
5+
# https://github.com/muxi-ai/onellm
6+
#
7+
# Copyright (C) 2025 Ran Aroussi
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
21+
"""
22+
File validation utilities for OneLLM.
23+
24+
This module provides security-focused file validation to prevent common attacks
25+
like directory traversal, and to enforce size and type constraints.
26+
"""
27+
28+
import mimetypes
29+
from pathlib import Path
30+
from typing import Optional, Set
31+
32+
from ..errors import InvalidRequestError
33+
34+
# Default maximum file size: 100MB
35+
DEFAULT_MAX_FILE_SIZE = 100 * 1024 * 1024
36+
37+
# Default allowed file extensions
38+
# These are common file types used with LLM APIs
39+
DEFAULT_ALLOWED_EXTENSIONS: Set[str] = {
40+
# Audio formats (for transcription, translation)
41+
".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm", ".ogg", ".flac",
42+
# Image formats (for vision models)
43+
".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".tiff", ".tif",
44+
# Document formats
45+
".pdf", ".txt", ".json", ".jsonl", ".csv", ".tsv",
46+
".doc", ".docx", ".xls", ".xlsx", ".ppt", ".pptx",
47+
# Code and data formats
48+
".py", ".js", ".ts", ".java", ".cpp", ".c", ".h",
49+
".xml", ".yaml", ".yml", ".toml", ".ini",
50+
# Archive formats
51+
".zip", ".tar", ".gz", ".bz2", ".7z",
52+
# Video formats (for future support)
53+
".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv",
54+
}
55+
56+
57+
class FileValidator:
58+
"""
59+
Validates file paths and contents for security and compliance.
60+
61+
This class provides methods to:
62+
- Validate file paths to prevent directory traversal attacks
63+
- Enforce file size limits to prevent DoS attacks
64+
- Validate file types to prevent uploading malicious files
65+
- Safely read file contents
66+
"""
67+
68+
@staticmethod
69+
def validate_file_path(
70+
file_path: str,
71+
max_size: Optional[int] = None,
72+
allowed_extensions: Optional[Set[str]] = None,
73+
validate_mime: bool = True,
74+
) -> Path:
75+
"""
76+
Validate and normalize a file path for security.
77+
78+
This method performs comprehensive validation including:
79+
- Path existence and type checking
80+
- Directory traversal prevention
81+
- File size validation
82+
- Extension validation
83+
- MIME type validation
84+
85+
Args:
86+
file_path: Path to the file to validate
87+
max_size: Maximum allowed file size in bytes (default: 100MB)
88+
allowed_extensions: Set of allowed file extensions (default: common types)
89+
validate_mime: Whether to validate MIME type matches extension
90+
91+
Returns:
92+
Validated and normalized Path object
93+
94+
Raises:
95+
InvalidRequestError: If any validation check fails
96+
97+
Example:
98+
>>> path = FileValidator.validate_file_path("data/file.txt")
99+
>>> with open(path, 'rb') as f:
100+
... data = f.read()
101+
"""
102+
# Validate input type
103+
if not file_path or not isinstance(file_path, str):
104+
raise InvalidRequestError(
105+
"file_path must be a non-empty string"
106+
)
107+
108+
# Set defaults
109+
if max_size is None:
110+
max_size = DEFAULT_MAX_FILE_SIZE
111+
if allowed_extensions is None:
112+
allowed_extensions = DEFAULT_ALLOWED_EXTENSIONS
113+
114+
try:
115+
# Convert to Path and resolve to absolute path
116+
# This follows symlinks and normalizes the path
117+
path = Path(file_path).resolve(strict=True)
118+
except FileNotFoundError:
119+
raise InvalidRequestError(
120+
f"File not found: {file_path}"
121+
)
122+
except (OSError, RuntimeError) as e:
123+
raise InvalidRequestError(
124+
f"Invalid file path: {e}"
125+
)
126+
127+
# Verify it's a regular file (not a directory, device, etc.)
128+
if not path.is_file():
129+
if path.is_dir():
130+
raise InvalidRequestError(
131+
f"Path is a directory, not a file: {file_path}"
132+
)
133+
else:
134+
raise InvalidRequestError(
135+
f"Path is not a regular file: {file_path}"
136+
)
137+
138+
# Check for directory traversal attempts
139+
# After resolve(), the path should not contain ".."
140+
# This prevents attacks like "../../../../etc/passwd"
141+
if ".." in path.parts:
142+
raise InvalidRequestError(
143+
f"Directory traversal detected in path: {file_path}"
144+
)
145+
146+
# Validate file extension if restrictions are set
147+
if allowed_extensions:
148+
file_extension = path.suffix.lower()
149+
150+
# Empty extension check
151+
if not file_extension:
152+
raise InvalidRequestError(
153+
f"File has no extension: {path.name}. "
154+
f"Allowed extensions: {', '.join(sorted(allowed_extensions))}"
155+
)
156+
157+
# Check if extension is allowed
158+
if file_extension not in allowed_extensions:
159+
# Create a helpful error message with allowed types
160+
allowed_list = ', '.join(sorted(allowed_extensions)[:10])
161+
if len(allowed_extensions) > 10:
162+
allowed_list += f", ... ({len(allowed_extensions)} total)"
163+
164+
raise InvalidRequestError(
165+
f"File type not allowed: {file_extension}. "
166+
f"Allowed types: {allowed_list}"
167+
)
168+
169+
# Check file size
170+
try:
171+
file_size = path.stat().st_size
172+
except OSError as e:
173+
raise InvalidRequestError(
174+
f"Cannot access file: {e}"
175+
)
176+
177+
# Empty file check
178+
if file_size == 0:
179+
raise InvalidRequestError(
180+
f"File is empty: {path.name}"
181+
)
182+
183+
# Size limit check
184+
if max_size and file_size > max_size:
185+
# Convert to human-readable format
186+
max_mb = max_size / (1024 * 1024)
187+
actual_mb = file_size / (1024 * 1024)
188+
189+
raise InvalidRequestError(
190+
f"File too large: {actual_mb:.2f}MB exceeds limit of {max_mb:.2f}MB. "
191+
f"File: {path.name}"
192+
)
193+
194+
# Validate MIME type matches extension
195+
if validate_mime:
196+
mime_type, _ = mimetypes.guess_type(str(path))
197+
198+
# If we can't determine MIME type, be cautious
199+
if mime_type is None:
200+
# Some extensions might not have MIME types registered
201+
# Only warn for common cases
202+
if path.suffix.lower() not in {'.txt', '.json', '.jsonl', '.csv'}:
203+
raise InvalidRequestError(
204+
f"Cannot determine file type for: {path.name}. "
205+
f"Extension: {path.suffix}"
206+
)
207+
208+
return path
209+
210+
@staticmethod
211+
def safe_read_file(
212+
path: Path,
213+
max_size: Optional[int] = None,
214+
chunk_size: int = 8192,
215+
) -> bytes:
216+
"""
217+
Safely read file contents with memory protection.
218+
219+
This method reads files in chunks to avoid loading huge files
220+
into memory all at once, which could cause memory issues.
221+
222+
Args:
223+
path: Validated Path object to read
224+
max_size: Maximum size to read (default: file size)
225+
chunk_size: Size of chunks to read (default: 8KB)
226+
227+
Returns:
228+
File contents as bytes
229+
230+
Raises:
231+
InvalidRequestError: If file cannot be read or is too large
232+
233+
Example:
234+
>>> path = FileValidator.validate_file_path("data.bin")
235+
>>> data = FileValidator.safe_read_file(path)
236+
"""
237+
if not isinstance(path, Path):
238+
raise InvalidRequestError(
239+
"path must be a Path object (use validate_file_path first)"
240+
)
241+
242+
# Get file size
243+
try:
244+
file_size = path.stat().st_size
245+
except OSError as e:
246+
raise InvalidRequestError(
247+
f"Cannot access file: {e}"
248+
)
249+
250+
# Check against max_size if provided
251+
if max_size and file_size > max_size:
252+
max_mb = max_size / (1024 * 1024)
253+
actual_mb = file_size / (1024 * 1024)
254+
raise InvalidRequestError(
255+
f"File too large to read: {actual_mb:.2f}MB exceeds {max_mb:.2f}MB"
256+
)
257+
258+
# Read file in chunks
259+
try:
260+
chunks = []
261+
bytes_read = 0
262+
263+
with open(path, "rb") as f:
264+
while True:
265+
# Read a chunk
266+
chunk = f.read(chunk_size)
267+
if not chunk:
268+
break
269+
270+
chunks.append(chunk)
271+
bytes_read += len(chunk)
272+
273+
# Double-check we haven't exceeded max_size
274+
# (in case file was modified during reading)
275+
if max_size and bytes_read > max_size:
276+
raise InvalidRequestError(
277+
f"File size exceeded during read: {path.name}"
278+
)
279+
280+
return b"".join(chunks)
281+
282+
except OSError as e:
283+
raise InvalidRequestError(
284+
f"Error reading file: {e}"
285+
)
286+
except MemoryError:
287+
raise InvalidRequestError(
288+
f"File too large to fit in memory: {path.name}"
289+
)
290+
291+
@staticmethod
292+
def validate_bytes_size(
293+
data: bytes,
294+
max_size: Optional[int] = None,
295+
name: str = "data",
296+
) -> None:
297+
"""
298+
Validate size of byte data.
299+
300+
Args:
301+
data: Bytes to validate
302+
max_size: Maximum allowed size in bytes
303+
name: Name for error messages
304+
305+
Raises:
306+
InvalidRequestError: If data is too large
307+
"""
308+
if not isinstance(data, bytes):
309+
raise InvalidRequestError(
310+
f"{name} must be bytes, got {type(data).__name__}"
311+
)
312+
313+
if len(data) == 0:
314+
raise InvalidRequestError(
315+
f"{name} is empty"
316+
)
317+
318+
if max_size and len(data) > max_size:
319+
max_mb = max_size / (1024 * 1024)
320+
actual_mb = len(data) / (1024 * 1024)
321+
raise InvalidRequestError(
322+
f"{name} too large: {actual_mb:.2f}MB exceeds {max_mb:.2f}MB"
323+
)

0 commit comments

Comments
 (0)