Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions keras_tuner/engine/base_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


import copy
import ntpath
import os
import traceback
import warnings
Expand Down Expand Up @@ -115,13 +116,15 @@ def __init__(
# Ops and metadata
self.directory = directory or "."
self.project_name = project_name or "untitled_project"
self._validate_project_path(self.directory, self.project_name)
self.oracle._set_project_dir(self.directory, self.project_name)

if overwrite and backend.io.exists(self.project_dir):
backend.io.rmtree(self.project_dir)

# To support tuning distribution.
self.tuner_id = os.environ.get("KERASTUNER_TUNER_ID", "tuner0")
self._validate_tuner_id(self.tuner_id)

# Reloading state.
if not overwrite and backend.io.exists(self._get_tuner_fname()):
Expand Down Expand Up @@ -471,5 +474,43 @@ def get_trial_dir(self, trial_id):
utils.create_directory(dirname)
return dirname

@staticmethod
def _validate_project_path(directory, project_name):
"""Validates that directory and project_name do not contain path traversal.

Raises:
ValueError: If path traversal sequences or absolute paths are detected.
"""
for name, segment in (("directory", str(directory)), ("project_name", str(project_name))):
if ".." in segment:
raise ValueError(
f"Path traversal is not allowed in {name}. Received: {segment!r}"
)
if (
name == "project_name"
and (
os.path.isabs(segment)
or ntpath.isabs(segment)
or ntpath.splitdrive(segment)[0]
)
):
raise ValueError(
f"Absolute paths are not allowed in {name}. Received: {segment!r}"
)
Comment on lines +478 to +499
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

Security Vulnerability: Absolute Path Bypass in project_name\n\nWhile checking for ".." in segment prevents directory traversal using relative path segments, it does not prevent project_name from being an absolute path (e.g., /etc/passwd or C:\\Windows).\n\nIn Python, os.path.join(directory, project_name) will discard the directory prefix entirely if project_name is an absolute path. When overwrite=True is set, this can lead to arbitrary directory deletion or unauthorized file access outside the intended workspace.\n\nTo resolve this, we should explicitly validate that project_name is neither an absolute path nor a drive-relative path.

    def _validate_project_path(directory, project_name):\n        """Validates that directory and project_name do not contain path traversal.\n\n        Raises:\n            ValueError: If path traversal sequences are detected.\n        """\n        for segment in (str(directory), str(project_name)):\n            if ".." in segment:\n                raise ValueError(\n                    f"Path traversal is not allowed. Received: {segment!r}"\n                )\n        proj_str = str(project_name)\n        if (\n            os.path.isabs(proj_str)\n            or os.path.splitdrive(proj_str)[0]\n            or proj_str.startswith(("/", "\\"))\n        ):\n            raise ValueError(\n                f"project_name cannot be an absolute or drive-relative path. "\n                f"Received: {project_name!r}"\n            )

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in the current PR head dbc21de. _validate_project_path() now rejects POSIX absolute paths, Windows rooted paths, and Windows drive-prefixed paths via ntpath, with regression coverage for C:\Windows, C:Windows, and \Windows project names.


@staticmethod
def _validate_tuner_id(tuner_id):
"""Validates that tuner_id does not contain path traversal sequences.

Raises:
ValueError: If path traversal sequences or path separators are detected.
"""
tuner_id_str = str(tuner_id)
if ".." in tuner_id_str or "/" in tuner_id_str or "\\" in tuner_id_str:
raise ValueError(
f"tuner_id cannot contain path separators or traversal sequences. "
f"Received: {tuner_id_str!r}"
)
Comment on lines +502 to +513
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

Security Vulnerability: Windows Path Separator Bypass in tuner_id\n\nOn Windows, os.path.sep is \\. Checking os.path.sep in tuner_id_str will not detect forward slashes (/), which are still treated as path separators by many Windows APIs and Python's path manipulation functions.\n\nAn attacker could supply a tuner_id containing / (e.g., subdir/tuner_id) to bypass this check on Windows, leading to path traversal or arbitrary file creation.\n\nTo fix this, we should explicitly check for both / and \\ path separators regardless of the host operating system.

    def _validate_tuner_id(tuner_id):\n        """Validates that tuner_id does not contain path traversal sequences.\n\n        Raises:\n            ValueError: If path traversal sequences or path separators are detected.\n        """\n        tuner_id_str = str(tuner_id)\n        if (\n            ".." in tuner_id_str\n            or "/" in tuner_id_str\n            or "\\" in tuner_id_str\n        ):\n            raise ValueError(\n                f"tuner_id cannot contain path separators or traversal sequences. "\n                f"Received: {tuner_id_str!r}"\n            )

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in the current PR head dbc21de. _validate_tuner_id() checks both / and \ explicitly, and the PR includes regression coverage for a forward-slash KERASTUNER_TUNER_ID.


def _get_tuner_fname(self):
return os.path.join(str(self.project_dir), f"{str(self.tuner_id)}.json")
114 changes: 114 additions & 0 deletions keras_tuner/engine/base_tuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,117 @@ def _the_func():
_the_func, num_workers=2, wait_for_chief=True
)
oracle_client.TIMEOUT = timeout



