Skip to content

Commit 1b0e519

Browse files
committed
Implement scoped root validation in pathsec sentinel
1 parent 7d8684a commit 1b0e519

File tree

4 files changed

+122
-81
lines changed

4 files changed

+122
-81
lines changed

nltk/corpus/reader/api.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -221,21 +221,27 @@ def raw(self, fileids=None):
221221

222222
def open(self, file):
223223
"""
224-
Return an open stream that can be used to read the given file.
224+
Return an open stream for the given file.
225+
Security patched: prevents path traversal and scoped escapes.
225226
"""
226-
# -------- SECURITY PATCH START --------
227-
file = str(file)
227+
# Layer 1: Lexical guard
228+
if os.path.isabs(file) or ".." in file.replace("\\", "/"):
229+
raise ValueError(f"CorpusReader paths must be relative: {file}")
228230

229-
if os.path.isabs(file):
230-
raise ValueError("Absolute paths are not allowed")
231+
path = self._root.join(file)
231232

232-
if ".." in file.replace("\\", "/").split("/"):
233-
raise ValueError("Path traversal attempt blocked")
234-
# -------- SECURITY PATCH END --------
233+
# Layer 2: Scoped resolved guard (Fixes symlink escape test)
234+
from nltk.pathsec import validate_path
235235

236-
encoding = self.encoding(file)
237-
stream = self._root.join(file).open(encoding)
238-
return stream
236+
validate_path(path, context="CorpusReader", required_root=self._root)
237+
238+
# --- FIX: Handle dict-based encodings (e.g., UDHR corpus) ---
239+
encoding = self._encoding
240+
if isinstance(encoding, dict):
241+
encoding = encoding.get(file)
242+
243+
# Layer 3: Global sentinel check happens inside path.open()
244+
return path.open(encoding=encoding)
239245

