Save next_inputs.pt only for last two layers for layerwise save/restore#1431
Save next_inputs.pt only for last two layers for layerwise save/restore#1431sugunav14 wants to merge 4 commits into
Conversation
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
📝 WalkthroughWalkthroughAdds automatic pruning of stale per-layer activation checkpoints during layerwise quantization calibration. A new ChangesCheckpoint Pruning Feature
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 6✅ Passed checks (6 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/quantization/utils/layerwise_calib.py`:
- Around line 694-700: The prune step currently suppresses only
FileNotFoundError, but other filesystem OSError conditions (permission errors,
transient FS failures) should also be treated as harmless; update the
contextlib.suppress call around os.remove to suppress OSError instead of
FileNotFoundError so that removing next_inputs.pt in the block using layer_idx,
_layer_dir(self.checkpoint_dir, stale), and os.remove is best-effort and won't
crash after a successful commit.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: bef7b418-3b23-4958-8c11-b317211322c6
📒 Files selected for processing (3)
modelopt/torch/quantization/model_calib.pymodelopt/torch/quantization/utils/layerwise_calib.pytests/unit/torch/quantization/test_sequential_checkpoint.py
| # Prune the next_inputs.pt for the layer two back — we only need the | ||
| # last two to support resume from the most recent commit point. Done | ||
| # after the manifest write so a crash here is harmless. | ||
| stale = layer_idx - 2 | ||
| if stale >= 0: | ||
| with contextlib.suppress(FileNotFoundError): | ||
| os.remove(os.path.join(_layer_dir(self.checkpoint_dir, stale), "next_inputs.pt")) |
There was a problem hiding this comment.
Make pruning truly best-effort for all filesystem failures
At Line 699, only FileNotFoundError is suppressed. Other OSError cases (e.g., permission/transient FS errors) will still fail the run after checkpoint commit, which contradicts the “harmless” intent in the comment.
Suggested fix
- if stale >= 0:
- with contextlib.suppress(FileNotFoundError):
- os.remove(os.path.join(_layer_dir(self.checkpoint_dir, stale), "next_inputs.pt"))
+ if stale >= 0:
+ stale_path = os.path.join(_layer_dir(self.checkpoint_dir, stale), "next_inputs.pt")
+ try:
+ os.remove(stale_path)
+ except FileNotFoundError:
+ pass
+ except OSError as e:
+ print_rank_0(f"Checkpoint: failed to prune stale next_inputs at {stale_path}: {e}")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@modelopt/torch/quantization/utils/layerwise_calib.py` around lines 694 - 700,
The prune step currently suppresses only FileNotFoundError, but other filesystem
OSError conditions (permission errors, transient FS failures) should also be
treated as harmless; update the contextlib.suppress call around os.remove to
suppress OSError instead of FileNotFoundError so that removing next_inputs.pt in
the block using layer_idx, _layer_dir(self.checkpoint_dir, stale), and os.remove
is best-effort and won't crash after a successful commit.
|
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1431 +/- ##
==========================================
- Coverage 76.95% 67.12% -9.84%
==========================================
Files 478 478
Lines 51648 51653 +5
==========================================
- Hits 39747 34673 -5074
- Misses 11901 16980 +5079
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
♻️ Duplicate comments (1)
modelopt/torch/quantization/utils/layerwise_calib.py (1)
690-696:⚠️ Potential issue | 🟠 Major | ⚡ Quick winMake stale
next_inputs.ptpruning truly best-effortLine 695 only suppresses
FileNotFoundError. OtherOSErrorcases can still fail the run after a successful checkpoint commit, which conflicts with the “harmless” pruning intent in Lines 691-693.Suggested minimal fix
stale = layer_idx - 2 if stale >= 0: - with contextlib.suppress(FileNotFoundError): - os.remove(os.path.join(_layer_dir(self.checkpoint_dir, stale), "next_inputs.pt")) + stale_path = os.path.join(_layer_dir(self.checkpoint_dir, stale), "next_inputs.pt") + with contextlib.suppress(OSError): + os.remove(stale_path)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/quantization/utils/layerwise_calib.py` around lines 690 - 696, The prune of stale "next_inputs.pt" is intended to be harmless but currently only suppresses FileNotFoundError; update the suppression so other filesystem errors are also ignored (e.g., use contextlib.suppress(OSError) instead of FileNotFoundError) around the os.remove call that uses _layer_dir(self.checkpoint_dir, stale) when stale = layer_idx - 2, so failures during this cleanup cannot crash after a successful checkpoint commit.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Duplicate comments:
In `@modelopt/torch/quantization/utils/layerwise_calib.py`:
- Around line 690-696: The prune of stale "next_inputs.pt" is intended to be
harmless but currently only suppresses FileNotFoundError; update the suppression
so other filesystem errors are also ignored (e.g., use
contextlib.suppress(OSError) instead of FileNotFoundError) around the os.remove
call that uses _layer_dir(self.checkpoint_dir, stale) when stale = layer_idx -
2, so failures during this cleanup cannot crash after a successful checkpoint
commit.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 277248d0-f52a-40b4-9fb5-04d79cb84c77
📒 Files selected for processing (1)
modelopt/torch/quantization/utils/layerwise_calib.py
There was a problem hiding this comment.
♻️ Duplicate comments (1)
modelopt/torch/quantization/utils/layerwise_calib.py (1)
643-648:⚠️ Potential issue | 🟠 Major | ⚡ Quick winMake pruning best-effort and bound stale index.
At Line 645–Line 647, prune can still fail with
OSError(permission/transient FS race) after checkpoint commit, which turns cleanup into a run-breaking path. Also, explicitly skipping negative stale indices avoids probing unintended paths for early layers.Suggested patch
def prune_old_next_inputs(self, layer_idx: int) -> None: """Delete the next_inputs.pt two layers back, keeping only the last two on disk.""" - old = os.path.join(_layer_dir(self.checkpoint_dir, layer_idx - 2), "next_inputs.pt") - if os.path.isfile(old): - os.remove(old) + stale = layer_idx - 2 + if stale < 0: + return + old = os.path.join(_layer_dir(self.checkpoint_dir, stale), "next_inputs.pt") + try: + os.remove(old) + except FileNotFoundError: + pass + except OSError as e: + print_rank_0(f"Checkpoint: failed to prune stale next_inputs at {old}: {e}")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/quantization/utils/layerwise_calib.py` around lines 643 - 648, The prune_old_next_inputs method currently computes a path for layer_idx - 2 and unconditionally probes and removes it, which can hit negative indices and raise OSError on remove; modify prune_old_next_inputs to first skip if layer_idx - 2 is negative, and when removing the file catch and suppress OSError (or log it) so cleanup is best-effort and won’t break execution; refer to prune_old_next_inputs, _layer_dir and self.checkpoint_dir when locating the stale "next_inputs.pt" file and wrap os.remove in a try/except OSError block.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Duplicate comments:
In `@modelopt/torch/quantization/utils/layerwise_calib.py`:
- Around line 643-648: The prune_old_next_inputs method currently computes a
path for layer_idx - 2 and unconditionally probes and removes it, which can hit
negative indices and raise OSError on remove; modify prune_old_next_inputs to
first skip if layer_idx - 2 is negative, and when removing the file catch and
suppress OSError (or log it) so cleanup is best-effort and won’t break
execution; refer to prune_old_next_inputs, _layer_dir and self.checkpoint_dir
when locating the stale "next_inputs.pt" file and wrap os.remove in a try/except
OSError block.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 57ac582d-1083-4f47-8981-18ac275ce8af
📒 Files selected for processing (1)
modelopt/torch/quantization/utils/layerwise_calib.py
|
|
||
| print_rank_0(f"Checkpoint: restored {self.start_layer} previously calibrated layers") | ||
|
|
||
| def prune_old_next_inputs(self, layer_idx: int) -> None: |
There was a problem hiding this comment.
| def prune_old_next_inputs(self, layer_idx: int) -> None: | |
| def _prune_old_next_inputs(self, layer_idx: int) -> None: |
What does this PR do?
Type of change: Optimization (no API or on-disk-format change)
Reduces disk usage of layerwise calibration checkpoints by pruning stale next_inputs.pt files after each layer save. Without this change, every non-final transformer layer's next_inputs.pt accumulates on disk, dominating checkpoint size for long calibration runs (e.g. ~250 MB × ~80 layers ≈ 20 GB for a 70B model). On resume we only ever read the most recent one, so older copies are dead weight.
The fix is small and localized: after each layer save commits its manifest, delete the next_inputs.pt for the layer two back. The disk footprint of activation copies becomes O(1) (~2 files) instead of O(N_layers). The static per-layer artifacts (weights.pt, quantizer_state.pt, output_meta.pt) are kept for every layer as before because full_restore needs them
Usage
No API change required — pruning happens automatically when checkpoint_dir is used with layerwise_calibrate
Testing
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: N/AAdditional Information
Summary by CodeRabbit
New Features
Refactor
Tests