Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 18 additions & 36 deletions dream-server/bin/dream-host-agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,40 +185,6 @@ def _precreate_data_dirs(service_id: str):
logger.warning("Failed to pre-create %s: %s", dir_path, e)


def _resolve_setup_hook(ext_dir: Path) -> Path | None:
"""Read manifest to find setup_hook path. Returns None if no hook defined."""
manifest_path = None
for name in ("manifest.yaml", "manifest.yml"):
candidate = ext_dir / name
if candidate.exists():
manifest_path = candidate
break
if manifest_path is None:
return None
try:
import yaml
manifest = yaml.safe_load(manifest_path.read_text(encoding="utf-8"))
except (ImportError, OSError):
return None
if not isinstance(manifest, dict):
return None
service_def = manifest.get("service", {})
if not isinstance(service_def, dict):
return None
setup_hook = service_def.get("setup_hook", "")
if not isinstance(setup_hook, str) or not setup_hook:
return None
hook_path = (ext_dir / setup_hook).resolve()
try:
hook_path.relative_to(ext_dir.resolve())
except ValueError:
logger.warning("Path traversal attempt in setup_hook for %s: %s", ext_dir.name, setup_hook)
return None
if not hook_path.is_file():
return None
return hook_path


def docker_compose_action(service_id: str, action: str) -> tuple:
flags = resolve_compose_flags()
if action == "start":
Expand Down Expand Up @@ -931,11 +897,27 @@ def _run_install():
if run_setup_hook:
_write_progress(service_id, "setup_hook", "Running setup...")
ext_dir = USER_EXTENSIONS_DIR / service_id
hook_path = _resolve_setup_hook(ext_dir)
hook_path = _resolve_hook(ext_dir, "post_install")
if hook_path:
# Minimal allowlist env — mirror _execute_hook (L856-866)
# to prevent leaking host-agent secrets to extension scripts.
manifest = _read_manifest(ext_dir)
service_def = manifest.get("service", {}) if manifest else {}
if not isinstance(service_def, dict):
service_def = {}
hook_env = {
"PATH": os.environ.get("PATH", "/usr/bin:/bin"),
"HOME": os.environ.get("HOME", ""),
"SERVICE_ID": service_id,
"SERVICE_PORT": str(service_def.get("port", 0)),
"SERVICE_DATA_DIR": str(DATA_DIR / service_id),
"DREAM_VERSION": DREAM_VERSION,
"GPU_BACKEND": GPU_BACKEND,
"HOOK_NAME": "post_install",
}
result = subprocess.run(
["bash", str(hook_path), str(INSTALL_DIR), GPU_BACKEND],
cwd=str(ext_dir),
cwd=str(ext_dir), env=hook_env,
capture_output=True, text=True,
timeout=SUBPROCESS_TIMEOUT_START,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,50 @@ def test_missing_cache_file_is_noop(self, tmp_path, monkeypatch):
monkeypatch.setattr(_mod, "INSTALL_DIR", install_dir)

invalidate_compose_cache() # must not raise


# --- Install setup-hook env allowlist (regression) ---
#
# Locks in the fix that strips host-agent secrets from the env passed to
# extension setup hooks during _handle_install. A source-level check is used
# because the subprocess.run call lives inside a nested closure started on a
# daemon thread, which makes dynamic mocking fragile.


class TestInstallHookEnvAllowlist:

def _install_source(self):
import inspect
return inspect.getsource(_mod.AgentHandler._handle_install)

def test_setup_hook_subprocess_run_passes_env_kwarg(self):
src = self._install_source()
assert "env=hook_env" in src, (
"setup_hook subprocess.run must pass env=hook_env "
"(regression: do not fall back to inheriting os.environ)"
)

def test_setup_hook_env_excludes_host_agent_secrets(self):
src = self._install_source()
for secret in ("AGENT_API_KEY", "DREAM_AGENT_KEY", "DASHBOARD_API_KEY"):
assert secret not in src, (
f"_handle_install must not reference {secret}; "
"extension setup hooks must not receive host-agent secrets"
)

def test_setup_hook_env_contains_allowlist_keys(self):
src = self._install_source()
for key in (
"PATH", "HOME", "SERVICE_ID", "SERVICE_PORT",
"SERVICE_DATA_DIR", "DREAM_VERSION", "GPU_BACKEND", "HOOK_NAME",
):
assert f'"{key}"' in src, (
f"setup_hook env allowlist missing required key {key}"
)

def test_setup_hook_uses_resolve_hook_with_post_install(self):
src = self._install_source()
assert '_resolve_hook(ext_dir, "post_install")' in src, (
"setup_hook must use _resolve_hook(..., 'post_install'); "
"the legacy _resolve_setup_hook has been removed"
)
Loading