Skip to content

Commit aa40aa9

Browse files
security: fix path traversal in BaseTuner directory/project_name/tuner_id
Fixes a path traversal vulnerability where user-supplied directory, project_name, and tuner_id parameters were joined into filesystem paths without validation. An attacker could set directory='../evil' or project_name='../../etc' to escape the intended project directory. When combined with overwrite=True, this could lead to arbitrary directory deletion. The fix adds _validate_project_path() and _validate_tuner_id() static methods that reject path traversal sequences (..) and path separators in tuner_id. Tests cover directory traversal, project_name traversal, and tuner_id traversal with environment variable input.
1 parent 48f6714 commit aa40aa9

2 files changed

Lines changed: 130 additions & 0 deletions

File tree

keras_tuner/engine/base_tuner.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,15 @@ def __init__(
115115
# Ops and metadata
116116
self.directory = directory or "."
117117
self.project_name = project_name or "untitled_project"
118+
self._validate_project_path(self.directory, self.project_name)
118119
self.oracle._set_project_dir(self.directory, self.project_name)
119120

120121
if overwrite and backend.io.exists(self.project_dir):
121122
backend.io.rmtree(self.project_dir)
122123

123124
# To support tuning distribution.
124125
self.tuner_id = os.environ.get("KERASTUNER_TUNER_ID", "tuner0")
126+
self._validate_tuner_id(self.tuner_id)
125127

126128
# Reloading state.
127129
if not overwrite and backend.io.exists(self._get_tuner_fname()):
@@ -471,5 +473,37 @@ def get_trial_dir(self, trial_id):
471473
utils.create_directory(dirname)
472474
return dirname
473475

476+
@staticmethod
477+
def _validate_project_path(directory, project_name):
478+
"""Validates that directory and project_name do not contain path traversal.
479+
480+
Raises:
481+
ValueError: If path traversal sequences or absolute paths are detected.
482+
"""
483+
for name, segment in (("directory", str(directory)), ("project_name", str(project_name))):
484+
if ".." in segment:
485+
raise ValueError(
486+
f"Path traversal is not allowed in {name}. Received: {segment!r}"
487+
)
488+
# Reject absolute paths in project_name to prevent writing outside CWD
489+
if name == "project_name" and os.path.isabs(segment):
490+
raise ValueError(
491+
f"Absolute paths are not allowed in {name}. Received: {segment!r}"
492+
)
493+
494+
@staticmethod
495+
def _validate_tuner_id(tuner_id):
496+
"""Validates that tuner_id does not contain path traversal sequences.
497+
498+
Raises:
499+
ValueError: If path traversal sequences or path separators are detected.
500+
"""
501+
tuner_id_str = str(tuner_id)
502+
if ".." in tuner_id_str or "/" in tuner_id_str or "\\" in tuner_id_str:
503+
raise ValueError(
504+
f"tuner_id cannot contain path separators or traversal sequences. "
505+
f"Received: {tuner_id_str!r}"
506+
)
507+
474508
def _get_tuner_fname(self):
475509
return os.path.join(str(self.project_dir), f"{str(self.tuner_id)}.json")

keras_tuner/engine/base_tuner_test.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,99 @@ def _the_func():
248248
_the_func, num_workers=2, wait_for_chief=True
249249
)
250250
oracle_client.TIMEOUT = timeout
251+
252+
253+
254+
def test_directory_path_traversal_raises_value_error(tmp_path):
255+
def build_model(hp):
256+
hp.Boolean("a")
257+
258+
with pytest.raises(ValueError, match="Path traversal"):
259+
gridsearch.GridSearch(
260+
directory="../evil",
261+
hypermodel=build_model,
262+
max_trials=1,
263+
)
264+
265+
266+
def test_project_name_path_traversal_raises_value_error(tmp_path):
267+
def build_model(hp):
268+
hp.Boolean("a")
269+
270+
with pytest.raises(ValueError, match="Path traversal"):
271+
gridsearch.GridSearch(
272+
directory=tmp_path,
273+
project_name="../../evil",
274+
hypermodel=build_model,
275+
max_trials=1,
276+
)
277+
278+
279+
def test_tuner_id_path_traversal_raises_value_error(tmp_path):
280+
def build_model(hp):
281+
hp.Boolean("a")
282+
283+
import os
284+
original_tuner_id = os.environ.get("KERASTUNER_TUNER_ID")
285+
try:
286+
os.environ["KERASTUNER_TUNER_ID"] = "../../../evil"
287+
with pytest.raises(ValueError, match="tuner_id"):
288+
gridsearch.GridSearch(
289+
directory=tmp_path,
290+
hypermodel=build_model,
291+
max_trials=1,
292+
)
293+
finally:
294+
if original_tuner_id is not None:
295+
os.environ["KERASTUNER_TUNER_ID"] = original_tuner_id
296+
else:
297+
os.environ.pop("KERASTUNER_TUNER_ID", None)
298+
299+
300+
def test_valid_directory_and_project_name_succeeds(tmp_path):
301+
def build_model(hp):
302+
hp.Boolean("a")
303+
304+
# These should not raise
305+
tuner = gridsearch.GridSearch(
306+
directory=tmp_path,
307+
project_name="my_project",
308+
hypermodel=build_model,
309+
max_trials=1,
310+
)
311+
assert tuner.directory == str(tmp_path)
312+
assert tuner.project_name == "my_project"
313+
314+
315+
def test_project_name_absolute_path_raises_value_error(tmp_path):
316+
def build_model(hp):
317+
hp.Boolean("a")
318+
319+
with pytest.raises(ValueError, match="Absolute paths"):
320+
gridsearch.GridSearch(
321+
directory=tmp_path,
322+
project_name="/etc",
323+
hypermodel=build_model,
324+
max_trials=1,
325+
)
326+
327+
328+
def test_tuner_id_forward_slash_raises_value_error(tmp_path):
329+
def build_model(hp):
330+
hp.Boolean("a")
331+
332+
import os
333+
original_tuner_id = os.environ.get("KERASTUNER_TUNER_ID")
334+
try:
335+
os.environ["KERASTUNER_TUNER_ID"] = "evil/tuner"
336+
with pytest.raises(ValueError, match="tuner_id"):
337+
gridsearch.GridSearch(
338+
directory=tmp_path,
339+
hypermodel=build_model,
340+
max_trials=1,
341+
)
342+
finally:
343+
if original_tuner_id is not None:
344+
os.environ["KERASTUNER_TUNER_ID"] = original_tuner_id
345+
else:
346+
os.environ.pop("KERASTUNER_TUNER_ID", None)

0 commit comments

Comments
 (0)