240246
def encoding(self, file):
241247
"""

nltk/data.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,18 +1122,20 @@ def _open(resource_url):
11221122
protocol, path_ = resource_url.split(":", 1)
11231123

11241124
if protocol == "nltk":
1125+
# If find() or .open() raises a ValueError (security) or LookupError,
1126+
# let it bubble up or handle it based on load() logic.
11251127
return find(path_).open()
11261128
elif protocol == "file":
1127-
import urllib.request
1128-
1129-
local_path = urllib.request.url2pathname(path_)
1129+
local_path = url2pathname(path_)
11301130
try:
1131+
# 1. Attempt to use NLTK's standard search paths (Safe/Normalized)
11311132
return find(local_path).open()
1132-
except LookupError:
1133-
# FIX: Use _secure_open to ensure the sentinel validates
1134-
# paths that find() cannot resolve.
1133+
except (LookupError, ValueError):
1134+
# 2. Fallback for absolute paths (e.g., file:///etc/passwd)
1135+
# This ensures even direct file access hits the pathsec sentinel.
11351136
return _secure_open(local_path, "rb")
11361137
else:
1138+
# Network protocols (http, https, ftp)
11371139
return _secure_urlopen(resource_url)
11381140

11391141

nltk/pathsec.py

Lines changed: 67 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#
88
"""Centralized I/O security sentinel for NLTK."""
99

10+
"""Centralized I/O security sentinel for NLTK."""
1011
import builtins
1112
import ipaddress
1213
import os
@@ -32,6 +33,7 @@ def _get_allowed_roots():
3233

3334
current_paths = []
3435
if "nltk.data" in sys.modules:
36+
# Accessing nltk.data.path via sys.modules to avoid top-level circularity
3537
current_paths = list(getattr(sys.modules["nltk.data"], "path", []))
3638

3739
env_paths = os.environ.get("NLTK_DATA", "")
@@ -41,13 +43,13 @@ def _get_allowed_roots():
4143
return _ALLOWED_ROOTS_CACHE
4244

4345
roots = set()
44-
# FIX: Use os.pathsep for environment variables (Copilot High)
4546
for p in current_paths + env_paths.split(os.pathsep):
4647
if p:
4748
try:
49+
# Handle both string paths and PathPointer objects
4850
raw_p = p.path if hasattr(p, "path") else p
4951
roots.add(Path(str(raw_p)).resolve())
50-
except Exception:
52+
except (OSError, ValueError, RuntimeError):
5153
continue
5254

5355
import tempfile
@@ -57,16 +59,23 @@ def _get_allowed_roots():
5759
p = Path(loc).expanduser().resolve()
5860
if p.exists():
5961
roots.add(p)
60-
except Exception:
62+
except (OSError, ValueError, RuntimeError):
6163
continue
6264

6365
_ALLOWED_ROOTS_CACHE = roots
6466
_LAST_DATA_PATHS = current_state
6567
return roots
6668

6769

68-
def validate_path(path_input, context="NLTK"):
69-
"""Ensures file access is restricted to allowed data directories."""
70+
def validate_path(path_input, context="NLTK", required_root=None):
71+
"""
72+
Ensures file access is restricted to allowed data directories.
73+
74+
:param path_input: The path to validate.
75+
:param context: Diagnostic context for warnings/errors.
76+
:param required_root: If provided, enforces that the path is strictly
77+
within this specific directory (scoped sandbox).
78+
"""
7079
if isinstance(path_input, int) or not path_input or not str(path_input).strip():
7180
return
7281
try:
@@ -77,26 +86,50 @@ def validate_path(path_input, context="NLTK"):
7786
if parsed.scheme in ("http", "https", "ftp"):
7887
return
7988
if parsed.scheme == "file":
80-
raw = urllib.request.url2pathname(parsed.path)
89+
raw = unquote(parsed.path)
8190

82-
target = Path(raw).resolve()
91+
# Resolve path to catch symlink escapes
92+
try:
93+
target = Path(raw).resolve()
94+
except (OSError, ValueError):
95+
# Fallback for virtual paths inside ZIPs (e.g. corpora/foo.zip/file.txt)
96+
lower_raw = raw.lower()
97+
if ".zip" in lower_raw:
98+
zip_idx = lower_raw.find(".zip") + 4
99+
target = Path(raw[:zip_idx]).resolve()
100+
else:
101+
target = Path(raw)
102+
103+
# LAYER 1: Scoped Sandbox (PR #3528 Integration)
104+
# This resolves both target and root to block symlink-based escapes.
105+
if required_root:
106+
root_raw = (
107+
required_root.path
108+
if hasattr(required_root, "path")
109+
else str(required_root)
110+
)
111+
scoped_root = Path(root_raw).resolve()
112+
if not (target == scoped_root or target.is_relative_to(scoped_root)):
113+
# Raise ValueError to match NLTK's historical CorpusReader error type
114+
raise ValueError(
115+
f"Security Violation [{context}]: Path {target} escapes root {scoped_root}"
116+
)
83117

118+
# LAYER 2: Global NLTK_DATA Sandbox
84119
allowed_roots = _get_allowed_roots()
85120
if any(target == root or target.is_relative_to(root) for root in allowed_roots):
86121
return
87122

88-
# 5. CWD Fallback (Explicit Opt-In for ENFORCE mode)
123+
# CWD Fallback (Explicit Opt-In for ENFORCE mode)
89124
try:
90125
cwd = Path(os.getcwd()).resolve()
91126
if target == cwd or target.is_relative_to(cwd):
92127
if any(cwd == root for root in allowed_roots):
93128
return
94-
95129
msg = (
96-
f"Security Violation [{context}]: CWD access is restricted in ENFORCE mode. "
97-
"To allow local data, use: nltk.data.path.append('.')"
130+
f"Security Violation [{context}]: CWD access restricted in ENFORCE mode. "
131+
"Authorize via: nltk.data.path.append('.')"
98132
)
99-
100133
if ENFORCE:
101134
raise PermissionError(msg)
102135
else:
@@ -106,7 +139,7 @@ def validate_path(path_input, context="NLTK"):
106139
stacklevel=3,
107140
)
108141
return
109-
except Exception:
142+
except (OSError, ValueError):
110143
pass
111144

112145
msg = f"Security Violation [{context}]: Unauthorized path {target}"
@@ -116,37 +149,29 @@ def validate_path(path_input, context="NLTK"):
116149
warnings.warn(msg, RuntimeWarning, stacklevel=3)
117150
except (PermissionError, ValueError):
118151
raise
119-
except Exception as e:
152+
except Exception:
120153
if ENFORCE:
121-
raise PermissionError(f"Path validation failed [{context}]: {e}")
154+
raise
122155

123156

124157
def validate_zip_archive(
125158
zip_obj_or_path, target_root, specific_member=None, context="ZipAudit"
126159
):
127-
"""Enhanced Zip-Slip protection with null-byte detection."""
160+
"""Enhanced Zip-Slip protection using Pathlib for cross-platform safety."""
128161
try:
129162
target = Path(target_root).resolve()
130-
target_str = str(target)
131-
# Normalize target paths for cross-platform, case-insensitive comparison
132-
target_norm_eq = os.path.normcase(target_str)
133-
# Ensure trailing separator for prefix check (e.g., 'C:\\data\\')
134-
target_norm_prefix = os.path.normcase(os.path.join(target_str, ""))
135163

136164
def _audit(zf):
137-
members_to_check = (
165+
members = (
138166
[specific_member] if specific_member is not None else zf.namelist()
139167
)
140-
for name in members_to_check:
168+
for name in members:
141169
name_str = name.filename if hasattr(name, "filename") else str(name)
142170
if "\0" in name_str:
143171
raise ValueError(f"Null byte in ZIP member: {name_str}")
144-
member_path_str = os.path.abspath(os.path.join(target_str, name_str))
145-
member_norm = os.path.normcase(member_path_str)
146-
if not (
147-
member_norm.startswith(target_norm_prefix)
148-
or member_norm == target_norm_eq
149-
):
172+
173+
member_path = (target / name_str).resolve()
174+
if not (member_path == target or member_path.is_relative_to(target)):
150175
msg = f"Security Violation [{context}]: Traversal member '{name_str}' detected."
151176
if ENFORCE:
152177
raise PermissionError(msg)
@@ -160,17 +185,17 @@ def _audit(zf):
160185
_audit(zf)
161186
except (PermissionError, ValueError):
162187
raise
163-
except (OSError, zipfile.BadZipFile) as e:
188+
except (OSError, zipfile.BadZipFile):
164189
if ENFORCE:
165-
raise PermissionError(f"Zip validation failed [{context}]: {e}") from e
190+
raise PermissionError("Zip validation failed")
166191

167192

168193
@lru_cache(maxsize=256)
169194
def _resolve_hostname(hostname):
170-
"""Cached hostname resolution to mitigate DNS rebinding (Copilot Medium)."""
195+
"""Cached hostname resolution to mitigate DNS rebinding."""
171196
try:
172197
return socket.getaddrinfo(hostname, None, proto=socket.IPPROTO_TCP)
173-
except Exception:
198+
except (OSError, ValueError):
174199
return []
175200

176201

@@ -180,13 +205,8 @@ def validate_network_url(url_input, context="NetworkIO"):
180205
return
181206
try:
182207
parsed = urlparse(str(url_input))
183-
184-
# FIX: Cross-route file scheme to path validation (Copilot High)
185208
if parsed.scheme == "file":
186-
validate_path(
187-
urllib.request.url2pathname(parsed.path),
188-
context=f"{context}.file_scheme",
189-
)
209+
validate_path(unquote(parsed.path), context=f"{context}.file_scheme")
190210
return
191211

192212
if parsed.scheme not in ("http", "https"):
@@ -202,41 +222,31 @@ def validate_network_url(url_input, context="NetworkIO"):
202222
for result in _resolve_hostname(parsed.hostname or ""):
203223
ip = ipaddress.ip_address(result[4][0])
204224
if ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_private:
205-
msg = f"Security Violation [{context}]: Blocked SSRF attempt to restricted IP {ip}"
225+
msg = f"Security Violation [{context}]: SSRF attempt to restricted IP {ip}"
206226
if ENFORCE:
207227
raise PermissionError(msg)
208228
else:
209229
warnings.warn(msg, RuntimeWarning, stacklevel=3)
210230
except (PermissionError, ValueError):
211231
raise
212-
except Exception as e:
232+
except Exception:
213233
if ENFORCE:
214-
raise PermissionError(f"URL validation failed [{context}]: {e}")
234+
raise
215235

216236

217237
class _ValidatingRedirectHandler(urllib.request.HTTPRedirectHandler):
218-
"""Ensures that every step of a redirect chain is re-validated against SSRF (Copilot High)."""
238+
"""Ensures that every step of a redirect chain is re-validated against SSRF."""
219239

220240
def redirect_request(self, req, fp, code, msg, headers, newurl):
221-
validate_network_url(newurl, context="pathsec.urlopen.redirect")
241+
validate_network_url(newurl, context="NetworkRedirect")
222242
return super().redirect_request(req, fp, code, msg, headers, newurl)
223243

224244

225-
_validating_opener = None
226-
227-
228-
def _get_validating_opener():
229-
global _validating_opener
230-
if _validating_opener is None:
231-
_validating_opener = urllib.request.build_opener(_ValidatingRedirectHandler())
232-
return _validating_opener
233-
234-
235245
def urlopen(url, *args, **kwargs):
236246
"""Secure wrapper for urllib.request.urlopen with redirect validation."""
237247
url_str = url.full_url if hasattr(url, "full_url") else str(url)
238248
validate_network_url(url_str, context="pathsec.urlopen")
239-
opener = _get_validating_opener()
249+
opener = urllib.request.build_opener(_ValidatingRedirectHandler())
240250
return opener.open(url, *args, **kwargs)
241251

242252

@@ -255,18 +265,11 @@ def __init__(self, file, *args, **kwargs):
255265
super().__init__(file, *args, **kwargs)
256266

257267
def extract(self, member, path=None, pwd=None):
258-
validate_zip_archive(
259-
self,
260-
path or os.getcwd(),
261-
specific_member=member,
262-
context="pathsec.ZipFile.extract",
263-
)
268+
validate_zip_archive(self, path or os.getcwd(), specific_member=member)
264269
return super().extract(member, path, pwd)
265270

266271
def extractall(self, path=None, members=None, pwd=None):
267-
validate_zip_archive(
268-
self, path or os.getcwd(), context="pathsec.ZipFile.extractall"
269-
)
272+
validate_zip_archive(self, path or os.getcwd())
270273
super().extractall(path, members, pwd)
271274

272275

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import os
2+
3+
import pytest
4+
5+
from nltk.corpus.reader.plaintext import PlaintextCorpusReader
6+
7+
8+
@pytest.mark.skipif(not hasattr(os, "symlink"), reason="requires os.symlink")
9+
def test_corpusreader_open_blocks_symlink_escape(tmp_path):
10+
# Arrange: a corpus root in tempdir
11+
corpus_root = tmp_path / "corpus"
12+
corpus_root.mkdir()
13+
14+
# Arrange: a second directory also in tempdir (so pathsec allowed-roots won't object)
15+
outside_dir = tmp_path / "outside"
16+
outside_dir.mkdir()
17+
18+
# Secret file outside the corpus root
19+
secret = outside_dir / "secret.txt"
20+
secret.write_text("should not be readable via corpus_root", encoding="utf-8")
21+
22+
# Create a symlink inside corpus_root that points outside corpus_root
23+
link = corpus_root / "outside_link"
24+
os.symlink(str(outside_dir), str(link))
25+
26+
reader = PlaintextCorpusReader(str(corpus_root), r".*")
27+
28+
# Act + Assert: opening via the symlinked path must be blocked by corpus-root sandboxing
29+
with pytest.raises((ValueError, PermissionError, OSError)):
30+
reader.open("outside_link/secret.txt").read()

0 commit comments

Comments
 (0)