Skip to content

Conversation

@erfanMhi
Copy link
Collaborator

@erfanMhi erfanMhi commented Dec 17, 2025

Delta Checkpointing: ~90% Bandwidth Reduction for Weight Communication

Implements sparse delta checkpointing to dramatically reduce upload/download bandwidth between trainer and miners/validators.

  • Sparse delta encoding: Transmits only non-zero weight changes in COO format, achieving ~90% bandwidth reduction between base checkpoints
  • Zstd compression: Additional compression on delta files for further bandwidth savings
  • Intelligent upload: Automatically uploads FULL checkpoints every 10 windows and DELTA checkpoints for intermediate windows, with fallback to FULL on failures
  • Efficient reconstruction: Miners/validators download small deltas and reconstruct full weights locally from cached base checkpoints
  • Hash verification: SHA256 hashing ensures bit-exact reconstruction and corruption detection
  • Enhanced monitoring: UploadResult metrics track sparsity, compression, and throughput; retry logic ensures reliability

Summary by CodeRabbit

Release Notes

  • New Features

    • Added delta checkpoint support for incremental model uploads with automatic sparse encoding and fallback to full checkpoints.
    • Introduced parameter change tracking to monitor parameter updates during training, including per-layer and per-component metrics.
    • Added sparse quality analysis to validate the fidelity of sparse weight updates.
  • Enhancements

    • Improved checkpoint management with per-subprocess monitoring labels for better training visibility.
    • Extended configuration with new environment-driven parameters for parameter tracking behavior and delta checkpoint settings.
    • Enhanced dependencies with zstandard for checkpoint compression.

✏️ Tip: You can customize this high-level summary in your review settings.

…; now trainer runs as long as replay buffer is not empty; pause confirmation is done as soon as GPU memory is freed now.
…parse weight updates and reconstruction; includes new utilities for computing and applying sparse deltas, along with integration tests for the delta upload and download flow.
…locks; improve comments for clarity and ensure delta checkpointing logic is consistent across components
…istent monitoring setup across subprocesses; update TrainerNeuron and upload_worker to utilize this new helper for improved logging and error handling.
…lities; implement robust error handling for model weight loading in CheckpointManager, CheckpointPublisher, and upload_worker, ensuring better resilience against missing or malformed checkpoint files.
…ints; enhance upload metrics with compression details for improved efficiency and monitoring.
…pdate ruff configuration to exclude research directory.
… enhanced upload metrics; refactor upload methods to return detailed results including timing and size metrics, improving monitoring and error handling during checkpoint uploads.
…to FULL; enhance error handling and logging for upload processes, improving reliability and monitoring of checkpoint uploads.
@cursor
Copy link

cursor bot commented Dec 17, 2025

You have run out of free Bugbot PR reviews for this billing cycle. This will reset on January 4.

To receive reviews on all of your PRs, visit the Cursor dashboard to activate Pro and start your 14-day free trial.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 17, 2025

Walkthrough

This PR introduces delta checkpoint support enabling efficient model updates through sparse weight deltas, adds parameter-change tracking across training steps, implements sparse quality analysis for update fidelity, and extends monitoring infrastructure for subprocess coordination and metrics logging.

Changes

