Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 102 additions & 45 deletions apps/Portal/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import stat
import subprocess
import sys
import tempfile
import time
import threading
import tomllib
Expand Down Expand Up @@ -100,6 +101,11 @@
TRAINPILOT_BIN = Path("/opt/pilot/apps/TrainPilot/trainpilot.sh")
TRAINPILOT_BUNDLED_TOML = Path("/opt/pilot/apps/TrainPilot/newlora.toml")
TRAINPILOT_PERSISTENT_TOML = WORKSPACE_ROOT / "config" / "trainpilot" / "newlora.toml"
_TRAINPILOT_LOCAL_PATH_ROOTS = (
WORKSPACE_ROOT.resolve(),
Path("/opt").resolve(),
Path(os.environ.get("HOME", "/root")).expanduser().resolve(),
)
_tp_proc: Optional[subprocess.Popen] = None
_tp_logs: deque[str] = deque(maxlen=4000)
_tp_output_dir: Optional[Path] = None
Expand Down Expand Up @@ -170,15 +176,56 @@ def _update_model_pull_job(job: ModelPullJob, line: str) -> None:
pass


def _run_model_pull_job(job: ModelPullJob, cmd: list[str]) -> None:
def _cleanup_temp_paths(paths: list[Path]) -> None:
for path in paths:
try:
path.unlink(missing_ok=True)
except Exception:
pass


def _build_model_pull_command(name: str) -> tuple[list[str], dict[str, str], list[Path]]:
manifest_line = models_service.manifest_line_for_name(
name,
MANIFEST,
DEFAULT_MANIFEST,
MODELS_DIR,
CONFIG_DIR,
)
temp_dir = CONFIG_DIR / "model-pulls"
temp_dir.mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile(
"w",
dir=temp_dir,
prefix="pull-",
suffix=".manifest",
encoding="utf-8",
delete=False,
) as handle:
handle.write(manifest_line + "\n")
manifest_path = Path(handle.name)
env = os.environ.copy()
env["MODELS_MANIFEST"] = str(manifest_path)
env.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "0")
return ["/opt/pilot/get-models.sh", "pull-all"], env, [manifest_path]


def _run_model_pull_job(
job: ModelPullJob,
cmd: list[str],
env: Optional[dict[str, str]] = None,
cleanup_paths: Optional[list[Path]] = None,
) -> None:
merged_env = os.environ.copy()
merged_env.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "0")
if env:
merged_env.update(env)
try:
env = os.environ.copy()
env.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "0")
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
env=env,
env=merged_env,
bufsize=0,
)
job.pid = proc.pid
Expand Down Expand Up @@ -215,6 +262,7 @@ def _run_model_pull_job(job: ModelPullJob, cmd: list[str]) -> None:
job.error = str(e)
job.updated_at = time.time()
finally:
_cleanup_temp_paths(cleanup_paths or [])
with _model_pull_lock:
_model_pull_jobs[job.name] = job

Expand Down Expand Up @@ -575,10 +623,6 @@ class DatasetEntry(BaseModel):
path: str


class TrainPilotModelCheckRequest(BaseModel):
toml_path: str = ""


def _toml_find_first_str(data, key: str) -> Optional[str]:
if isinstance(data, dict):
v = data.get(key)
Expand Down Expand Up @@ -1082,13 +1126,31 @@ def _clean_name(name: str) -> str:
return cleaned or "dataset"


def _resolve_under_root(root: Path, candidate: Path) -> Path:
def _path_is_within_root(candidate: Path, root: Path) -> bool:
root_resolved = os.path.realpath(str(root))
resolved = os.path.realpath(str(candidate))
root_with_sep = os.path.join(root_resolved, "")
if resolved != root_resolved and not resolved.startswith(root_with_sep):
return resolved == root_resolved or resolved.startswith(root_with_sep)


def _resolve_under_root(root: Path, candidate: Path) -> Path:
resolved = Path(os.path.realpath(str(candidate)))
if not _path_is_within_root(resolved, root):
raise HTTPException(status_code=400, detail="Invalid path")
return Path(resolved)
return resolved


def _resolve_local_path_from_roots(raw_value: str, *, field: str, roots: tuple[Path, ...]) -> Path:
raw = (raw_value or "").strip()
if not raw:
raise HTTPException(status_code=400, detail=f"{field} is required")
expanded = os.path.expandvars(os.path.expanduser(raw))
if not os.path.isabs(expanded):
raise HTTPException(status_code=400, detail=f"{field} must be an absolute local path")
resolved = Path(os.path.realpath(expanded))
if not any(_path_is_within_root(resolved, root) for root in roots):
raise HTTPException(status_code=400, detail=f"{field} must stay within approved directories")
return resolved


