Skip to content

Commit 02a3a91

Browse files
committed
Final hardening
1 parent a634475 commit 02a3a91

File tree

1 file changed

+5
-18
lines changed

1 file changed

+5
-18
lines changed

nltk/pathsec.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _get_allowed_roots():
4040
return _ALLOWED_ROOTS_CACHE
4141

4242
roots = set()
43-
# FIX: Use os.pathsep for environment variables to prevent directory leakage
43+
# FIX: Use os.pathsep for environment variables (Copilot High)
4444
for p in current_paths + env_paths.split(os.pathsep):
4545
if p:
4646
try:
@@ -49,7 +49,6 @@ def _get_allowed_roots():
4949
except:
5050
continue
5151

52-
# Trust standard data locations and the system TEMP directory
5352
import tempfile
5453

5554
for loc in ["~/nltk_data", "/usr/share/nltk_data", tempfile.gettempdir()]:
@@ -70,26 +69,22 @@ def validate_path(path_input, context="NLTK"):
7069
if isinstance(path_input, int) or not path_input or not str(path_input).strip():
7170
return
7271
try:
73-
# 1. Handle NLTK Pointers
7472
raw = path_input.path if hasattr(path_input, "path") else str(path_input)
7573

76-
# 2. URL Handling
7774
if "://" in raw:
7875
parsed = urlparse(raw)
7976
if parsed.scheme in ("http", "https", "ftp"):
8077
return
8178
if parsed.scheme == "file":
8279
raw = unquote(parsed.path)
8380

84-
# 3. ZIP TRANSPARENCY: Truncate to the archive file
8581
lower_raw = raw.lower()
8682
if ".zip" in lower_raw:
8783
zip_idx = lower_raw.find(".zip") + 4
8884
raw = raw[:zip_idx]
8985

9086
target = Path(raw).resolve()
9187

92-
# 4. Containment Check against authorized roots
9388
allowed_roots = _get_allowed_roots()
9489
if any(target == root or target.is_relative_to(root) for root in allowed_roots):
9590
return
@@ -98,7 +93,6 @@ def validate_path(path_input, context="NLTK"):
9893
try:
9994
cwd = Path(os.getcwd()).resolve()
10095
if target == cwd or target.is_relative_to(cwd):
101-
# Only allow if CWD is explicitly in the search path (Explicit Opt-In)
10296
if any(cwd == root for root in allowed_roots):
10397
return
10498

@@ -145,10 +139,8 @@ def _audit(zf):
145139
)
146140
for name in members_to_check:
147141
name_str = name.filename if hasattr(name, "filename") else str(name)
148-
149142
if "\0" in name_str:
150143
raise ValueError(f"Null byte in ZIP member: {name_str}")
151-
152144
member_path_str = os.path.abspath(os.path.join(target_str, name_str))
153145
if (
154146
not member_path_str.startswith(target_str + os.sep)
@@ -174,7 +166,7 @@ def _audit(zf):
174166

175167
@lru_cache(maxsize=256)
176168
def _resolve_hostname(hostname):
177-
"""Cached hostname resolution to mitigate DNS rebinding."""
169+
"""Cached hostname resolution to mitigate DNS rebinding (Copilot Medium)."""
178170
try:
179171
return socket.getaddrinfo(hostname, None, proto=socket.IPPROTO_TCP)
180172
except:
@@ -188,14 +180,11 @@ def validate_network_url(url_input, context="NetworkIO"):
188180
try:
189181
parsed = urlparse(str(url_input))
190182

191-
# 1. Block file:// in a network context to prevent local file disclosure
183+
# FIX: Cross-route file scheme to path validation (Copilot High)
192184
if parsed.scheme == "file":
193-
# Direct to path validation instead of allowing it as a "URL"
194-
local_path = unquote(parsed.path)
195-
validate_path(local_path, context=f"{context}.file_scheme")
185+
validate_path(unquote(parsed.path), context=f"{context}.file_scheme")
196186
return
197187

198-
# 2. Strict scheme check for true network calls
199188
if parsed.scheme not in ("http", "https"):
200189
msg = (
201190
f"Security Violation [{context}]: Unsupported scheme '{parsed.scheme}'."
@@ -206,7 +195,6 @@ def validate_network_url(url_input, context="NetworkIO"):
206195
warnings.warn(msg, RuntimeWarning, stacklevel=3)
207196
return
208197

209-
# 3. SSRF / DNS Rebinding checks (existing logic)
210198
for result in _resolve_hostname(parsed.hostname or ""):
211199
ip = ipaddress.ip_address(result[4][0])
212200
if ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_private:
@@ -223,7 +211,7 @@ def validate_network_url(url_input, context="NetworkIO"):
223211

224212

225213
class _ValidatingRedirectHandler(urllib.request.HTTPRedirectHandler):
226-
"""Ensures that every step of a redirect chain is re-validated against SSRF."""
214+
"""Ensures that every step of a redirect chain is re-validated against SSRF (Copilot High)."""
227215

228216
def redirect_request(self, req, fp, code, msg, headers, newurl):
229217
validate_network_url(newurl, context="NetworkRedirect")
@@ -234,7 +222,6 @@ def urlopen(url, *args, **kwargs):
234222
"""Secure wrapper for urllib.request.urlopen with redirect validation."""
235223
url_str = url.full_url if hasattr(url, "full_url") else str(url)
236224
validate_network_url(url_str, context="pathsec.urlopen")
237-
# Use custom opener to enforce validation on 30x redirects
238225
opener = urllib.request.build_opener(_ValidatingRedirectHandler())
239226
return opener.open(url, *args, **kwargs)
240227

0 commit comments

Comments
 (0)