Cohort / File(s) Summary
Delta checkpoint core
grail/infrastructure/delta_checkpoint.py, grail/infrastructure/checkpoint_consumer.py
New delta_checkpoint module provides sparse delta computation, application, hashing, and size estimation. CheckpointMetadata extended with checkpoint_type, base_window, weights_hash fields and is_delta() helper. get_checkpoint enhanced to handle delta reconstruction via _handle_delta_checkpoint and _reconstruct_from_delta.
Checkpoint publishing & upload
grail/trainer/checkpoint_publisher.py, grail/trainer/upload_worker.py
Introduced UploadError, UploadTiming, UploadResult dataclasses for structured upload outcomes with granular timing. Added upload_delta method for sparse delta publishing. _compute_keep_windows extended for delta base retention. upload_worker_loop now supports FULL/DELTA selection with base-interval logic and delta state caching.
Parameter and quality tracking
grail/trainer/param_tracker.py, grail/trainer/sparse_quality.py
New param_tracker module provides ParamChangeTracker for snapshot-based parameter-change measurement with per-layer/component breakdown and multi-threshold sparsity analysis. New sparse_quality module adds SparseQualityAnalyzer for evaluating fidelity of sparse updates via KL divergence, cosine similarity, MSE, and top-1 agreement metrics.
Monitoring infrastructure
grail/monitoring/__init__.py, grail/monitoring/manager.py, grail/monitoring/backends/wandb_backend.py
Added initialize_subprocess_monitoring async helper for robust subprocess monitoring setup with optional W&B run attachment. Extended WandB backend with optimizer_step metric and per-label worker directories (wandb_\<safe_label\>). Added param_change/* metric prefixes mapped to optimizer_step.
Training integration
grail/trainer/algorithms/grpo.py, grail/neurons/trainer.py, grail/trainer/training_process.py, grail/trainer/evaluator.py
GRPOAlgorithm augmented with ParamChangeTracker and SparseQualityAnalyzer integrated into training loop with periodic metrics logging. TrainingService._prepare_monitor_config now accepts subprocess_label. training_process.py refactored to use initialize_subprocess_monitoring and enhanced pause/resume with improved IPC coordination. evaluator.run_cycle extended with optional window_number parameter.
Configuration & utilities
grail/shared/constants.py, grail/trainer/config.py, grail/shared/safetensors_utils.py
New constants for parameter-change tracking (PARAM_CHANGE_MEASURE_INTERVAL, PARAM_CHANGE_THRESHOLD, etc.), sparse quality (SPARSE_QUALITY_ENABLED), and delta checkpointing (DELTA_BASE_INTERVAL, DELTA_THRESHOLD, DELTA_CHECKPOINT_ENABLED). TrainingConfig extended with corresponding fields. New load_model_state_dict utility supports unsharded and sharded safetensors layouts.
Dependencies & build
pyproject.toml
Moved zstandard>=0.21.0 from dev optional-dependencies to main dependencies. Extended Ruff exclude list to include research directory.
Tests
tests/integration/infrastructure/test_delta_checkpoint_flow.py, tests/unit/infrastructure/test_delta_checkpoint.py, tests/integration/test_train_cli_smoke.py, tests/unit/infrastructure/__init__.py
Comprehensive integration and unit test coverage for delta checkpoint end-to-end flows, persistence, reconstruction, hash verification, and retention policies. Updated train CLI smoke test to accept monitor_config parameter.
Research utilities
research/data/kernelbook_analysis.py
New exploratory script for KernelBook dataset analysis with test_module_from_code and compare_pytorch_vs_triton helper functions for kernel verification.

Sequence Diagram(s)

sequenceDiagram
    participant Trainer as Training Loop
    participant ParamTracker as ParamChangeTracker
    participant SparseQA as SparseQualityAnalyzer
    participant Monitor as MonitoringManager
    participant WandB as W&B Backend

    Trainer->>ParamTracker: capture_snapshot(model)
    Note over ParamTracker: Store weights to CPU
    
    Trainer->>Trainer: optimizer.step()
    Trainer->>ParamTracker: compute_metrics(model)
    Note over ParamTracker: Compute deltas, sparsity, sign-flips
    ParamTracker-->>Trainer: ParamChangeMetrics
    
    Trainer->>Monitor: log_param_change_metrics()
    Monitor->>WandB: log gauge metrics
    WandB-->>Monitor: ✓
    
    Trainer->>SparseQA: analyze(model, inputs)
    SparseQA->>SparseQA: compute sparse deltas
    SparseQA->>SparseQA: forward pass (sparse vs full)
    SparseQA->>SparseQA: compute KL/cosine/MSE/top-1
    SparseQA-->>Trainer: SparseQualityMetrics
    
    Trainer->>Monitor: log_sparse_quality_metrics()
    Monitor->>WandB: log per-threshold metrics
    WandB-->>Monitor: ✓
    
    Note over Trainer,WandB: Snapshot cleared after interval
    Trainer->>ParamTracker: clear_snapshot()
Loading
sequenceDiagram
    participant UploadWorker as upload_worker_loop
    participant Publisher as CheckpointPublisher
    participant Storage as Checkpoint Storage
    participant Base as Base Checkpoint Cache
    
    UploadWorker->>UploadWorker: determine FULL vs DELTA
    alt is base-interval or first checkpoint
        UploadWorker->>Publisher: upload_from_staging(full=True)
        Publisher->>Storage: upload checkpoint files
        Storage-->>Publisher: ✓
        Publisher-->>UploadWorker: UploadResult (FULL)
        UploadWorker->>Base: cache state as delta base
        Note over Base: base_window, base_state stored
    else intermediate checkpoint
        UploadWorker->>Publisher: upload_delta(base_state)
        Publisher->>Publisher: compute_sparse_delta()
        Publisher->>Publisher: compress delta (zstd)
        Publisher->>Storage: upload delta artifacts
        Storage-->>Publisher: ✓
        Publisher-->>UploadWorker: UploadResult (DELTA)
        Note over UploadWorker: Sparse ratio & metrics in result
    else delta fails
        UploadWorker->>Publisher: upload_from_staging(full=True)
        Note over UploadWorker: Fallback to FULL
    end
    
    UploadWorker->>UploadWorker: log metrics to monitor
Loading
sequenceDiagram
    participant CheckpointManager as CheckpointManager
    participant MetadataFetch as Fetch Metadata
    participant DeltaHandler as _handle_delta_checkpoint
    participant BaseCheckpoint as Base Checkpoint
    participant Reconstruct as _reconstruct_from_delta
    participant Storage as Local Storage
    
    CheckpointManager->>MetadataFetch: _fetch_metadata(window)
    MetadataFetch-->>CheckpointManager: CheckpointMetadata
    
    alt checkpoint_type == "DELTA"
        CheckpointManager->>DeltaHandler: _handle_delta_checkpoint(metadata)
        DeltaHandler->>BaseCheckpoint: get_checkpoint(base_window)
        alt base is DELTA
            BaseCheckpoint->>DeltaHandler: recursively download base
        else base is FULL
            BaseCheckpoint-->>DeltaHandler: base path
        end
        
        DeltaHandler->>DeltaHandler: download delta files to temp
        DeltaHandler->>DeltaHandler: verify delta integrity
        DeltaHandler->>Reconstruct: _reconstruct_from_delta(base, delta)
        
        Reconstruct->>Reconstruct: load base weights
        Reconstruct->>Reconstruct: apply_sparse_delta()
        Reconstruct->>Reconstruct: verify weights_hash
        Reconstruct->>Storage: save reconstructed model
        Reconstruct->>Storage: copy supplementary files
        Reconstruct->>Storage: write metadata (marked FULL)
        Reconstruct-->>DeltaHandler: output path
        DeltaHandler-->>CheckpointManager: checkpoint path
    else checkpoint_type == "FULL"
        CheckpointManager->>Storage: retrieve checkpoint directly
        Storage-->>CheckpointManager: checkpoint path
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Delta reconstruction logic (checkpoint_consumer.py, delta_checkpoint.py): Multi-step IO with recursive base handling, sparse operations, and hash verification require careful validation of edge cases (missing bases, integrity failures, mismatch scenarios).
  • Checkpoint publishing & upload (checkpoint_publisher.py, upload_worker.py): Delta vs FULL selection logic, base caching state machine, and fallback paths warrant thorough review of state consistency.
  • Parameter tracking integration (param_tracker.py, sparse_quality.py, grpo.py): Memory-safety assumptions (CPU offloading), snapshot lifecycle management, and metric computation correctness need verification.
  • Monitoring infrastructure (monitoring/manager.py, wandb_backend.py, training_process.py): Subprocess initialization, IPC coordination, and label sanitization patterns should be reviewed for robustness.
  • Configuration propagation (shared/constants.py, trainer/config.py, training_process.py): Environment variable plumbing and default value consistency across modules.

Possibly related PRs

  • PR #52: Introduces initial delta checkpoint infrastructure (CheckpointMetadata, CheckpointManager, delta upload/download foundations) that this PR extends with full reconstruction pipeline and sparse operations.
  • PR #48: Modifies TrainingConfig, GRPO algorithm, and trainer integration points that this PR augments with parameter-change tracking and sparse-quality analysis.
  • PR #47: Touches checkpoint producer/consumer/upload code paths that this PR significantly extends with delta handling and retention policy updates.

Poem

🐰 Delta hops through checkpoints lean,
Sparse weights where deltas gleam,
Parameter whispers, quality shine—
Base and delta dance in time!
Monitoring maps the rabbit's way, 🥕

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Title check ✅ Passed The PR title 'feat: add delta checkpointing' directly and concisely describes the main feature added—delta checkpointing for sparse checkpoint updates—which aligns with the substantial implementation across multiple modules.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feat/delta-checkpointing

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (13)
grail/monitoring/backends/wandb_backend.py (1)

40-41: Minor redundancy: param_change/sparse/ prefix is already covered by param_change/.

The "param_change/sparse/" entry is redundant since "param_change/" already matches and both map to "optimizer_step". The _get_step_metric_for_name method returns on first prefix match, so "param_change/sparse/foo" would match "param_change/" first. Consider removing the redundant entry for clarity.

 STEP_METRIC_PREFIXES = {
     "training/epoch/": "epoch",
     "training/batch/": "batch_step",
     "training/block/": "block_number",
     "training/prefilter/": "block_number",
     "param_change/": "optimizer_step",
-    "param_change/sparse/": "optimizer_step",
     "eval/": "block_number",
research/data/kernelbook_analysis.py (1)

27-42: Script executes on import due to top-level code.

The script loads datasets and prints output immediately when imported. This prevents importing the helper functions (test_module_from_code, compare_pytorch_vs_triton) without side effects. Consider wrapping the executable portions in an if __name__ == "__main__": guard.

+if __name__ == "__main__":
     print("Loading KernelBook dataset...")
     ds = load_dataset("GPUMODE/KernelBook", split="train")
     # ... rest of the script execution ...
grail/trainer/algorithms/grpo.py (1)

2073-2102: Consider extracting duplicated param tracking logic into a helper method.

The parameter change tracking and sparse quality analysis logic (lines 2073-2102) is nearly identical to the code in the main loop (lines 2002-2031). Extracting this into a private helper method would improve maintainability.

async def _log_param_tracking_metrics(
    self,
    model: Any,
    monitor: Any | None,
    input_ids: torch.Tensor | None = None,
    attention_mask: torch.Tensor | None = None,
) -> None:
    """Log parameter change and sparse quality metrics if due."""
    if not self.param_tracker.has_snapshot():
        return
    
    try:
        param_metrics = self.param_tracker.compute_metrics(model)
        await log_param_change_metrics(param_metrics, monitor, self.optimizer_step_count)
    except Exception as exc:
        logger.warning("Failed to compute param change metrics: %s", exc)

    if self.sparse_analyzer.enabled and input_ids is not None:
        try:
            sparse_metrics = self.sparse_analyzer.analyze(model, input_ids, attention_mask)
            await log_sparse_quality_metrics(sparse_metrics, monitor, self.optimizer_step_count)
        except Exception as exc:
            logger.warning("Failed to compute sparse quality metrics: %s", exc)
tests/integration/infrastructure/test_delta_checkpoint_flow.py (1)

347-349: Hardcoded interval check may become fragile.

The test assumes DELTA_BASE_INTERVAL == 100 with a conditional assertion. If the constant changes, the test silently skips the assertions rather than failing.

Consider making the test more robust by computing expected base windows dynamically:

-        if DELTA_BASE_INTERVAL == 100:
-            assert 200 * WINDOW_LENGTH in keep, "Base window 200 should be kept"
-            assert 100 * WINDOW_LENGTH in keep, "Base window 100 should be kept"
+        # Verify base windows aligned to DELTA_BASE_INTERVAL are kept
+        base_window_200 = 200 * WINDOW_LENGTH
+        base_window_100 = 100 * WINDOW_LENGTH
+        if base_window_200 % (DELTA_BASE_INTERVAL * WINDOW_LENGTH) == 0:
+            assert base_window_200 in keep, f"Base window {base_window_200} should be kept"
+        if base_window_100 % (DELTA_BASE_INTERVAL * WINDOW_LENGTH) == 0:
+            assert base_window_100 in keep, f"Base window {base_window_100} should be kept"
grail/shared/safetensors_utils.py (1)

60-62: Consider logging a warning for empty or invalid weight_map entries.

The validation correctly raises ValueError for missing/invalid weight_map, but if weight_map contains non-string values, they're silently cast via str(v) on line 64. This is likely fine for valid safetensors but could mask unexpected formats.

     weight_map = index.get("weight_map") if isinstance(index, dict) else None
     if not isinstance(weight_map, dict) or not weight_map:
         raise ValueError("Invalid safetensors index: missing/invalid 'weight_map'")
+
+    # Validate weight_map values are strings
+    non_string_values = [k for k, v in weight_map.items() if not isinstance(v, str)]
+    if non_string_values:
+        logger.warning(
+            "weight_map contains non-string shard references for keys: %s",
+            non_string_values[:5],
+        )
grail/monitoring/manager.py (1)

454-455: Consider moving imports to module level.

The os and time imports inside the function are standard library modules that could be imported at the module level for consistency with the rest of the file. This is a minor style suggestion.

+import os
+import time
 from collections.abc import Awaitable, Callable
 ...

 async def initialize_subprocess_monitoring(
     ...
 ) -> MonitoringManager | None:
-    import os
-    import time
-
     if not monitor_config:
grail/trainer/training_process.py (1)

61-62: Add validation for environment variable parsing.

If these environment variables contain invalid float values, float() will raise ValueError. Consider adding validation or wrapping in try/except with a warning log.

-WARMUP_FRACTION = float(os.getenv("GRAIL_WARMUP_FRACTION", "0.05"))
-SCHEDULER_ETA_MIN = float(os.getenv("GRAIL_SCHEDULER_ETA_MIN", "1e-7"))
+def _parse_float_env(name: str, default: str) -> float:
+    val = os.getenv(name, default)
+    try:
+        return float(val)
+    except ValueError:
+        logger.warning("Invalid %s='%s', using default %s", name, val, default)
+        return float(default)
+
+WARMUP_FRACTION = _parse_float_env("GRAIL_WARMUP_FRACTION", "0.05")
+SCHEDULER_ETA_MIN = _parse_float_env("GRAIL_SCHEDULER_ETA_MIN", "1e-7")
grail/trainer/checkpoint_publisher.py (1)

762-784: Handle edge case where delta contains no changes.

If sparse_tensors is empty (no parameter changes), the safetensors file will be empty. While handled correctly, the compression ratio calculation on line 782-784 could divide by zero if both sizes are 0 (though unlikely in practice).

 compression_ratio = (
-    delta_raw_size / delta_compressed_size if delta_compressed_size > 0 else 1.0
+    delta_raw_size / delta_compressed_size if delta_compressed_size > 0 and delta_raw_size > 0 else 1.0
 )
grail/trainer/upload_worker.py (1)

136-138: Consider memory implications of caching base state.

base_state holds the entire model state dict in memory for delta computation. For large models (e.g., 7B+ parameters), this could be significant. Consider:

  1. Documenting the memory overhead
  2. Adding a mechanism to clear base_state when memory is constrained

This is acceptable for now given the bandwidth savings, but worth monitoring.

grail/infrastructure/delta_checkpoint.py (3)

55-58: Consider handling parameters missing from current_state.

Parameters present in base_state but absent from current_state are silently ignored. If this represents a model architecture mismatch, it might be worth logging a warning similar to the one for the reverse case (line 57).


126-143: Orphaned sparse delta entries are silently ignored.

If sparse_tensors contains {name}.indices for a parameter that doesn't exist in base_state, those deltas will never be applied. This could indicate a mismatch between the delta and base checkpoint. Consider adding a validation check or warning.

+    # Check for orphaned delta entries
+    delta_param_names = {k.rsplit('.', 1)[0] for k in sparse_tensors if k.endswith('.indices')}
+    missing_in_base = delta_param_names - set(base_state.keys())
+    if missing_in_base:
+        logger.warning("Delta contains parameters not in base state: %s", missing_in_base)
+
     for name, base_tensor in base_state.items():

219-220: Minor: Consider caching element sizes to avoid tensor allocation.

Creating empty tensors on each call is slightly wasteful. Since the typical dtypes are fixed, you could use constants or a simple lookup.

+# Element sizes for common dtypes (bytes)
+_DTYPE_SIZES = {
+    torch.int32: 4,
+    torch.int64: 8,
+    torch.float32: 4,
+    torch.float16: 2,
+    torch.bfloat16: 2,
+}
+
+
 def estimate_sparse_size(
     nonzero_params: int,
     index_dtype: torch.dtype = torch.int32,
     value_dtype: torch.dtype = torch.float32,
 ) -> int:
-    index_size = nonzero_params * torch.tensor([], dtype=index_dtype).element_size()
-    value_size = nonzero_params * torch.tensor([], dtype=value_dtype).element_size()
+    index_size = nonzero_params * _DTYPE_SIZES.get(
+        index_dtype, torch.tensor([], dtype=index_dtype).element_size()
+    )
+    value_size = nonzero_params * _DTYPE_SIZES.get(
+        value_dtype, torch.tensor([], dtype=value_dtype).element_size()
+    )
     return index_size + value_size
grail/trainer/sparse_quality.py (1)

497-510: Consider adding a fallback if 1e-6 threshold is not present.

The summary log only fires if a threshold of exactly 1e-6 exists in the metrics. If thresholds are configured differently, no summary will be logged. This is likely fine for the current use case, but consider logging the first threshold as a fallback.

📜 Review details

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0186154 and c4d2125.

⛔ Files ignored due to path filters (1)
  • uv.lock is excluded by !**/*.lock
