diff --git a/keras_tuner/engine/base_tuner.py b/keras_tuner/engine/base_tuner.py index 9bba6b91c..d03a98323 100644 --- a/keras_tuner/engine/base_tuner.py +++ b/keras_tuner/engine/base_tuner.py @@ -15,6 +15,7 @@ import copy +import ntpath import os import traceback import warnings @@ -115,6 +116,7 @@ def __init__( # Ops and metadata self.directory = directory or "." self.project_name = project_name or "untitled_project" + self._validate_project_path(self.directory, self.project_name) self.oracle._set_project_dir(self.directory, self.project_name) if overwrite and backend.io.exists(self.project_dir): @@ -122,6 +124,7 @@ def __init__( # To support tuning distribution. self.tuner_id = os.environ.get("KERASTUNER_TUNER_ID", "tuner0") + self._validate_tuner_id(self.tuner_id) # Reloading state. if not overwrite and backend.io.exists(self._get_tuner_fname()): @@ -471,5 +474,43 @@ def get_trial_dir(self, trial_id): utils.create_directory(dirname) return dirname + @staticmethod + def _validate_project_path(directory, project_name): + """Validates that directory and project_name do not contain path traversal. + + Raises: + ValueError: If path traversal sequences or absolute paths are detected. + """ + for name, segment in (("directory", str(directory)), ("project_name", str(project_name))): + if ".." in segment: + raise ValueError( + f"Path traversal is not allowed in {name}. Received: {segment!r}" + ) + if ( + name == "project_name" + and ( + os.path.isabs(segment) + or ntpath.isabs(segment) + or ntpath.splitdrive(segment)[0] + ) + ): + raise ValueError( + f"Absolute paths are not allowed in {name}. Received: {segment!r}" + ) + + @staticmethod + def _validate_tuner_id(tuner_id): + """Validates that tuner_id does not contain path traversal sequences. + + Raises: + ValueError: If path traversal sequences or path separators are detected. + """ + tuner_id_str = str(tuner_id) + if ".." in tuner_id_str or "/" in tuner_id_str or "\\" in tuner_id_str: + raise ValueError( + f"tuner_id cannot contain path separators or traversal sequences. " + f"Received: {tuner_id_str!r}" + ) + def _get_tuner_fname(self): return os.path.join(str(self.project_dir), f"{str(self.tuner_id)}.json") diff --git a/keras_tuner/engine/base_tuner_test.py b/keras_tuner/engine/base_tuner_test.py index b7c137bee..a627bcf39 100644 --- a/keras_tuner/engine/base_tuner_test.py +++ b/keras_tuner/engine/base_tuner_test.py @@ -248,3 +248,117 @@ def _the_func(): _the_func, num_workers=2, wait_for_chief=True ) oracle_client.TIMEOUT = timeout + + + +def test_directory_path_traversal_raises_value_error(tmp_path): + def build_model(hp): + hp.Boolean("a") + + with pytest.raises(ValueError, match="Path traversal"): + gridsearch.GridSearch( + directory="../evil", + hypermodel=build_model, + max_trials=1, + ) + + +def test_project_name_path_traversal_raises_value_error(tmp_path): + def build_model(hp): + hp.Boolean("a") + + with pytest.raises(ValueError, match="Path traversal"): + gridsearch.GridSearch( + directory=tmp_path, + project_name="../../evil", + hypermodel=build_model, + max_trials=1, + ) + + +def test_tuner_id_path_traversal_raises_value_error(tmp_path): + def build_model(hp): + hp.Boolean("a") + + import os + original_tuner_id = os.environ.get("KERASTUNER_TUNER_ID") + try: + os.environ["KERASTUNER_TUNER_ID"] = "../../../evil" + with pytest.raises(ValueError, match="tuner_id"): + gridsearch.GridSearch( + directory=tmp_path, + hypermodel=build_model, + max_trials=1, + ) + finally: + if original_tuner_id is not None: + os.environ["KERASTUNER_TUNER_ID"] = original_tuner_id + else: + os.environ.pop("KERASTUNER_TUNER_ID", None) + + +def test_valid_directory_and_project_name_succeeds(tmp_path): + def build_model(hp): + hp.Boolean("a") + + # These should not raise + tuner = gridsearch.GridSearch( + directory=tmp_path, + project_name="my_project", + hypermodel=build_model, + max_trials=1, + ) + assert tuner.directory == str(tmp_path) + assert tuner.project_name == "my_project" + + +def test_project_name_absolute_path_raises_value_error(tmp_path): + def build_model(hp): + hp.Boolean("a") + + with pytest.raises(ValueError, match="Absolute paths"): + gridsearch.GridSearch( + directory=tmp_path, + project_name="/etc", + hypermodel=build_model, + max_trials=1, + ) + + +@pytest.mark.parametrize( + "project_name", ["C:\\Windows", "C:Windows", "\\Windows"] +) +def test_project_name_windows_absolute_path_raises_value_error( + tmp_path, project_name +): + def build_model(hp): + hp.Boolean("a") + + with pytest.raises(ValueError, match="Absolute paths"): + gridsearch.GridSearch( + directory=tmp_path, + project_name=project_name, + hypermodel=build_model, + max_trials=1, + ) + + +def test_tuner_id_forward_slash_raises_value_error(tmp_path): + def build_model(hp): + hp.Boolean("a") + + import os + original_tuner_id = os.environ.get("KERASTUNER_TUNER_ID") + try: + os.environ["KERASTUNER_TUNER_ID"] = "evil/tuner" + with pytest.raises(ValueError, match="tuner_id"): + gridsearch.GridSearch( + directory=tmp_path, + hypermodel=build_model, + max_trials=1, + ) + finally: + if original_tuner_id is not None: + os.environ["KERASTUNER_TUNER_ID"] = original_tuner_id + else: + os.environ.pop("KERASTUNER_TUNER_ID", None)