def _dataset_dir(name: str) -> Path:
Expand Down Expand Up @@ -1351,19 +1413,18 @@ def tagpilot_save_item(

@app.post("/api/models/{name}/pull")
def pull_model(name: str):
models_service.ensure_manifest(MANIFEST, DEFAULT_MANIFEST, MODELS_DIR, CONFIG_DIR)
entries = models_service.parse_manifest(MANIFEST, DEFAULT_MANIFEST, MODELS_DIR, CONFIG_DIR)
if not any(entry.name == name for entry in entries):
try:
cmd, env, cleanup_paths = _build_model_pull_command(name)
except KeyError:
raise HTTPException(status_code=404, detail="Unknown model")
cmd = ["/opt/pilot/get-models.sh", "pull", name]
print(f"[models] pull start name={name} cmd={' '.join(cmd)}", file=sys.stderr)
# Use existing CLI for consistency
try:
result = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
env=env,
check=True,
)
output = result.stdout or ""
Expand All @@ -1375,15 +1436,17 @@ def pull_model(name: str):
tail = output[-4000:] if len(output) > 4000 else output
print(f"[models] pull failed name={name} output_tail={tail!r}", file=sys.stderr)
raise HTTPException(status_code=500, detail="Model pull failed")
finally:
_cleanup_temp_paths(cleanup_paths)


@app.post("/api/models/{name}/pull/start")
def pull_model_start(name: str):
"""Start a model pull in the background (used by UI for progress updates)."""
_cleanup_model_pull_jobs()
models_service.ensure_manifest(MANIFEST, DEFAULT_MANIFEST, MODELS_DIR, CONFIG_DIR)
entries = models_service.parse_manifest(MANIFEST, DEFAULT_MANIFEST, MODELS_DIR, CONFIG_DIR)
if not any(e.name == name for e in entries):
try:
cmd, env, cleanup_paths = _build_model_pull_command(name)
except KeyError:
raise HTTPException(status_code=404, detail="Unknown model")

with _model_pull_lock:
Expand All @@ -1393,8 +1456,7 @@ def pull_model_start(name: str):
job = ModelPullJob(name=name)
_model_pull_jobs[name] = job

cmd = ["/opt/pilot/get-models.sh", "pull", name]
threading.Thread(target=_run_model_pull_job, args=(job, cmd), daemon=True).start()
threading.Thread(target=_run_model_pull_job, args=(job, cmd, env, cleanup_paths), daemon=True).start()
return _model_pull_job_to_dict(job)


Expand Down Expand Up @@ -2805,7 +2867,6 @@ class TrainPilotRequest(BaseModel):
dataset_name: str
output_name: str
profile: str = "regular"
toml_path: str = ""


def _tp_reader(proc: subprocess.Popen):
Expand All @@ -2831,21 +2892,6 @@ def _ensure_trainpilot_toml() -> Path:
raise HTTPException(status_code=500, detail=f"Bundled TrainPilot TOML not found at {TRAINPILOT_BUNDLED_TOML}")


def _resolve_trainpilot_toml_path(raw_path: str = "") -> Path:
raw = (raw_path or "").strip()
if not raw:
return _ensure_trainpilot_toml()
candidate = Path(raw)
if candidate == TRAINPILOT_BUNDLED_TOML:
candidate = _ensure_trainpilot_toml()
elif not candidate.is_absolute():
candidate = WORKSPACE_ROOT / candidate
candidate = _resolve_under_root(WORKSPACE_ROOT, candidate)
if candidate.suffix.lower() != ".toml":
raise HTTPException(status_code=400, detail="TrainPilot config must be a TOML file")
return candidate


@app.post("/api/trainpilot/start")
def trainpilot_start(req: TrainPilotRequest):
global _tp_proc
Expand All @@ -2865,9 +2911,7 @@ def trainpilot_start(req: TrainPilotRequest):
profile = req.profile.strip() or "regular"
if profile not in ("quick_test", "regular", "high_quality"):
raise HTTPException(status_code=400, detail="Invalid profile")
toml_path = _resolve_trainpilot_toml_path(req.toml_path)
if not toml_path.exists():
raise HTTPException(status_code=400, detail=f"TOML not found: {toml_path}")
toml_path = _ensure_trainpilot_toml()

# Add debugging info to logs
_tp_logs.append(f"=== Starting TrainPilot at {datetime.now().isoformat()} ===")
Expand Down Expand Up @@ -2928,15 +2972,13 @@ def trainpilot_stop():


@app.post("/api/trainpilot/model-check")
def trainpilot_model_check(req: TrainPilotModelCheckRequest):
def trainpilot_model_check():
"""
Parse the selected TrainPilot TOML and check that checkpoint + VAE files exist.
If they are missing and can be mapped to a manifest entry, return the model name
so the UI can offer to download with progress.
"""
toml_path = _resolve_trainpilot_toml_path(req.toml_path)
if not toml_path.exists():
raise HTTPException(status_code=404, detail=f"TOML not found: {toml_path}")
toml_path = _ensure_trainpilot_toml()

try:
raw = toml_path.read_bytes()
Expand Down Expand Up @@ -2968,7 +3010,22 @@ def check_one(kind: str, key: str, value: Optional[str]) -> dict:
"model_name": None,
"reason": "Not a local file path",
}
p = Path(value)
try:
p = _resolve_local_path_from_roots(
value,
field=key,
roots=_TRAINPILOT_LOCAL_PATH_ROOTS,
)
except HTTPException as exc:
return {
"kind": kind,
"key": key,
"value": value,
"is_local_path": True,
"exists": False,
"model_name": None,
"reason": str(exc.detail),
}
exists = p.exists()
model_name = None
if not exists:
Expand Down
Loading
Loading