📒 Files selected for processing (22)
  • grail/infrastructure/checkpoint_consumer.py (7 hunks)
  • grail/infrastructure/delta_checkpoint.py (1 hunks)
  • grail/monitoring/__init__.py (2 hunks)
  • grail/monitoring/backends/wandb_backend.py (4 hunks)
  • grail/monitoring/manager.py (2 hunks)
  • grail/neurons/trainer.py (7 hunks)
  • grail/shared/constants.py (1 hunks)
  • grail/shared/safetensors_utils.py (1 hunks)
  • grail/trainer/algorithms/grpo.py (4 hunks)
  • grail/trainer/checkpoint_publisher.py (7 hunks)
  • grail/trainer/config.py (1 hunks)
  • grail/trainer/evaluator.py (1 hunks)
  • grail/trainer/param_tracker.py (1 hunks)
  • grail/trainer/sparse_quality.py (1 hunks)
  • grail/trainer/training_process.py (8 hunks)
  • grail/trainer/upload_worker.py (11 hunks)
  • pyproject.toml (2 hunks)
  • research/data/kernelbook_analysis.py (1 hunks)
  • tests/integration/infrastructure/test_delta_checkpoint_flow.py (1 hunks)
  • tests/integration/test_train_cli_smoke.py (1 hunks)
  • tests/unit/infrastructure/__init__.py (1 hunks)
  • tests/unit/infrastructure/test_delta_checkpoint.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (11)
