Skip to content

Save next_inputs.pt only for last two layers for layerwise save/restore#1431

Open
sugunav14 wants to merge 4 commits into
mainfrom
svelury/save-last-2-layer-inputs
Open

Save next_inputs.pt only for last two layers for layerwise save/restore#1431
sugunav14 wants to merge 4 commits into
mainfrom
svelury/save-last-2-layer-inputs

Conversation

@sugunav14
Copy link
Copy Markdown
Contributor

@sugunav14 sugunav14 commented May 11, 2026

What does this PR do?

Type of change: Optimization (no API or on-disk-format change)

Reduces disk usage of layerwise calibration checkpoints by pruning stale next_inputs.pt files after each layer save. Without this change, every non-final transformer layer's next_inputs.pt accumulates on disk, dominating checkpoint size for long calibration runs (e.g. ~250 MB × ~80 layers ≈ 20 GB for a 70B model). On resume we only ever read the most recent one, so older copies are dead weight.

The fix is small and localized: after each layer save commits its manifest, delete the next_inputs.pt for the layer two back. The disk footprint of activation copies becomes O(1) (~2 files) instead of O(N_layers). The static per-layer artifacts (weights.pt, quantizer_state.pt, output_meta.pt) are kept for every layer as before because full_restore needs them

Usage

No API change required — pruning happens automatically when checkpoint_dir is used with layerwise_calibrate

  mtq.calibrate(
      model,
      algorithm={
          "method": "gptq",
          "use_sequential": True,
          "checkpoint_dir": "/path/to/checkpoint",
      },
      forward_loop=forward_loop,
  )

Testing

  • test_full_run_creates_checkpoints — extended to assert layer 0's next_inputs.pt is gone after a 3-layer run while layer 1's remains.
  • test_only_last_two_next_inputs_kept (new) — 5-layer walk-through verifying only the most recent next_inputs.pt is retained and all static files survive for every layer.
  • test_resume_matches_full_run — uses a calib_func that raises during layer 1 to simulate a real crash; asserts the resumed run produces bit-identical weights to a reference full run.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: N/A
  • Did you get Claude approval on this PR?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Enhanced resilience for quantization workflows with improved crash recovery and resumption capabilities
  • Refactor

    • Optimized disk storage footprint during quantization calibration by reducing checkpoint file retention
  • Tests

    • Extended test coverage for checkpoint file pruning and storage management
    • Added tests for crash recovery and job resumption in quantization calibration

Review Change Stack

Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
@sugunav14 sugunav14 requested a review from a team as a code owner May 11, 2026 21:34
@sugunav14 sugunav14 requested a review from Fridah-nv May 11, 2026 21:34
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 11, 2026

📝 Walkthrough

Walkthrough

Adds automatic pruning of stale per-layer activation checkpoints during layerwise quantization calibration. A new prune_old_next_inputs helper method deletes activation snapshots two layers back after checkpoint save completes, ensuring at most two activation next_inputs.pt files remain on disk. Tests validate pruning patterns, verify static checkpoint files are preserved, and confirm crash-and-resume capability with intermediate checkpoints.

Changes

Checkpoint Pruning Feature