def test_directory_path_traversal_raises_value_error(tmp_path):
def build_model(hp):
hp.Boolean("a")

with pytest.raises(ValueError, match="Path traversal"):
gridsearch.GridSearch(
directory="../evil",
hypermodel=build_model,
max_trials=1,
)


def test_project_name_path_traversal_raises_value_error(tmp_path):
def build_model(hp):
hp.Boolean("a")

with pytest.raises(ValueError, match="Path traversal"):
gridsearch.GridSearch(
directory=tmp_path,
project_name="../../evil",
hypermodel=build_model,
max_trials=1,
)


def test_tuner_id_path_traversal_raises_value_error(tmp_path):
def build_model(hp):
hp.Boolean("a")

import os
original_tuner_id = os.environ.get("KERASTUNER_TUNER_ID")
try:
os.environ["KERASTUNER_TUNER_ID"] = "../../../evil"
with pytest.raises(ValueError, match="tuner_id"):
gridsearch.GridSearch(
directory=tmp_path,
hypermodel=build_model,
max_trials=1,
)
finally:
if original_tuner_id is not None:
os.environ["KERASTUNER_TUNER_ID"] = original_tuner_id
else:
os.environ.pop("KERASTUNER_TUNER_ID", None)


def test_valid_directory_and_project_name_succeeds(tmp_path):
def build_model(hp):
hp.Boolean("a")

# These should not raise
tuner = gridsearch.GridSearch(
directory=tmp_path,
project_name="my_project",
hypermodel=build_model,
max_trials=1,
)
assert tuner.directory == str(tmp_path)
assert tuner.project_name == "my_project"


def test_project_name_absolute_path_raises_value_error(tmp_path):
def build_model(hp):
hp.Boolean("a")

with pytest.raises(ValueError, match="Absolute paths"):
gridsearch.GridSearch(
directory=tmp_path,
project_name="/etc",
hypermodel=build_model,
max_trials=1,
)


@pytest.mark.parametrize(
"project_name", ["C:\\Windows", "C:Windows", "\\Windows"]
)
def test_project_name_windows_absolute_path_raises_value_error(
tmp_path, project_name
):
def build_model(hp):
hp.Boolean("a")

with pytest.raises(ValueError, match="Absolute paths"):
gridsearch.GridSearch(
directory=tmp_path,
project_name=project_name,
hypermodel=build_model,
max_trials=1,
)


def test_tuner_id_forward_slash_raises_value_error(tmp_path):
def build_model(hp):
hp.Boolean("a")

import os
original_tuner_id = os.environ.get("KERASTUNER_TUNER_ID")
try:
os.environ["KERASTUNER_TUNER_ID"] = "evil/tuner"
with pytest.raises(ValueError, match="tuner_id"):
gridsearch.GridSearch(
directory=tmp_path,
hypermodel=build_model,
max_trials=1,
)
finally:
if original_tuner_id is not None:
os.environ["KERASTUNER_TUNER_ID"] = original_tuner_id
else:
os.environ.pop("KERASTUNER_TUNER_ID", None)