tests/integration/infrastructure/test_delta_checkpoint_flow.py (3)
grail/infrastructure/checkpoint_consumer.py (3)
  • CheckpointMetadata (69-90)
  • is_delta (88-90)
  • _compute_keep_windows (717-747)
grail/infrastructure/delta_checkpoint.py (3)
  • apply_sparse_delta (105-143)
  • compute_sparse_delta (30-102)
  • compute_weights_hash (146-175)
grail/trainer/checkpoint_publisher.py (2)
  • is_delta (148-150)
  • _compute_keep_windows (176-212)
grail/monitoring/__init__.py (1)
grail/monitoring/manager.py (1)
  • initialize_subprocess_monitoring (416-560)
grail/monitoring/manager.py (4)
grail/trainer/training_process.py (1)
  • get_block_context (316-320)
grail/monitoring/backends/wandb_backend.py (1)
  • start_run (719-758)
grail/monitoring/backends/null_backend.py (1)
  • start_run (73-83)
grail/monitoring/base.py (1)
  • start_run (128-138)
grail/trainer/param_tracker.py (1)
grail/monitoring/manager.py (1)
  • log_gauge (158-186)
grail/infrastructure/checkpoint_consumer.py (2)
grail/shared/safetensors_utils.py (1)
  • load_model_state_dict (25-72)
grail/infrastructure/delta_checkpoint.py (2)
  • apply_sparse_delta (105-143)
  • compute_weights_hash (146-175)
tests/unit/infrastructure/test_delta_checkpoint.py (1)
grail/infrastructure/delta_checkpoint.py (5)
  • apply_sparse_delta (105-143)
  • compute_sparse_delta (30-102)
  • compute_weights_hash (146-175)
  • estimate_sparse_size (204-221)
  • verify_weights_hash (178-201)
grail/trainer/sparse_quality.py (2)
grail/trainer/param_tracker.py (4)
  • ParamChangeTracker (153-474)
  • from_config (215-231)
  • has_snapshot (254-260)
  • get_snapshot (262-273)
grail/monitoring/manager.py (1)
  • log_gauge (158-186)
grail/neurons/trainer.py (4)
grail/trainer/upload_worker.py (1)
  • run_upload_worker (436-488)
grail/monitoring/manager.py (1)
  • set_block_context (105-113)
grail/trainer/evaluator.py (1)
  • run_cycle (181-271)
grail/neurons/base.py (1)
  • heartbeat (123-125)
grail/trainer/upload_worker.py (2)
grail/shared/safetensors_utils.py (1)
  • load_model_state_dict (25-72)
grail/trainer/checkpoint_publisher.py (7)
  • UploadError (74-77)
  • UploadResult (126-173)
  • is_delta (148-150)
  • to_dict (112-122)
  • to_dict (152-173)
  • upload_delta (698-983)
  • upload_from_staging (514-696)
grail/trainer/checkpoint_publisher.py (3)
grail/infrastructure/checkpoint_consumer.py (3)
  • CheckpointMetadata (69-90)
  • is_delta (88-90)
  • remote_prefix (85-86)
grail/infrastructure/delta_checkpoint.py (1)
  • compute_sparse_delta (30-102)
grail/shared/safetensors_utils.py (1)
  • load_model_state_dict (25-72)
grail/trainer/training_process.py (3)
grail/monitoring/manager.py (1)
  • initialize_subprocess_monitoring (416-560)
grail/trainer/snapshot_manager.py (2)
  • save_snapshot_atomic (59-127)
  • get_latest_snapshot_path (195-202)
grail/trainer/ipc.py (2)
  • queue_snapshot (122-142)
  • confirm_pause (95-97)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: Run Tests (3.10)
  • GitHub Check: Run Tests (3.11)
  • GitHub Check: Run Tests (3.10)
  • GitHub Check: Run Tests (3.11)
  • GitHub Check: Run Tests (3.11)
🔇 Additional comments (63)
tests/unit/infrastructure/__init__.py (1)

1-1: LGTM!

Standard package initialization for the infrastructure unit tests module.

tests/unit/infrastructure/test_delta_checkpoint.py (8)

1-19: LGTM!

Clean imports with all necessary dependencies for testing the delta checkpoint functionality.


21-135: LGTM!

Comprehensive test coverage for compute_sparse_delta including edge cases (no changes, full changes), threshold filtering, multidimensional tensors, and dtype handling. The sparsity ratio calculations and index assertions are correct.


137-196: LGTM!

Good coverage for apply_sparse_delta including empty delta handling, multidimensional reconstruction, and dtype conversion verification.


198-253: LGTM!

Round-trip tests properly validate exact reconstruction with threshold=0.0 and appropriately account for bf16 precision loss with atol=1e-2.


255-314: LGTM!

Thorough hash function testing including determinism, bf16 support, collision resistance, key-order independence, and format validation.


316-332: LGTM!

Verification tests cover both positive and negative cases appropriately.


334-348: LGTM!

Size estimation tests correctly validate the byte calculations for different dtype combinations.


350-382: LGTM!

Good integration test for safetensors round-trip. Note that this test passes shapes directly from compute_sparse_delta to apply_sparse_delta without serializing it. In production, shapes would need to be persisted alongside the delta (e.g., in JSON metadata as shown in the integration test reference). This is acceptable for a unit test scope.

pyproject.toml (2)

227-227: LGTM!

Excluding the research directory from Ruff linting is reasonable for experimental/research code that may not adhere to strict project conventions.


56-56: Add zstandard dependency for delta checkpoint compression.

The zstandard library is correctly added as a core dependency to support Zstd compression for delta checkpoints. The package has no known security vulnerabilities and is deemed safe to use.

grail/monitoring/backends/wandb_backend.py (4)

25-32: LGTM!

Adding optimizer_step to RESERVED_TAGS correctly ensures it's treated as an x-axis field rather than being appended to metric names, consistent with its use as a step metric for parameter change tracking.


241-256: LGTM!

Good improvement to use per-label directories (wandb_<safe_label>) instead of a static wandb_training path. The sanitization of the label to alphanumeric/dash/underscore characters prevents invalid directory names and avoids file conflicts between multiple worker processes.


311-320: LGTM!

Adding optimizer_step as a base metric is required for param_change metrics to use it as their x-axis, consistent with the other changes in this file.


331-333: LGTM!