Layer / File(s) Summary
Checkpoint Save and Pruning Logic
modelopt/torch/quantization/utils/layerwise_calib.py
Adds _CheckpointState.prune_old_next_inputs(layer_idx) method that deletes next_inputs.pt for layer_idx - 2 if it exists. Updates save method documentation to specify retention of two most recent layer activations and calls the pruning helper after manifest write to enforce disk cleanup.
Test Validation of Pruning and Resumption
tests/unit/torch/quantization/test_sequential_checkpoint.py
Adds pytest import for exception assertions. Extends test_full_run_creates_checkpoints with assertions for next_inputs.pt pruning across layer directories. Introduces test_only_last_two_next_inputs_kept to verify pruning pattern while confirming static checkpoint files persist. Updates test_resume_matches_full_run to simulate calibration interruption, validate intermediate checkpoint manifest, and verify resumption completes successfully.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 6
✅ Passed checks (6 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: optimization that reduces disk usage by pruning old next_inputs.pt files, keeping only the last two layers' activation copies for layerwise save/restore.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed All security anti-patterns reviewed and cleared. torch.load with weights_only=False calls have proper inline comments per SECURITY.md guidelines. No other patterns found.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch svelury/save-last-2-layer-inputs

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/quantization/utils/layerwise_calib.py`:
- Around line 694-700: The prune step currently suppresses only
FileNotFoundError, but other filesystem OSError conditions (permission errors,
transient FS failures) should also be treated as harmless; update the
contextlib.suppress call around os.remove to suppress OSError instead of
FileNotFoundError so that removing next_inputs.pt in the block using layer_idx,
_layer_dir(self.checkpoint_dir, stale), and os.remove is best-effort and won't
crash after a successful commit.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: bef7b418-3b23-4958-8c11-b317211322c6

📥 Commits

Reviewing files that changed from the base of the PR and between d30ebbd and 8c60a38.

📒 Files selected for processing (3)
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/utils/layerwise_calib.py
  • tests/unit/torch/quantization/test_sequential_checkpoint.py

Comment on lines +694 to +700
# Prune the next_inputs.pt for the layer two back — we only need the
# last two to support resume from the most recent commit point. Done
# after the manifest write so a crash here is harmless.
stale = layer_idx - 2
if stale >= 0:
with contextlib.suppress(FileNotFoundError):
os.remove(os.path.join(_layer_dir(self.checkpoint_dir, stale), "next_inputs.pt"))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Make pruning truly best-effort for all filesystem failures

At Line 699, only FileNotFoundError is suppressed. Other OSError cases (e.g., permission/transient FS errors) will still fail the run after checkpoint commit, which contradicts the “harmless” intent in the comment.

Suggested fix
-        if stale >= 0:
-            with contextlib.suppress(FileNotFoundError):
-                os.remove(os.path.join(_layer_dir(self.checkpoint_dir, stale), "next_inputs.pt"))
+        if stale >= 0:
+            stale_path = os.path.join(_layer_dir(self.checkpoint_dir, stale), "next_inputs.pt")
+            try:
+                os.remove(stale_path)
+            except FileNotFoundError:
+                pass
+            except OSError as e:
+                print_rank_0(f"Checkpoint: failed to prune stale next_inputs at {stale_path}: {e}")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/utils/layerwise_calib.py` around lines 694 - 700,
The prune step currently suppresses only FileNotFoundError, but other filesystem
OSError conditions (permission errors, transient FS failures) should also be
treated as harmless; update the contextlib.suppress call around os.remove to
suppress OSError instead of FileNotFoundError so that removing next_inputs.pt in
the block using layer_idx, _layer_dir(self.checkpoint_dir, stale), and os.remove
is best-effort and won't crash after a successful commit.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 11, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1431/

Built to branch gh-pages at 2026-05-11 23:06 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
@codecov
Copy link
Copy Markdown

codecov Bot commented May 11, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 67.12%. Comparing base (f591131) to head (92e144d).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1431      +/-   ##
==========================================
- Coverage   76.95%   67.12%   -9.84%     
==========================================
  Files         478      478              
  Lines       51648    51653       +5     
==========================================
- Hits        39747    34673    -5074     
- Misses      11901    16980    +5079     
Flag Coverage Δ
examples 41.62% <20.00%> (+0.74%) ⬆️
gpu 27.27% <20.00%> (-33.16%) ⬇️
regression 15.14% <20.00%> (+0.01%) ⬆️
unit 52.74% <100.00%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
modelopt/torch/quantization/utils/layerwise_calib.py (1)

690-696: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Make stale next_inputs.pt pruning truly best-effort

Line 695 only suppresses FileNotFoundError. Other OSError cases can still fail the run after a successful checkpoint commit, which conflicts with the “harmless” pruning intent in Lines 691-693.

Suggested minimal fix
         stale = layer_idx - 2
         if stale >= 0:
-            with contextlib.suppress(FileNotFoundError):
-                os.remove(os.path.join(_layer_dir(self.checkpoint_dir, stale), "next_inputs.pt"))
+            stale_path = os.path.join(_layer_dir(self.checkpoint_dir, stale), "next_inputs.pt")
+            with contextlib.suppress(OSError):
+                os.remove(stale_path)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/utils/layerwise_calib.py` around lines 690 - 696,
The prune of stale "next_inputs.pt" is intended to be harmless but currently
only suppresses FileNotFoundError; update the suppression so other filesystem
errors are also ignored (e.g., use contextlib.suppress(OSError) instead of
FileNotFoundError) around the os.remove call that uses
_layer_dir(self.checkpoint_dir, stale) when stale = layer_idx - 2, so failures
during this cleanup cannot crash after a successful checkpoint commit.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Duplicate comments:
In `@modelopt/torch/quantization/utils/layerwise_calib.py`:
- Around line 690-696: The prune of stale "next_inputs.pt" is intended to be
harmless but currently only suppresses FileNotFoundError; update the suppression
so other filesystem errors are also ignored (e.g., use
contextlib.suppress(OSError) instead of FileNotFoundError) around the os.remove
call that uses _layer_dir(self.checkpoint_dir, stale) when stale = layer_idx -
2, so failures during this cleanup cannot crash after a successful checkpoint
commit.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 277248d0-f52a-40b4-9fb5-04d79cb84c77

📥 Commits

Reviewing files that changed from the base of the PR and between 8c60a38 and 256282c.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/utils/layerwise_calib.py

@sugunav14 sugunav14 requested a review from realAsma May 11, 2026 21:56
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
modelopt/torch/quantization/utils/layerwise_calib.py (1)

643-648: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Make pruning best-effort and bound stale index.

At Line 645–Line 647, prune can still fail with OSError (permission/transient FS race) after checkpoint commit, which turns cleanup into a run-breaking path. Also, explicitly skipping negative stale indices avoids probing unintended paths for early layers.

Suggested patch
     def prune_old_next_inputs(self, layer_idx: int) -> None:
         """Delete the next_inputs.pt two layers back, keeping only the last two on disk."""
-        old = os.path.join(_layer_dir(self.checkpoint_dir, layer_idx - 2), "next_inputs.pt")
-        if os.path.isfile(old):
-            os.remove(old)
+        stale = layer_idx - 2
+        if stale < 0:
+            return
+        old = os.path.join(_layer_dir(self.checkpoint_dir, stale), "next_inputs.pt")
+        try:
+            os.remove(old)
+        except FileNotFoundError:
+            pass
+        except OSError as e:
+            print_rank_0(f"Checkpoint: failed to prune stale next_inputs at {old}: {e}")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/utils/layerwise_calib.py` around lines 643 - 648,
The prune_old_next_inputs method currently computes a path for layer_idx - 2 and
unconditionally probes and removes it, which can hit negative indices and raise
OSError on remove; modify prune_old_next_inputs to first skip if layer_idx - 2
is negative, and when removing the file catch and suppress OSError (or log it)
so cleanup is best-effort and won’t break execution; refer to
prune_old_next_inputs, _layer_dir and self.checkpoint_dir when locating the
stale "next_inputs.pt" file and wrap os.remove in a try/except OSError block.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Duplicate comments:
In `@modelopt/torch/quantization/utils/layerwise_calib.py`:
- Around line 643-648: The prune_old_next_inputs method currently computes a
path for layer_idx - 2 and unconditionally probes and removes it, which can hit
negative indices and raise OSError on remove; modify prune_old_next_inputs to
first skip if layer_idx - 2 is negative, and when removing the file catch and
suppress OSError (or log it) so cleanup is best-effort and won’t break
execution; refer to prune_old_next_inputs, _layer_dir and self.checkpoint_dir
when locating the stale "next_inputs.pt" file and wrap os.remove in a try/except
OSError block.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 57ac582d-1083-4f47-8981-18ac275ce8af

📥 Commits

Reviewing files that changed from the base of the PR and between 256282c and 92e144d.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/utils/layerwise_calib.py

Copy link
Copy Markdown
Collaborator

@shengliangxu shengliangxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM


print_rank_0(f"Checkpoint: restored {self.start_layer} previously calibrated layers")

def prune_old_next_inputs(self, layer_idx: int) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def prune_old_next_inputs(self, layer_idx: int) -> None:
def _prune_old_next_inputs(self, layer_idx: int) -> None:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants