-
Notifications
You must be signed in to change notification settings - Fork 49
Enhance checkpoint loading with error handling #583
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
-Downloads all checkpoint files (as before) -Subsequent runs: Skips existing files, proceeds directly to loading -No more redundant downloads: Each file downloaded only once per rank -Faster startup: After initial download, checkpoint loading should be much faster
WalkthroughRefactors checkpoint handling in src/tplr/dcp_checkpoint.py: adds per-file existence checks during downloads, tracks skipped/downloaded files, enhances logging, and consolidates download/load logic to always use download_all. Introduces explicit FileNotFoundError pre-checks and structured error handling for load_local and download_and_load. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant U as Caller
participant D as download_and_load
participant DA as download_all
participant FS as Filesystem
participant LL as load_local
Note over D,DA: Start download+load flow (consolidated path)
U->>D: download_and_load()
D->>DA: download_all()
alt File exists locally
DA->>FS: check file existence (per-key)
FS-->>DA: exists -> skip (count)
else Needs download
DA->>FS: write downloaded file
FS-->>DA: bytes written
end
DA-->>D: return checkpoint_dir
D->>FS: read extra_metadata.json
D->>LL: load_local(checkpoint_dir)
alt Directory or files missing
LL-->>D: raise FileNotFoundError
D-->>U: warn and re-raise
else Other error
LL-->>D: raise Exception
D-->>U: error and re-raise
else Success
LL-->>D: model/state
D-->>U: return model/state
end
rect rgba(220,245,255,0.5)
Note right of DA: New: per-file skip + counters
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Pre-merge checks (1 passed, 2 warnings)❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
Poem
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. ✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
src/tplr/dcp_checkpoint.py (1)
784-799: Don’t hard-code ranks 0–3; derive expected ranks from sidecar.This will misreport completeness when world_size_at_save != 4. Prior feedback noted this should be dynamic.
Proposed fix:
- # Check that rank files exist for validator world size of 4 (ranks 0-3) - has_rank_0 = any("__0_0.distcp" in obj.get("Key", "") for obj in objects) - has_rank_1 = any("__1_0.distcp" in obj.get("Key", "") for obj in objects) - has_rank_2 = any("__2_0.distcp" in obj.get("Key", "") for obj in objects) - has_rank_3 = any("__3_0.distcp" in obj.get("Key", "") for obj in objects) - all_ranks_present = has_rank_0 and has_rank_1 and has_rank_2 and has_rank_3 + # Derive expected ranks at save time from sidecar + world_size = 4 + try: + sidecar = await self.comms.s3_get_object( + key=f"{prefix}extra_metadata.json", bucket=bucket + ) + world_size = int(sidecar.get("world_size_at_save", world_size)) + except Exception: + tplr.logger.debug( + f"[DCP][exists] fallback world_size_at_save={world_size} (sidecar read failed)" + ) + def rank_present(i: int) -> bool: + needle = f"__{i}_0.distcp" + return any(needle in obj.get("Key", "") for obj in objects) + all_ranks_present = all(rank_present(i) for i in range(world_size))Optional: paginate list_objects_v2 like other methods; MaxKeys=100 risks truncation for larger checkpoints.
🧹 Nitpick comments (3)
src/tplr/dcp_checkpoint.py (3)
542-551: Minor logging/flag polish.
- Consider renaming got_any → saw_remote and downloaded_any → downloaded_any_remote to avoid mental overload.
- When got_any is False, add a short info/debug log so operators see why None is returned.
Example:
- if got_any: - return local_dir - return None + if got_any: + return local_dir + tplr.logger.info(f"[DCP][download-all] no remote objects under {layout.prefix}/") + return None
682-707: Pre-check directory is redundant; validate checkpoint integrity instead._ckpt_dir is created by _local_dir(), so the existence check rarely triggers. More useful: assert presence of DCP metadata and sidecar before load, and include a concise file summary on failure.
Apply:
- # Check if checkpoint directory exists - if not ckpt_dir.exists(): - raise FileNotFoundError(f"Checkpoint directory not found: {ckpt_dir}") - - # List available files for debugging - files = list(ckpt_dir.iterdir()) if ckpt_dir.exists() else [] + files = list(ckpt_dir.iterdir()) + names = {p.name for p in files} + required = {".metadata", "extra_metadata.json"} + missing = sorted(required - names) + if missing: + raise FileNotFoundError( + f"Checkpoint missing required files {missing} in {ckpt_dir}" + )
526-534: Throughput log is great; consider clamping dt to avoid inf and formatting spikes.Minor guard to keep logs clean on very fast downloads.
- f"{_mb(size):.2f} MiB in {dt:.2f}s ({_mb(size) / dt if dt > 0 else 0:.2f} MiB/s)" + f"{_mb(size):.2f} MiB in {max(dt,1e-6):.2f}s " + f"({_mb(size) / max(dt,1e-6):.2f} MiB/s)"
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/tplr/dcp_checkpoint.py(5 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: joellidin
PR: tplr-ai/templar#575
File: src/tplr/dcp_checkpoint.py:730-745
Timestamp: 2025-09-09T14:09:02.032Z
Learning: In src/tplr/dcp_checkpoint.py, the current implementation hard-codes validation for ranks 0-3 in the check_checkpoint_exists method, but the user joellidin has acknowledged this should be changed to dynamically read world_size_at_save from the sidecar and validate accordingly. This is deferred for later implementation.
🧬 Code graph analysis (1)
src/tplr/dcp_checkpoint.py (2)
tests/test_dcp_checkpoint.py (4)
debug(114-115)info(111-112)error(123-124)warning(117-118)src/tplr/comms.py (1)
get(1374-1485)
🪛 GitHub Actions: CI
src/tplr/dcp_checkpoint.py
[error] 1-1: Ruff format check failed. 1 file would be reformatted (src/tplr/dcp_checkpoint.py) by 'uv run ruff format --check'. 122 files already formatted. Run 'uv run ruff format' or 'ruff format' to fix.
🔇 Additional comments (4)
src/tplr/dcp_checkpoint.py (4)
1-1: Fix Ruff formatting to unblock CI.CI reports this file would be reformatted. Please run: uv run ruff format (or ruff format) and commit the result.
513-521: LGTM: Idempotent “skip-if-exists” prevents redundant downloads.Per-file existence checks with debug logging are a solid improvement for warm starts.
515-521: Guard against path traversal in S3 keys.dst = self.repo_root / key assumes keys are well-formed. If a compromised bucket produced keys with ../ segments, a downloader that writes to this path could escape repo_root.
Please confirm s3_get_object:
- normalizes and rejects keys containing .. or absolute paths, and
- writes via a tmp file + atomic rename in the target directory.
If not, I can provide a small sanitizer and atomic write helper.
718-746: Honor shared_fs in download_and_load to avoid redundant downloads on shared filesystems.download_and_load currently always calls download_all(), ignoring shared_fs — this risks N× downloads and partial writes on shared volumes. Call sites found: tests/test_dcp_checkpoint.py:388 (shared_fs=True), src/tplr/neurons.py:558, src/tplr/neurons.py:596, neurons/evaluator.py:869.
Consider restoring the distributed-path conditional:
- local_dir = await self.download_all( - window=window, prefer_highest_staked=prefer_highest_staked - ) + if shared_fs and _world() > 1: + local_dir = await self.download_distributed( + window=window, prefer_highest_staked=prefer_highest_staked + ) + else: + local_dir = await self.download_all( + window=window, prefer_highest_staked=prefer_highest_staked + )If distributed was intentionally disabled due to loader issues, log a clear warning when shared_fs=True so operators understand the trade-off.
-Downloads all checkpoint files (as before)
-Subsequent runs: Skips existing files, proceeds directly to loading
-No more redundant downloads: Each file downloaded only once per rank
-Faster startup: After initial download, checkpoint loading should be much faster
Description
Related Issue(s)
Type of Change
Branch Naming
Commit Messages
Code Quality
Testing
Documentation
If this is a breaking change
Screenshots/Examples
Additional Notes
Summary by CodeRabbit