Properly defines the param_change/* metric family with optimizer_step as the step metric, enabling dedicated visualization for parameter change tracking.

grail/trainer/config.py (1)

78-88: LGTM!

Clean addition of parameter change tracking and sparse quality configuration fields. The fields follow the established pattern of sourcing defaults from constants, enabling environment-driven configurability. The naming is consistent and the inline comments provide helpful context.

research/data/kernelbook_analysis.py (2)

83-85: Security: exec() executes arbitrary code from the dataset.

Using exec() on code loaded from an external dataset introduces security risks. This is acceptable for local research scripts, but ensure this script is never used in production or with untrusted data sources.

If this script could be invoked on untrusted inputs, consider sandboxing or limiting execution scope.


1-57: Verify this file belongs in the delta checkpointing PR.

This file appears to be a KernelBook dataset analysis script unrelated to the delta checkpointing feature described in the PR objectives. It may have been accidentally included.

grail/shared/constants.py (1)

247-284: LGTM! Well-structured configuration constants.

The new constants for parameter change tracking, sparse quality analysis, and delta checkpoint configuration are well-documented with clear comments, consistent naming conventions (GRAIL_ prefix), and sensible defaults. The environment variable fallbacks enable runtime configurability.

grail/trainer/param_tracker.py (4)

275-296: LGTM! Memory-safe snapshot implementation.

The snapshot implementation correctly clones before moving to CPU (line 291), which prevents issues with in-place modifications. The warning on overwrite (line 282) is helpful for debugging.


345-351: Good: Float32 conversion before delta computation.

Converting bfloat16 tensors to float32 before subtraction (lines 347-348) is critical for detecting small deltas that bfloat16 cannot represent. The inline comment clearly explains the rationale.


356-371: Consider making diagnostic logging conditional on verbosity.

The first-parameter diagnostic logging runs unconditionally at DEBUG level. This is fine for development but could add overhead in production with DEBUG enabled. The current approach is acceptable since DEBUG should be off in production.


308-331: Memory consideration: Full model copy resides on CPU during tracking.

The snapshot stores a CPU copy of all trainable parameters. For large models, this approximately doubles CPU memory requirements during the measurement interval. This is documented in the class docstring but worth keeping in mind for very large models.

tests/integration/test_train_cli_smoke.py (1)

104-122: LGTM! Test helper updated to match new function signature.

The _fake_run_upload_worker helper correctly adds the monitor_config parameter to match the updated production signature, maintaining test compatibility with the new monitoring flow.

grail/monitoring/__init__.py (1)

10-25: LGTM! Clean public API extension.

The new initialize_subprocess_monitoring export is properly added to both the import statement and __all__, following the existing pattern. This provides a clean interface for subprocess monitoring initialization.

grail/trainer/evaluator.py (1)

188-204: LGTM!

The window_number parameter is cleanly added with proper documentation. The internal storage in self._current_window_number enables window-aware evaluation context tracking.

grail/trainer/algorithms/grpo.py (1)

1067-1086: LGTM!

The parameter tracking and sparse quality analysis initialization follows a clean pattern with factory methods and conditional logging.

tests/integration/infrastructure/test_delta_checkpoint_flow.py (3)

32-37: LGTM!

The temp_cache fixture correctly uses yield for cleanup and shutil.rmtree with ignore_errors=True for robust teardown.


70-98: LGTM!

The roundtrip test properly validates the compute + apply delta flow with appropriate tolerance checks.


355-398: Good edge case coverage!

The tests for 0% sparsity (all weights change), 100% sparsity (no weights change), and single parameter change are valuable boundary tests.

grail/shared/safetensors_utils.py (1)

25-71: LGTM!

Clean implementation supporting both sharded and unsharded safetensors layouts with proper error handling for missing shards and malformed indices.

grail/monitoring/manager.py (2)

416-560: LGTM!

The initialize_subprocess_monitoring function is well-designed with robust error handling, detailed logging at each stage, and graceful degradation (returning None on failure). The step-by-step approach with timing information will aid debugging.


461-463: Verify environment variable side effects.

Setting WANDB_DISABLE_SERVICE and WANDB_SERVICE environment variables affects the entire process. This is intentional for shared-mode W&B, but worth noting that these settings persist after the function returns.

If this function is called multiple times in the same process (e.g., during testing or retry scenarios), the environment modifications are idempotent. However, if there are scenarios where different subprocess types need different W&B configurations, consider documenting this behavior.

grail/neurons/trainer.py (5)

185-185: LGTM: Subprocess label now passed explicitly.

The keyword-only argument pattern ensures callers explicitly identify the subprocess, improving code clarity and preventing accidental misuse.


207-219: LGTM: Upload worker receives proper subprocess label and monitor config.

The changes correctly pass subprocess_label="upload_worker" and propagate monitor_config to the upload worker process, enabling per-subprocess W&B logging.


336-336: LGTM: Block context now includes window number.

Passing current_window instead of None provides more precise monitoring context for metrics correlation.


807-809: LGTM: Window number forwarded to evaluation cycle.

This enables window-aware evaluation context as noted in the relevant code snippet for evaluator.run_cycle.


827-891: LGTM: Subprocess monitoring configuration with label propagation.

The method correctly:

  • Uses keyword-only argument for subprocess_label
  • Propagates label to wandb_x_label for shared mode workers
  • Includes comprehensive debug logging for troubleshooting
  • Validates critical parameters (entity, project)
grail/trainer/training_process.py (4)

306-327: LGTM: Monitoring initialization properly delegated.

The refactored _initialize_monitoring correctly:

  • Defines a local get_block_context async helper
  • Uses the shared initialize_subprocess_monitoring helper for consistent behavior across subprocesses
  • Passes test_connection=True to verify W&B connectivity

368-410: LGTM: Initial checkpoint upload with enriched metadata and IPC queuing.

The updated flow:

  • Prepares structured snapshot_metadata with training config and parent window
  • Queues snapshot for upload worker via IPC
  • Includes appropriate fallback warning when IPC/path unavailable

458-468: LGTM: GRPO group loading optimized to trigger only on window change.

This reduces redundant trusted miner lookups by checking target_data_window != self.last_loaded_window before querying. Good performance optimization.


592-612: LGTM: Pause confirmation signaled before snapshot save.

The comment explains the rationale well: GPU is freed, so evaluation can start immediately while the snapshot saves in parallel (30-90s). This prevents orchestrator timeout waiting for confirmation.

grail/trainer/checkpoint_publisher.py (4)

74-77: LGTM: Custom exception for upload failures.

Clean separation of upload-related errors from general exceptions.


80-173: LGTM: Well-designed result dataclasses.

Good use of:

  • frozen=True, slots=True for immutability and memory efficiency
  • Computed total_s property that sums timing components
  • is_delta property for type discrimination
  • to_dict() methods for metrics logging

176-212: LGTM: Retention policy extended for delta base windows.

The logic correctly:

  • Computes base_stride_blocks from delta interval and window length
  • Retains base windows that delta checkpoints depend on
  • Ensures delta reconstruction remains possible

698-983: LGTM: Comprehensive delta upload implementation.

The upload_delta method:

  • Loads current weights and computes sparse delta against base
  • Applies zstd compression (level 3 for good balance)
  • Includes SHA256 hash for integrity verification
  • Provides detailed timing breakdown
  • Properly cleans up temp directory in finally block
  • Returns structured UploadResult with delta-specific metrics
grail/trainer/upload_worker.py (4)

97-106: LGTM: Subprocess monitoring initialization.

Correctly initializes monitoring with test_connection=False since subtensor isn't ready yet, then sets block context later in the loop when subtensor is available.


175-237: LGTM: Robust FULL vs DELTA upload decision with fallback.

The logic correctly:

  • Uploads FULL if delta disabled, no base cached, or at base interval boundary
  • Falls back from DELTA to FULL if delta upload fails
  • Tracks did_fallback_to_full to properly cache base state

244-267: LGTM: Base state caching after FULL upload.

Properly caches the checkpoint state after FULL uploads (including fallbacks) with informative logging of tensor count and parameter count.


350-433: LGTM: Upload retry with delta support.

The refactored _upload_with_retry:

  • Accepts is_delta, base_window, base_state parameters
  • Routes to appropriate publisher method
  • Validates delta prerequisites before attempting delta upload
  • Returns UploadResult for structured metrics
  • Properly re-raises UploadError after exhausting retries
grail/infrastructure/checkpoint_consumer.py (6)

80-90: LGTM: CheckpointMetadata extended for delta support.

Clean additions:

  • checkpoint_type defaults to "FULL" for backward compatibility
  • base_window tracks the base checkpoint for deltas
  • weights_hash enables integrity verification
  • is_delta() helper for type checking

203-216: LGTM: Delta checkpoint handling integrated into get_checkpoint.

The flow correctly:

  • Checks metadata.is_delta() before attempting reconstruction
  • Logs base window for debugging
  • Raises CheckpointDownloadError on reconstruction failure

662-679: LGTM: Comprehensive file copying logic.

The exclusion list for copying non-weight files from base is thorough:

  • Excludes model weight files (safetensors, index files)
  • Excludes metadata and signature files
  • Excludes sharded weight files (model-*.safetensors)
  • Excludes READY markers

This ensures only tokenizer, config, and other auxiliary files are copied.


506-559: LGTM: Delta checkpoint handling with recursive base fetching.

The _handle_delta_checkpoint method:

  • Validates base_window is present
  • Recursively fetches base checkpoint (handles chained deltas)
  • Downloads delta files to separate temp directory
  • Verifies delta integrity before reconstruction
  • Cleans up temp directory in finally block

821-840: LGTM: Cache directory now honors GRAIL_CACHE_DIR.

Good addition for flexibility:

  • Supports ~ expansion for home directory paths
  • Documents RAM-disk usage recommendation (/dev/shm/grail)
  • Falls back to ~/.cache/grail when not set

717-747: LGTM: Retention logic extended for delta base windows.

Consistent with the publisher's _compute_keep_windows, ensuring both producer and consumer retain the same base windows needed for delta reconstruction.

grail/infrastructure/delta_checkpoint.py (2)

146-175: LGTM!

The hashing implementation is robust: sorted keys ensure determinism, and including dtype and shape in the hash prevents false matches across different tensor configurations. The view(torch.uint8) approach correctly handles bfloat16 tensors that can't be converted to numpy directly.


178-201: LGTM!

Clean verification function with appropriate error logging for debugging hash mismatches.

grail/trainer/sparse_quality.py (5)

32-86: LGTM!

Well-structured dataclasses with clear separation between per-threshold metrics and aggregated results. The to_log_dict method provides convenient serialization for monitoring systems.


104-111: LGTM - correct KL divergence computation.

The use of log_target=True with both inputs as log-probabilities correctly computes KL(P_a || P_b). The inline comment on lines 109-110 could be updated to reflect the actual formula with log_target=True, but the code is correct.


121-222: LGTM!

The metric computation functions follow a consistent, correct pattern with proper masking and numerical stability considerations (e.g., eps=1e-8 in normalization, clamp(min=1.0) for valid counts).


237-272: LGTM!

Clean initialization with sensible defaults. The from_config factory method safely handles missing configuration attributes.


445-466: LGTM!

Clean implementation of random masking with uniform sampling.

Comment on lines +636 to +642
# Apply sparse delta (float32 computation, bf16 output)
reconstructed = apply_sparse_delta(
base_state,
sparse_tensors,
shapes,
target_dtype=torch.bfloat16,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# Search for apply_sparse_delta function and dtype handling in the checkpoint files
rg -A 10 "def apply_sparse_delta" grail/infrastructure/
rg -n "dtype" grail/infrastructure/checkpoint_consumer.py grail/infrastructure/delta_checkpoint.py grail/trainer/checkpoint_publisher.py
rg -n "target_dtype" grail/infrastructure/

Repository: one-covenant/grail

Length of output: 2887


🏁 Script executed:

rg -n "apply_sparse_delta" grail/
rg -B 5 "apply_sparse_delta" grail/infrastructure/checkpoint_consumer.py | head -30

Repository: one-covenant/grail

Length of output: 755


Infer target dtype from base_state tensors instead of hardcoding bfloat16.

The apply_sparse_delta function accepts a configurable target_dtype parameter, but the callsite hardcodes torch.bfloat16. Since base_state is dict[str, torch.Tensor], extract the dtype from the tensors to preserve the original checkpoint's precision. This ensures compatibility with models trained in other dtypes and hardware without native bfloat16 support.

🤖 Prompt for AI Agents
In grail/infrastructure/checkpoint_consumer.py around lines 636 to 642, the call
to apply_sparse_delta hardcodes target_dtype=torch.bfloat16; instead infer the
target dtype from base_state tensors so we preserve the checkpoint precision.
Inspect base_state (dict[str, torch.Tensor]) to determine a representative dtype
(e.g., pick the dtype of the first tensor or ensure all tensors share the same
dtype), and pass that dtype as target_dtype to apply_sparse_delta; include a
fallback (e.g., torch.float32) if base_state is empty or heterogeneous, and keep
the rest of the call unchanged.

Comment on lines +2089 to +2097
# Sparse quality analysis (uses same snapshot)
if self.sparse_analyzer.enabled:
try:
sparse_metrics = self.sparse_analyzer.analyze(
model, input_ids_tensor, attention_mask_tensor
)
await log_sparse_quality_metrics(
sparse_metrics, monitor, self.optimizer_step_count
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Potential issue: input_ids_tensor and attention_mask_tensor may be undefined or stale at epoch-end.

In the epoch-end accumulation path (lines 2089-2097), the sparse_analyzer.analyze() call references input_ids_tensor and attention_mask_tensor, but these variables are only defined inside the micro-batch loop. If the loop never executed or if the last iteration was skipped due to a continue statement (e.g., non-finite values), these tensors may be undefined, causing a NameError.

Consider capturing the tensors explicitly before calling analyze:

                        if self.sparse_analyzer.enabled:
                            try:
+                               # Note: Uses input tensors from last valid micro-batch
+                               if 'input_ids_tensor' not in dir():
+                                   logger.warning("No input tensors available for sparse analysis at epoch end")
+                               else:
                                sparse_metrics = self.sparse_analyzer.analyze(
                                    model, input_ids_tensor, attention_mask_tensor
                                )

Alternatively, consider extracting the param tracking and sparse analysis logic into a helper method to reduce duplication between the main loop (lines 2002-2031) and epoch-end (lines 2076-2102).

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines +294 to +295
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Handle edge case where model has no parameters.

next(model.parameters()) will raise StopIteration if the model has no parameters. While unlikely in practice, consider adding a guard or using a default.

-        device = next(model.parameters()).device
-        dtype = next(model.parameters()).dtype
+        params = list(model.parameters())
+        if not params:
+            raise ValueError("Model has no parameters to analyze")
+        device = params[0].device
+        dtype = params[0].dtype
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
params = list(model.parameters())
if not params:
raise ValueError("Model has no parameters to analyze")
device = params[0].device
dtype = params[0].dtype
🤖 Prompt for AI Agents
In grail/trainer/sparse_quality.py around lines 294-295, the current calls to
next(model.parameters()) will raise StopIteration if the model has no
parameters; change this to safely obtain device and dtype by first trying
next(model.parameters(), None), if that is None try next(model.buffers(), None),
and if still None fall back to explicit defaults (e.g., torch.device('cpu') and
torch.float32); assign device and dtype from the first non-None tensor found and
otherwise use the defaults so the code does not raise on parameterless models.

Comment on lines +413 to +443
for name, param in model.named_parameters():
if name not in snapshot:
continue

modified_params.add(name)

# Compute sparse delta (float32 on CPU)
sparse_delta = deltas[name] * mask_dict[name]

# Apply: W = W_old + sparse_delta (in float32, then convert)
snapshot_fp32 = snapshot[name].float()
new_weight = snapshot_fp32 + sparse_delta
param.data.copy_(new_weight.to(device=device, dtype=dtype))

try:
# Forward pass
with torch.no_grad():
logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
logits_cpu = logits.cpu().float()
del logits # Free GPU memory immediately
finally:
# Restore original weights using snapshot + full_delta (no GPU clone needed!)
for name, param in model.named_parameters():
if name in modified_params:
# W_current = W_old + delta, so restore by computing on CPU and moving to GPU
snapshot_fp32 = snapshot[name].float()
full_delta = deltas[name] # Full delta (not sparse)
original_weight = snapshot_fp32 + full_delta
param.data.copy_(original_weight.to(device=device, dtype=dtype))

return logits_cpu
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Move weight patching inside try block for exception safety.

The weight modification loop (lines 413-425) is outside the try block. If an exception occurs during patching (e.g., device mismatch, OOM), the model is left with partially modified weights and the finally block won't restore them properly since modified_params tracks what was modified, not what needs restoration.

         # Track which params we modified (for restoration)
         modified_params: set[str] = set()

-        for name, param in model.named_parameters():
-            if name not in snapshot:
-                continue
-
-            modified_params.add(name)
-
-            # Compute sparse delta (float32 on CPU)
-            sparse_delta = deltas[name] * mask_dict[name]
-
-            # Apply: W = W_old + sparse_delta (in float32, then convert)
-            snapshot_fp32 = snapshot[name].float()
-            new_weight = snapshot_fp32 + sparse_delta
-            param.data.copy_(new_weight.to(device=device, dtype=dtype))
-
         try:
+            for name, param in model.named_parameters():
+                if name not in snapshot:
+                    continue
+
+                modified_params.add(name)
+
+                # Compute sparse delta (float32 on CPU)
+                sparse_delta = deltas[name] * mask_dict[name]
+
+                # Apply: W = W_old + sparse_delta (in float32, then convert)
+                snapshot_fp32 = snapshot[name].float()
+                new_weight = snapshot_fp32 + sparse_delta
+                param.data.copy_(new_weight.to(device=device, dtype=dtype))
+
             # Forward pass
             with torch.no_grad():
🤖 Prompt for AI Agents
In grail/trainer/sparse_quality.py around lines 413 to 443, the loop that
applies sparse deltas to model parameters must be moved inside the try block (or
the try should be expanded to include that loop) so that any exception during
patching is caught and the finally block can reliably restore only the actually
modified parameters; initialize modified_params before the try if not already,
perform the patching inside the try (updating modified_params as you apply each
param), then run the forward under with torch.no_grad() and compute logits_cpu,
and keep the existing finally block to iterate modified_params and restore
original weights from snapshot + full_delta.

@erfanMhi erfanMhi changed the title feat: delta checkpointing feat: add delta checkpointing Dec 17, 2025
@erfanMhi erfanMhi self-assigned this Dec 17, 2025
@erfanMhi erfanMhi merged commit a4bdd30 into main Dec 17, 2025
10 checks passed
@erfanMhi erfanMhi deleted the feat/delta-checkpointing branch December 17, 2025 01:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants