77#
88"""Centralized I/O security sentinel for NLTK."""
99
10+ """Centralized I/O security sentinel for NLTK."""
1011import builtins
1112import ipaddress
1213import os
@@ -32,6 +33,7 @@ def _get_allowed_roots():
3233
3334 current_paths = []
3435 if "nltk.data" in sys .modules :
36+ # Accessing nltk.data.path via sys.modules to avoid top-level circularity
3537 current_paths = list (getattr (sys .modules ["nltk.data" ], "path" , []))
3638
3739 env_paths = os .environ .get ("NLTK_DATA" , "" )
@@ -41,13 +43,13 @@ def _get_allowed_roots():
4143 return _ALLOWED_ROOTS_CACHE
4244
4345 roots = set ()
44- # FIX: Use os.pathsep for environment variables (Copilot High)
4546 for p in current_paths + env_paths .split (os .pathsep ):
4647 if p :
4748 try :
49+ # Handle both string paths and PathPointer objects
4850 raw_p = p .path if hasattr (p , "path" ) else p
4951 roots .add (Path (str (raw_p )).resolve ())
50- except Exception :
52+ except ( OSError , ValueError , RuntimeError ) :
5153 continue
5254
5355 import tempfile
@@ -57,16 +59,23 @@ def _get_allowed_roots():
5759 p = Path (loc ).expanduser ().resolve ()
5860 if p .exists ():
5961 roots .add (p )
60- except Exception :
62+ except ( OSError , ValueError , RuntimeError ) :
6163 continue
6264
6365 _ALLOWED_ROOTS_CACHE = roots
6466 _LAST_DATA_PATHS = current_state
6567 return roots
6668
6769
68- def validate_path (path_input , context = "NLTK" ):
69- """Ensures file access is restricted to allowed data directories."""
70+ def validate_path (path_input , context = "NLTK" , required_root = None ):
71+ """
72+ Ensures file access is restricted to allowed data directories.
73+
74+ :param path_input: The path to validate.
75+ :param context: Diagnostic context for warnings/errors.
76+ :param required_root: If provided, enforces that the path is strictly
77+ within this specific directory (scoped sandbox).
78+ """
7079 if isinstance (path_input , int ) or not path_input or not str (path_input ).strip ():
7180 return
7281 try :
@@ -77,26 +86,50 @@ def validate_path(path_input, context="NLTK"):
7786 if parsed .scheme in ("http" , "https" , "ftp" ):
7887 return
7988 if parsed .scheme == "file" :
80- raw = urllib . request . url2pathname (parsed .path )
89+ raw = unquote (parsed .path )
8190
82- target = Path (raw ).resolve ()
91+ # Resolve path to catch symlink escapes
92+ try :
93+ target = Path (raw ).resolve ()
94+ except (OSError , ValueError ):
95+ # Fallback for virtual paths inside ZIPs (e.g. corpora/foo.zip/file.txt)
96+ lower_raw = raw .lower ()
97+ if ".zip" in lower_raw :
98+ zip_idx = lower_raw .find (".zip" ) + 4
99+ target = Path (raw [:zip_idx ]).resolve ()
100+ else :
101+ target = Path (raw )
102+
103+ # LAYER 1: Scoped Sandbox (PR #3528 Integration)
104+ # This resolves both target and root to block symlink-based escapes.
105+ if required_root :
106+ root_raw = (
107+ required_root .path
108+ if hasattr (required_root , "path" )
109+ else str (required_root )
110+ )
111+ scoped_root = Path (root_raw ).resolve ()
112+ if not (target == scoped_root or target .is_relative_to (scoped_root )):
113+ # Raise ValueError to match NLTK's historical CorpusReader error type
114+ raise ValueError (
115+ f"Security Violation [{ context } ]: Path { target } escapes root { scoped_root } "
116+ )
83117
118+ # LAYER 2: Global NLTK_DATA Sandbox
84119 allowed_roots = _get_allowed_roots ()
85120 if any (target == root or target .is_relative_to (root ) for root in allowed_roots ):
86121 return
87122
88- # 5. CWD Fallback (Explicit Opt-In for ENFORCE mode)
123+ # CWD Fallback (Explicit Opt-In for ENFORCE mode)
89124 try :
90125 cwd = Path (os .getcwd ()).resolve ()
91126 if target == cwd or target .is_relative_to (cwd ):
92127 if any (cwd == root for root in allowed_roots ):
93128 return
94-
95129 msg = (
96- f"Security Violation [{ context } ]: CWD access is restricted in ENFORCE mode. "
97- "To allow local data, use : nltk.data.path.append('.')"
130+ f"Security Violation [{ context } ]: CWD access restricted in ENFORCE mode. "
131+ "Authorize via : nltk.data.path.append('.')"
98132 )
99-
100133 if ENFORCE :
101134 raise PermissionError (msg )
102135 else :
@@ -106,7 +139,7 @@ def validate_path(path_input, context="NLTK"):
106139 stacklevel = 3 ,
107140 )
108141 return
109- except Exception :
142+ except ( OSError , ValueError ) :
110143 pass
111144
112145 msg = f"Security Violation [{ context } ]: Unauthorized path { target } "
@@ -116,37 +149,29 @@ def validate_path(path_input, context="NLTK"):
116149 warnings .warn (msg , RuntimeWarning , stacklevel = 3 )
117150 except (PermissionError , ValueError ):
118151 raise
119- except Exception as e :
152+ except Exception :
120153 if ENFORCE :
121- raise PermissionError ( f"Path validation failed [ { context } ]: { e } " )
154+ raise
122155
123156
124157def validate_zip_archive (
125158 zip_obj_or_path , target_root , specific_member = None , context = "ZipAudit"
126159):
127- """Enhanced Zip-Slip protection with null-byte detection ."""
160+ """Enhanced Zip-Slip protection using Pathlib for cross-platform safety ."""
128161 try :
129162 target = Path (target_root ).resolve ()
130- target_str = str (target )
131- # Normalize target paths for cross-platform, case-insensitive comparison
132- target_norm_eq = os .path .normcase (target_str )
133- # Ensure trailing separator for prefix check (e.g., 'C:\\data\\')
134- target_norm_prefix = os .path .normcase (os .path .join (target_str , "" ))
135163
136164 def _audit (zf ):
137- members_to_check = (
165+ members = (
138166 [specific_member ] if specific_member is not None else zf .namelist ()
139167 )
140- for name in members_to_check :
168+ for name in members :
141169 name_str = name .filename if hasattr (name , "filename" ) else str (name )
142170 if "\0 " in name_str :
143171 raise ValueError (f"Null byte in ZIP member: { name_str } " )
144- member_path_str = os .path .abspath (os .path .join (target_str , name_str ))
145- member_norm = os .path .normcase (member_path_str )
146- if not (
147- member_norm .startswith (target_norm_prefix )
148- or member_norm == target_norm_eq
149- ):
172+
173+ member_path = (target / name_str ).resolve ()
174+ if not (member_path == target or member_path .is_relative_to (target )):
150175 msg = f"Security Violation [{ context } ]: Traversal member '{ name_str } ' detected."
151176 if ENFORCE :
152177 raise PermissionError (msg )
@@ -160,17 +185,17 @@ def _audit(zf):
160185 _audit (zf )
161186 except (PermissionError , ValueError ):
162187 raise
163- except (OSError , zipfile .BadZipFile ) as e :
188+ except (OSError , zipfile .BadZipFile ):
164189 if ENFORCE :
165- raise PermissionError (f "Zip validation failed [ { context } ]: { e } " ) from e
190+ raise PermissionError ("Zip validation failed" )
166191
167192
168193@lru_cache (maxsize = 256 )
169194def _resolve_hostname (hostname ):
170- """Cached hostname resolution to mitigate DNS rebinding (Copilot Medium) ."""
195+ """Cached hostname resolution to mitigate DNS rebinding."""
171196 try :
172197 return socket .getaddrinfo (hostname , None , proto = socket .IPPROTO_TCP )
173- except Exception :
198+ except ( OSError , ValueError ) :
174199 return []
175200
176201
@@ -180,13 +205,8 @@ def validate_network_url(url_input, context="NetworkIO"):
180205 return
181206 try :
182207 parsed = urlparse (str (url_input ))
183-
184- # FIX: Cross-route file scheme to path validation (Copilot High)
185208 if parsed .scheme == "file" :
186- validate_path (
187- urllib .request .url2pathname (parsed .path ),
188- context = f"{ context } .file_scheme" ,
189- )
209+ validate_path (unquote (parsed .path ), context = f"{ context } .file_scheme" )
190210 return
191211
192212 if parsed .scheme not in ("http" , "https" ):
@@ -202,41 +222,31 @@ def validate_network_url(url_input, context="NetworkIO"):
202222 for result in _resolve_hostname (parsed .hostname or "" ):
203223 ip = ipaddress .ip_address (result [4 ][0 ])
204224 if ip .is_loopback or ip .is_link_local or ip .is_multicast or ip .is_private :
205- msg = f"Security Violation [{ context } ]: Blocked SSRF attempt to restricted IP { ip } "
225+ msg = f"Security Violation [{ context } ]: SSRF attempt to restricted IP { ip } "
206226 if ENFORCE :
207227 raise PermissionError (msg )
208228 else :
209229 warnings .warn (msg , RuntimeWarning , stacklevel = 3 )
210230 except (PermissionError , ValueError ):
211231 raise
212- except Exception as e :
232+ except Exception :
213233 if ENFORCE :
214- raise PermissionError ( f"URL validation failed [ { context } ]: { e } " )
234+ raise
215235
216236
217237class _ValidatingRedirectHandler (urllib .request .HTTPRedirectHandler ):
218- """Ensures that every step of a redirect chain is re-validated against SSRF (Copilot High) ."""
238+ """Ensures that every step of a redirect chain is re-validated against SSRF."""
219239
220240 def redirect_request (self , req , fp , code , msg , headers , newurl ):
221- validate_network_url (newurl , context = "pathsec.urlopen.redirect " )
241+ validate_network_url (newurl , context = "NetworkRedirect " )
222242 return super ().redirect_request (req , fp , code , msg , headers , newurl )
223243
224244
225- _validating_opener = None
226-
227-
228- def _get_validating_opener ():
229- global _validating_opener
230- if _validating_opener is None :
231- _validating_opener = urllib .request .build_opener (_ValidatingRedirectHandler ())
232- return _validating_opener
233-
234-
235245def urlopen (url , * args , ** kwargs ):
236246 """Secure wrapper for urllib.request.urlopen with redirect validation."""
237247 url_str = url .full_url if hasattr (url , "full_url" ) else str (url )
238248 validate_network_url (url_str , context = "pathsec.urlopen" )
239- opener = _get_validating_opener ( )
249+ opener = urllib . request . build_opener ( _ValidatingRedirectHandler () )
240250 return opener .open (url , * args , ** kwargs )
241251
242252
@@ -255,18 +265,11 @@ def __init__(self, file, *args, **kwargs):
255265 super ().__init__ (file , * args , ** kwargs )
256266
257267 def extract (self , member , path = None , pwd = None ):
258- validate_zip_archive (
259- self ,
260- path or os .getcwd (),
261- specific_member = member ,
262- context = "pathsec.ZipFile.extract" ,
263- )
268+ validate_zip_archive (self , path or os .getcwd (), specific_member = member )
264269 return super ().extract (member , path , pwd )
265270
266271 def extractall (self , path = None , members = None , pwd = None ):
267- validate_zip_archive (
268- self , path or os .getcwd (), context = "pathsec.ZipFile.extractall"
269- )
272+ validate_zip_archive (self , path or os .getcwd ())
270273 super ().extractall (path , members , pwd )
271274
272275
0 commit comments