-
Notifications
You must be signed in to change notification settings - Fork 8
feat: add delta checkpointing #55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ng in TrainerNeuron
…quality analysis.
…; now trainer runs as long as replay buffer is not empty; pause confirmation is done as soon as GPU memory is freed now.
…parse weight updates and reconstruction; includes new utilities for computing and applying sparse deltas, along with integration tests for the delta upload and download flow.
…locks; improve comments for clarity and ensure delta checkpointing logic is consistent across components
…istent monitoring setup across subprocesses; update TrainerNeuron and upload_worker to utilize this new helper for improved logging and error handling.
…lities; implement robust error handling for model weight loading in CheckpointManager, CheckpointPublisher, and upload_worker, ensuring better resilience against missing or malformed checkpoint files.
…ints; enhance upload metrics with compression details for improved efficiency and monitoring.
…pdate ruff configuration to exclude research directory.
…equent full checkpoint uploads.
… enhanced upload metrics; refactor upload methods to return detailed results including timing and size metrics, improving monitoring and error handling during checkpoint uploads.
…to FULL; enhance error handling and logging for upload processes, improving reliability and monitoring of checkpoint uploads.
|
You have run out of free Bugbot PR reviews for this billing cycle. This will reset on January 4. To receive reviews on all of your PRs, visit the Cursor dashboard to activate Pro and start your 14-day free trial. |
WalkthroughThis PR introduces delta checkpoint support enabling efficient model updates through sparse weight deltas, adds parameter-change tracking across training steps, implements sparse quality analysis for update fidelity, and extends monitoring infrastructure for subprocess coordination and metrics logging. Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer as Training Loop
participant ParamTracker as ParamChangeTracker
participant SparseQA as SparseQualityAnalyzer
participant Monitor as MonitoringManager
participant WandB as W&B Backend
Trainer->>ParamTracker: capture_snapshot(model)
Note over ParamTracker: Store weights to CPU
Trainer->>Trainer: optimizer.step()
Trainer->>ParamTracker: compute_metrics(model)
Note over ParamTracker: Compute deltas, sparsity, sign-flips
ParamTracker-->>Trainer: ParamChangeMetrics
Trainer->>Monitor: log_param_change_metrics()
Monitor->>WandB: log gauge metrics
WandB-->>Monitor: ✓
Trainer->>SparseQA: analyze(model, inputs)
SparseQA->>SparseQA: compute sparse deltas
SparseQA->>SparseQA: forward pass (sparse vs full)
SparseQA->>SparseQA: compute KL/cosine/MSE/top-1
SparseQA-->>Trainer: SparseQualityMetrics
Trainer->>Monitor: log_sparse_quality_metrics()
Monitor->>WandB: log per-threshold metrics
WandB-->>Monitor: ✓
Note over Trainer,WandB: Snapshot cleared after interval
Trainer->>ParamTracker: clear_snapshot()
sequenceDiagram
participant UploadWorker as upload_worker_loop
participant Publisher as CheckpointPublisher
participant Storage as Checkpoint Storage
participant Base as Base Checkpoint Cache
UploadWorker->>UploadWorker: determine FULL vs DELTA
alt is base-interval or first checkpoint
UploadWorker->>Publisher: upload_from_staging(full=True)
Publisher->>Storage: upload checkpoint files
Storage-->>Publisher: ✓
Publisher-->>UploadWorker: UploadResult (FULL)
UploadWorker->>Base: cache state as delta base
Note over Base: base_window, base_state stored
else intermediate checkpoint
UploadWorker->>Publisher: upload_delta(base_state)
Publisher->>Publisher: compute_sparse_delta()
Publisher->>Publisher: compress delta (zstd)
Publisher->>Storage: upload delta artifacts
Storage-->>Publisher: ✓
Publisher-->>UploadWorker: UploadResult (DELTA)
Note over UploadWorker: Sparse ratio & metrics in result
else delta fails
UploadWorker->>Publisher: upload_from_staging(full=True)
Note over UploadWorker: Fallback to FULL
end
UploadWorker->>UploadWorker: log metrics to monitor
sequenceDiagram
participant CheckpointManager as CheckpointManager
participant MetadataFetch as Fetch Metadata
participant DeltaHandler as _handle_delta_checkpoint
participant BaseCheckpoint as Base Checkpoint
participant Reconstruct as _reconstruct_from_delta
participant Storage as Local Storage
CheckpointManager->>MetadataFetch: _fetch_metadata(window)
MetadataFetch-->>CheckpointManager: CheckpointMetadata
alt checkpoint_type == "DELTA"
CheckpointManager->>DeltaHandler: _handle_delta_checkpoint(metadata)
DeltaHandler->>BaseCheckpoint: get_checkpoint(base_window)
alt base is DELTA
BaseCheckpoint->>DeltaHandler: recursively download base
else base is FULL
BaseCheckpoint-->>DeltaHandler: base path
end
DeltaHandler->>DeltaHandler: download delta files to temp
DeltaHandler->>DeltaHandler: verify delta integrity
DeltaHandler->>Reconstruct: _reconstruct_from_delta(base, delta)
Reconstruct->>Reconstruct: load base weights
Reconstruct->>Reconstruct: apply_sparse_delta()
Reconstruct->>Reconstruct: verify weights_hash
Reconstruct->>Storage: save reconstructed model
Reconstruct->>Storage: copy supplementary files
Reconstruct->>Storage: write metadata (marked FULL)
Reconstruct-->>DeltaHandler: output path
DeltaHandler-->>CheckpointManager: checkpoint path
else checkpoint_type == "FULL"
CheckpointManager->>Storage: retrieve checkpoint directly
Storage-->>CheckpointManager: checkpoint path
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (13)
grail/monitoring/backends/wandb_backend.py (1)
40-41: Minor redundancy:param_change/sparse/prefix is already covered byparam_change/.The
"param_change/sparse/"entry is redundant since"param_change/"already matches and both map to"optimizer_step". The_get_step_metric_for_namemethod returns on first prefix match, so"param_change/sparse/foo"would match"param_change/"first. Consider removing the redundant entry for clarity.STEP_METRIC_PREFIXES = { "training/epoch/": "epoch", "training/batch/": "batch_step", "training/block/": "block_number", "training/prefilter/": "block_number", "param_change/": "optimizer_step", - "param_change/sparse/": "optimizer_step", "eval/": "block_number",research/data/kernelbook_analysis.py (1)
27-42: Script executes on import due to top-level code.The script loads datasets and prints output immediately when imported. This prevents importing the helper functions (
test_module_from_code,compare_pytorch_vs_triton) without side effects. Consider wrapping the executable portions in anif __name__ == "__main__":guard.+if __name__ == "__main__": print("Loading KernelBook dataset...") ds = load_dataset("GPUMODE/KernelBook", split="train") # ... rest of the script execution ...grail/trainer/algorithms/grpo.py (1)
2073-2102: Consider extracting duplicated param tracking logic into a helper method.The parameter change tracking and sparse quality analysis logic (lines 2073-2102) is nearly identical to the code in the main loop (lines 2002-2031). Extracting this into a private helper method would improve maintainability.
async def _log_param_tracking_metrics( self, model: Any, monitor: Any | None, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, ) -> None: """Log parameter change and sparse quality metrics if due.""" if not self.param_tracker.has_snapshot(): return try: param_metrics = self.param_tracker.compute_metrics(model) await log_param_change_metrics(param_metrics, monitor, self.optimizer_step_count) except Exception as exc: logger.warning("Failed to compute param change metrics: %s", exc) if self.sparse_analyzer.enabled and input_ids is not None: try: sparse_metrics = self.sparse_analyzer.analyze(model, input_ids, attention_mask) await log_sparse_quality_metrics(sparse_metrics, monitor, self.optimizer_step_count) except Exception as exc: logger.warning("Failed to compute sparse quality metrics: %s", exc)tests/integration/infrastructure/test_delta_checkpoint_flow.py (1)
347-349: Hardcoded interval check may become fragile.The test assumes
DELTA_BASE_INTERVAL == 100with a conditional assertion. If the constant changes, the test silently skips the assertions rather than failing.Consider making the test more robust by computing expected base windows dynamically:
- if DELTA_BASE_INTERVAL == 100: - assert 200 * WINDOW_LENGTH in keep, "Base window 200 should be kept" - assert 100 * WINDOW_LENGTH in keep, "Base window 100 should be kept" + # Verify base windows aligned to DELTA_BASE_INTERVAL are kept + base_window_200 = 200 * WINDOW_LENGTH + base_window_100 = 100 * WINDOW_LENGTH + if base_window_200 % (DELTA_BASE_INTERVAL * WINDOW_LENGTH) == 0: + assert base_window_200 in keep, f"Base window {base_window_200} should be kept" + if base_window_100 % (DELTA_BASE_INTERVAL * WINDOW_LENGTH) == 0: + assert base_window_100 in keep, f"Base window {base_window_100} should be kept"grail/shared/safetensors_utils.py (1)
60-62: Consider logging a warning for empty or invalid weight_map entries.The validation correctly raises
ValueErrorfor missing/invalidweight_map, but ifweight_mapcontains non-string values, they're silently cast viastr(v)on line 64. This is likely fine for valid safetensors but could mask unexpected formats.weight_map = index.get("weight_map") if isinstance(index, dict) else None if not isinstance(weight_map, dict) or not weight_map: raise ValueError("Invalid safetensors index: missing/invalid 'weight_map'") + + # Validate weight_map values are strings + non_string_values = [k for k, v in weight_map.items() if not isinstance(v, str)] + if non_string_values: + logger.warning( + "weight_map contains non-string shard references for keys: %s", + non_string_values[:5], + )grail/monitoring/manager.py (1)
454-455: Consider moving imports to module level.The
osandtimeimports inside the function are standard library modules that could be imported at the module level for consistency with the rest of the file. This is a minor style suggestion.+import os +import time from collections.abc import Awaitable, Callable ... async def initialize_subprocess_monitoring( ... ) -> MonitoringManager | None: - import os - import time - if not monitor_config:grail/trainer/training_process.py (1)
61-62: Add validation for environment variable parsing.If these environment variables contain invalid float values,
float()will raiseValueError. Consider adding validation or wrapping in try/except with a warning log.-WARMUP_FRACTION = float(os.getenv("GRAIL_WARMUP_FRACTION", "0.05")) -SCHEDULER_ETA_MIN = float(os.getenv("GRAIL_SCHEDULER_ETA_MIN", "1e-7")) +def _parse_float_env(name: str, default: str) -> float: + val = os.getenv(name, default) + try: + return float(val) + except ValueError: + logger.warning("Invalid %s='%s', using default %s", name, val, default) + return float(default) + +WARMUP_FRACTION = _parse_float_env("GRAIL_WARMUP_FRACTION", "0.05") +SCHEDULER_ETA_MIN = _parse_float_env("GRAIL_SCHEDULER_ETA_MIN", "1e-7")grail/trainer/checkpoint_publisher.py (1)
762-784: Handle edge case where delta contains no changes.If
sparse_tensorsis empty (no parameter changes), the safetensors file will be empty. While handled correctly, the compression ratio calculation on line 782-784 could divide by zero if both sizes are 0 (though unlikely in practice).compression_ratio = ( - delta_raw_size / delta_compressed_size if delta_compressed_size > 0 else 1.0 + delta_raw_size / delta_compressed_size if delta_compressed_size > 0 and delta_raw_size > 0 else 1.0 )grail/trainer/upload_worker.py (1)
136-138: Consider memory implications of caching base state.
base_stateholds the entire model state dict in memory for delta computation. For large models (e.g., 7B+ parameters), this could be significant. Consider:
- Documenting the memory overhead
- Adding a mechanism to clear
base_statewhen memory is constrainedThis is acceptable for now given the bandwidth savings, but worth monitoring.
grail/infrastructure/delta_checkpoint.py (3)
55-58: Consider handling parameters missing from current_state.Parameters present in
base_statebut absent fromcurrent_stateare silently ignored. If this represents a model architecture mismatch, it might be worth logging a warning similar to the one for the reverse case (line 57).
126-143: Orphaned sparse delta entries are silently ignored.If
sparse_tensorscontains{name}.indicesfor a parameter that doesn't exist inbase_state, those deltas will never be applied. This could indicate a mismatch between the delta and base checkpoint. Consider adding a validation check or warning.+ # Check for orphaned delta entries + delta_param_names = {k.rsplit('.', 1)[0] for k in sparse_tensors if k.endswith('.indices')} + missing_in_base = delta_param_names - set(base_state.keys()) + if missing_in_base: + logger.warning("Delta contains parameters not in base state: %s", missing_in_base) + for name, base_tensor in base_state.items():
219-220: Minor: Consider caching element sizes to avoid tensor allocation.Creating empty tensors on each call is slightly wasteful. Since the typical dtypes are fixed, you could use constants or a simple lookup.
+# Element sizes for common dtypes (bytes) +_DTYPE_SIZES = { + torch.int32: 4, + torch.int64: 8, + torch.float32: 4, + torch.float16: 2, + torch.bfloat16: 2, +} + + def estimate_sparse_size( nonzero_params: int, index_dtype: torch.dtype = torch.int32, value_dtype: torch.dtype = torch.float32, ) -> int: - index_size = nonzero_params * torch.tensor([], dtype=index_dtype).element_size() - value_size = nonzero_params * torch.tensor([], dtype=value_dtype).element_size() + index_size = nonzero_params * _DTYPE_SIZES.get( + index_dtype, torch.tensor([], dtype=index_dtype).element_size() + ) + value_size = nonzero_params * _DTYPE_SIZES.get( + value_dtype, torch.tensor([], dtype=value_dtype).element_size() + ) return index_size + value_sizegrail/trainer/sparse_quality.py (1)
497-510: Consider adding a fallback if 1e-6 threshold is not present.The summary log only fires if a threshold of exactly
1e-6exists in the metrics. If thresholds are configured differently, no summary will be logged. This is likely fine for the current use case, but consider logging the first threshold as a fallback.
📜 Review details
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (22)
grail/infrastructure/checkpoint_consumer.py(7 hunks)grail/infrastructure/delta_checkpoint.py(1 hunks)grail/monitoring/__init__.py(2 hunks)grail/monitoring/backends/wandb_backend.py(4 hunks)grail/monitoring/manager.py(2 hunks)grail/neurons/trainer.py(7 hunks)grail/shared/constants.py(1 hunks)grail/shared/safetensors_utils.py(1 hunks)grail/trainer/algorithms/grpo.py(4 hunks)grail/trainer/checkpoint_publisher.py(7 hunks)grail/trainer/config.py(1 hunks)grail/trainer/evaluator.py(1 hunks)grail/trainer/param_tracker.py(1 hunks)grail/trainer/sparse_quality.py(1 hunks)grail/trainer/training_process.py(8 hunks)grail/trainer/upload_worker.py(11 hunks)pyproject.toml(2 hunks)research/data/kernelbook_analysis.py(1 hunks)tests/integration/infrastructure/test_delta_checkpoint_flow.py(1 hunks)tests/integration/test_train_cli_smoke.py(1 hunks)tests/unit/infrastructure/__init__.py(1 hunks)tests/unit/infrastructure/test_delta_checkpoint.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (11)
tests/integration/infrastructure/test_delta_checkpoint_flow.py (3)
grail/infrastructure/checkpoint_consumer.py (3)
CheckpointMetadata(69-90)is_delta(88-90)_compute_keep_windows(717-747)grail/infrastructure/delta_checkpoint.py (3)
apply_sparse_delta(105-143)compute_sparse_delta(30-102)compute_weights_hash(146-175)grail/trainer/checkpoint_publisher.py (2)
is_delta(148-150)_compute_keep_windows(176-212)
grail/monitoring/__init__.py (1)
grail/monitoring/manager.py (1)
initialize_subprocess_monitoring(416-560)
grail/monitoring/manager.py (4)
grail/trainer/training_process.py (1)
get_block_context(316-320)grail/monitoring/backends/wandb_backend.py (1)
start_run(719-758)grail/monitoring/backends/null_backend.py (1)
start_run(73-83)grail/monitoring/base.py (1)
start_run(128-138)
grail/trainer/param_tracker.py (1)
grail/monitoring/manager.py (1)
log_gauge(158-186)
grail/infrastructure/checkpoint_consumer.py (2)
grail/shared/safetensors_utils.py (1)
load_model_state_dict(25-72)grail/infrastructure/delta_checkpoint.py (2)
apply_sparse_delta(105-143)compute_weights_hash(146-175)
tests/unit/infrastructure/test_delta_checkpoint.py (1)
grail/infrastructure/delta_checkpoint.py (5)
apply_sparse_delta(105-143)compute_sparse_delta(30-102)compute_weights_hash(146-175)estimate_sparse_size(204-221)verify_weights_hash(178-201)
grail/trainer/sparse_quality.py (2)
grail/trainer/param_tracker.py (4)
ParamChangeTracker(153-474)from_config(215-231)has_snapshot(254-260)get_snapshot(262-273)grail/monitoring/manager.py (1)
log_gauge(158-186)
grail/neurons/trainer.py (4)
grail/trainer/upload_worker.py (1)
run_upload_worker(436-488)grail/monitoring/manager.py (1)
set_block_context(105-113)grail/trainer/evaluator.py (1)
run_cycle(181-271)grail/neurons/base.py (1)
heartbeat(123-125)
grail/trainer/upload_worker.py (2)
grail/shared/safetensors_utils.py (1)
load_model_state_dict(25-72)grail/trainer/checkpoint_publisher.py (7)
UploadError(74-77)UploadResult(126-173)is_delta(148-150)to_dict(112-122)to_dict(152-173)upload_delta(698-983)upload_from_staging(514-696)
grail/trainer/checkpoint_publisher.py (3)
grail/infrastructure/checkpoint_consumer.py (3)
CheckpointMetadata(69-90)is_delta(88-90)remote_prefix(85-86)grail/infrastructure/delta_checkpoint.py (1)
compute_sparse_delta(30-102)grail/shared/safetensors_utils.py (1)
load_model_state_dict(25-72)
grail/trainer/training_process.py (3)
grail/monitoring/manager.py (1)
initialize_subprocess_monitoring(416-560)grail/trainer/snapshot_manager.py (2)
save_snapshot_atomic(59-127)get_latest_snapshot_path(195-202)grail/trainer/ipc.py (2)
queue_snapshot(122-142)confirm_pause(95-97)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: Run Tests (3.10)
- GitHub Check: Run Tests (3.11)
- GitHub Check: Run Tests (3.10)
- GitHub Check: Run Tests (3.11)
- GitHub Check: Run Tests (3.11)
🔇 Additional comments (63)
tests/unit/infrastructure/__init__.py (1)
1-1: LGTM!Standard package initialization for the infrastructure unit tests module.
tests/unit/infrastructure/test_delta_checkpoint.py (8)
1-19: LGTM!Clean imports with all necessary dependencies for testing the delta checkpoint functionality.
21-135: LGTM!Comprehensive test coverage for
compute_sparse_deltaincluding edge cases (no changes, full changes), threshold filtering, multidimensional tensors, and dtype handling. The sparsity ratio calculations and index assertions are correct.
137-196: LGTM!Good coverage for
apply_sparse_deltaincluding empty delta handling, multidimensional reconstruction, and dtype conversion verification.
198-253: LGTM!Round-trip tests properly validate exact reconstruction with
threshold=0.0and appropriately account for bf16 precision loss withatol=1e-2.
255-314: LGTM!Thorough hash function testing including determinism, bf16 support, collision resistance, key-order independence, and format validation.
316-332: LGTM!Verification tests cover both positive and negative cases appropriately.
334-348: LGTM!Size estimation tests correctly validate the byte calculations for different dtype combinations.
350-382: LGTM!Good integration test for safetensors round-trip. Note that this test passes
shapesdirectly fromcompute_sparse_deltatoapply_sparse_deltawithout serializing it. In production, shapes would need to be persisted alongside the delta (e.g., in JSON metadata as shown in the integration test reference). This is acceptable for a unit test scope.pyproject.toml (2)
227-227: LGTM!Excluding the
researchdirectory from Ruff linting is reasonable for experimental/research code that may not adhere to strict project conventions.
56-56: Add zstandard dependency for delta checkpoint compression.The
zstandardlibrary is correctly added as a core dependency to support Zstd compression for delta checkpoints. The package has no known security vulnerabilities and is deemed safe to use.grail/monitoring/backends/wandb_backend.py (4)
25-32: LGTM!Adding
optimizer_steptoRESERVED_TAGScorrectly ensures it's treated as an x-axis field rather than being appended to metric names, consistent with its use as a step metric for parameter change tracking.
241-256: LGTM!Good improvement to use per-label directories (
wandb_<safe_label>) instead of a staticwandb_trainingpath. The sanitization of the label to alphanumeric/dash/underscore characters prevents invalid directory names and avoids file conflicts between multiple worker processes.
311-320: LGTM!Adding
optimizer_stepas a base metric is required for param_change metrics to use it as their x-axis, consistent with the other changes in this file.
331-333: LGTM!Properly defines the
param_change/*metric family withoptimizer_stepas the step metric, enabling dedicated visualization for parameter change tracking.grail/trainer/config.py (1)
78-88: LGTM!Clean addition of parameter change tracking and sparse quality configuration fields. The fields follow the established pattern of sourcing defaults from
constants, enabling environment-driven configurability. The naming is consistent and the inline comments provide helpful context.research/data/kernelbook_analysis.py (2)
83-85: Security:exec()executes arbitrary code from the dataset.Using
exec()on code loaded from an external dataset introduces security risks. This is acceptable for local research scripts, but ensure this script is never used in production or with untrusted data sources.If this script could be invoked on untrusted inputs, consider sandboxing or limiting execution scope.
1-57: Verify this file belongs in the delta checkpointing PR.This file appears to be a KernelBook dataset analysis script unrelated to the delta checkpointing feature described in the PR objectives. It may have been accidentally included.
grail/shared/constants.py (1)
247-284: LGTM! Well-structured configuration constants.The new constants for parameter change tracking, sparse quality analysis, and delta checkpoint configuration are well-documented with clear comments, consistent naming conventions (GRAIL_ prefix), and sensible defaults. The environment variable fallbacks enable runtime configurability.
grail/trainer/param_tracker.py (4)
275-296: LGTM! Memory-safe snapshot implementation.The snapshot implementation correctly clones before moving to CPU (line 291), which prevents issues with in-place modifications. The warning on overwrite (line 282) is helpful for debugging.
345-351: Good: Float32 conversion before delta computation.Converting bfloat16 tensors to float32 before subtraction (lines 347-348) is critical for detecting small deltas that bfloat16 cannot represent. The inline comment clearly explains the rationale.
356-371: Consider making diagnostic logging conditional on verbosity.The first-parameter diagnostic logging runs unconditionally at DEBUG level. This is fine for development but could add overhead in production with DEBUG enabled. The current approach is acceptable since DEBUG should be off in production.
308-331: Memory consideration: Full model copy resides on CPU during tracking.The snapshot stores a CPU copy of all trainable parameters. For large models, this approximately doubles CPU memory requirements during the measurement interval. This is documented in the class docstring but worth keeping in mind for very large models.
tests/integration/test_train_cli_smoke.py (1)
104-122: LGTM! Test helper updated to match new function signature.The
_fake_run_upload_workerhelper correctly adds themonitor_configparameter to match the updated production signature, maintaining test compatibility with the new monitoring flow.grail/monitoring/__init__.py (1)
10-25: LGTM! Clean public API extension.The new
initialize_subprocess_monitoringexport is properly added to both the import statement and__all__, following the existing pattern. This provides a clean interface for subprocess monitoring initialization.grail/trainer/evaluator.py (1)
188-204: LGTM!The
window_numberparameter is cleanly added with proper documentation. The internal storage inself._current_window_numberenables window-aware evaluation context tracking.grail/trainer/algorithms/grpo.py (1)
1067-1086: LGTM!The parameter tracking and sparse quality analysis initialization follows a clean pattern with factory methods and conditional logging.
tests/integration/infrastructure/test_delta_checkpoint_flow.py (3)
32-37: LGTM!The
temp_cachefixture correctly usesyieldfor cleanup andshutil.rmtreewithignore_errors=Truefor robust teardown.
70-98: LGTM!The roundtrip test properly validates the compute + apply delta flow with appropriate tolerance checks.
355-398: Good edge case coverage!The tests for 0% sparsity (all weights change), 100% sparsity (no weights change), and single parameter change are valuable boundary tests.
grail/shared/safetensors_utils.py (1)
25-71: LGTM!Clean implementation supporting both sharded and unsharded safetensors layouts with proper error handling for missing shards and malformed indices.
grail/monitoring/manager.py (2)
416-560: LGTM!The
initialize_subprocess_monitoringfunction is well-designed with robust error handling, detailed logging at each stage, and graceful degradation (returningNoneon failure). The step-by-step approach with timing information will aid debugging.
461-463: Verify environment variable side effects.Setting
WANDB_DISABLE_SERVICEandWANDB_SERVICEenvironment variables affects the entire process. This is intentional for shared-mode W&B, but worth noting that these settings persist after the function returns.If this function is called multiple times in the same process (e.g., during testing or retry scenarios), the environment modifications are idempotent. However, if there are scenarios where different subprocess types need different W&B configurations, consider documenting this behavior.
grail/neurons/trainer.py (5)
185-185: LGTM: Subprocess label now passed explicitly.The keyword-only argument pattern ensures callers explicitly identify the subprocess, improving code clarity and preventing accidental misuse.
207-219: LGTM: Upload worker receives proper subprocess label and monitor config.The changes correctly pass
subprocess_label="upload_worker"and propagatemonitor_configto the upload worker process, enabling per-subprocess W&B logging.
336-336: LGTM: Block context now includes window number.Passing
current_windowinstead ofNoneprovides more precise monitoring context for metrics correlation.
807-809: LGTM: Window number forwarded to evaluation cycle.This enables window-aware evaluation context as noted in the relevant code snippet for
evaluator.run_cycle.
827-891: LGTM: Subprocess monitoring configuration with label propagation.The method correctly:
- Uses keyword-only argument for subprocess_label
- Propagates label to
wandb_x_labelfor shared mode workers- Includes comprehensive debug logging for troubleshooting
- Validates critical parameters (entity, project)
grail/trainer/training_process.py (4)
306-327: LGTM: Monitoring initialization properly delegated.The refactored
_initialize_monitoringcorrectly:
- Defines a local
get_block_contextasync helper- Uses the shared
initialize_subprocess_monitoringhelper for consistent behavior across subprocesses- Passes
test_connection=Trueto verify W&B connectivity
368-410: LGTM: Initial checkpoint upload with enriched metadata and IPC queuing.The updated flow:
- Prepares structured
snapshot_metadatawith training config and parent window- Queues snapshot for upload worker via IPC
- Includes appropriate fallback warning when IPC/path unavailable
458-468: LGTM: GRPO group loading optimized to trigger only on window change.This reduces redundant trusted miner lookups by checking
target_data_window != self.last_loaded_windowbefore querying. Good performance optimization.
592-612: LGTM: Pause confirmation signaled before snapshot save.The comment explains the rationale well: GPU is freed, so evaluation can start immediately while the snapshot saves in parallel (30-90s). This prevents orchestrator timeout waiting for confirmation.
grail/trainer/checkpoint_publisher.py (4)
74-77: LGTM: Custom exception for upload failures.Clean separation of upload-related errors from general exceptions.
80-173: LGTM: Well-designed result dataclasses.Good use of:
frozen=True, slots=Truefor immutability and memory efficiency- Computed
total_sproperty that sums timing componentsis_deltaproperty for type discriminationto_dict()methods for metrics logging
176-212: LGTM: Retention policy extended for delta base windows.The logic correctly:
- Computes
base_stride_blocksfrom delta interval and window length- Retains base windows that delta checkpoints depend on
- Ensures delta reconstruction remains possible
698-983: LGTM: Comprehensive delta upload implementation.The
upload_deltamethod:
- Loads current weights and computes sparse delta against base
- Applies zstd compression (level 3 for good balance)
- Includes SHA256 hash for integrity verification
- Provides detailed timing breakdown
- Properly cleans up temp directory in finally block
- Returns structured
UploadResultwith delta-specific metricsgrail/trainer/upload_worker.py (4)
97-106: LGTM: Subprocess monitoring initialization.Correctly initializes monitoring with
test_connection=Falsesince subtensor isn't ready yet, then sets block context later in the loop when subtensor is available.
175-237: LGTM: Robust FULL vs DELTA upload decision with fallback.The logic correctly:
- Uploads FULL if delta disabled, no base cached, or at base interval boundary
- Falls back from DELTA to FULL if delta upload fails
- Tracks
did_fallback_to_fullto properly cache base state
244-267: LGTM: Base state caching after FULL upload.Properly caches the checkpoint state after FULL uploads (including fallbacks) with informative logging of tensor count and parameter count.
350-433: LGTM: Upload retry with delta support.The refactored
_upload_with_retry:
- Accepts
is_delta,base_window,base_stateparameters- Routes to appropriate publisher method
- Validates delta prerequisites before attempting delta upload
- Returns
UploadResultfor structured metrics- Properly re-raises
UploadErrorafter exhausting retriesgrail/infrastructure/checkpoint_consumer.py (6)
80-90: LGTM: CheckpointMetadata extended for delta support.Clean additions:
checkpoint_typedefaults to "FULL" for backward compatibilitybase_windowtracks the base checkpoint for deltasweights_hashenables integrity verificationis_delta()helper for type checking
203-216: LGTM: Delta checkpoint handling integrated into get_checkpoint.The flow correctly:
- Checks
metadata.is_delta()before attempting reconstruction- Logs base window for debugging
- Raises
CheckpointDownloadErroron reconstruction failure
662-679: LGTM: Comprehensive file copying logic.The exclusion list for copying non-weight files from base is thorough:
- Excludes model weight files (safetensors, index files)
- Excludes metadata and signature files
- Excludes sharded weight files (model-*.safetensors)
- Excludes READY markers
This ensures only tokenizer, config, and other auxiliary files are copied.
506-559: LGTM: Delta checkpoint handling with recursive base fetching.The
_handle_delta_checkpointmethod:
- Validates
base_windowis present- Recursively fetches base checkpoint (handles chained deltas)
- Downloads delta files to separate temp directory
- Verifies delta integrity before reconstruction
- Cleans up temp directory in finally block
821-840: LGTM: Cache directory now honors GRAIL_CACHE_DIR.Good addition for flexibility:
- Supports
~expansion for home directory paths- Documents RAM-disk usage recommendation (
/dev/shm/grail)- Falls back to
~/.cache/grailwhen not set
717-747: LGTM: Retention logic extended for delta base windows.Consistent with the publisher's
_compute_keep_windows, ensuring both producer and consumer retain the same base windows needed for delta reconstruction.grail/infrastructure/delta_checkpoint.py (2)
146-175: LGTM!The hashing implementation is robust: sorted keys ensure determinism, and including dtype and shape in the hash prevents false matches across different tensor configurations. The
view(torch.uint8)approach correctly handles bfloat16 tensors that can't be converted to numpy directly.
178-201: LGTM!Clean verification function with appropriate error logging for debugging hash mismatches.
grail/trainer/sparse_quality.py (5)
32-86: LGTM!Well-structured dataclasses with clear separation between per-threshold metrics and aggregated results. The
to_log_dictmethod provides convenient serialization for monitoring systems.
104-111: LGTM - correct KL divergence computation.The use of
log_target=Truewith both inputs as log-probabilities correctly computes KL(P_a || P_b). The inline comment on lines 109-110 could be updated to reflect the actual formula withlog_target=True, but the code is correct.
121-222: LGTM!The metric computation functions follow a consistent, correct pattern with proper masking and numerical stability considerations (e.g.,
eps=1e-8in normalization,clamp(min=1.0)for valid counts).
237-272: LGTM!Clean initialization with sensible defaults. The
from_configfactory method safely handles missing configuration attributes.
445-466: LGTM!Clean implementation of random masking with uniform sampling.
| # Apply sparse delta (float32 computation, bf16 output) | ||
| reconstructed = apply_sparse_delta( | ||
| base_state, | ||
| sparse_tensors, | ||
| shapes, | ||
| target_dtype=torch.bfloat16, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# Search for apply_sparse_delta function and dtype handling in the checkpoint files
rg -A 10 "def apply_sparse_delta" grail/infrastructure/
rg -n "dtype" grail/infrastructure/checkpoint_consumer.py grail/infrastructure/delta_checkpoint.py grail/trainer/checkpoint_publisher.py
rg -n "target_dtype" grail/infrastructure/Repository: one-covenant/grail
Length of output: 2887
🏁 Script executed:
rg -n "apply_sparse_delta" grail/
rg -B 5 "apply_sparse_delta" grail/infrastructure/checkpoint_consumer.py | head -30Repository: one-covenant/grail
Length of output: 755
Infer target dtype from base_state tensors instead of hardcoding bfloat16.
The apply_sparse_delta function accepts a configurable target_dtype parameter, but the callsite hardcodes torch.bfloat16. Since base_state is dict[str, torch.Tensor], extract the dtype from the tensors to preserve the original checkpoint's precision. This ensures compatibility with models trained in other dtypes and hardware without native bfloat16 support.
🤖 Prompt for AI Agents
In grail/infrastructure/checkpoint_consumer.py around lines 636 to 642, the call
to apply_sparse_delta hardcodes target_dtype=torch.bfloat16; instead infer the
target dtype from base_state tensors so we preserve the checkpoint precision.
Inspect base_state (dict[str, torch.Tensor]) to determine a representative dtype
(e.g., pick the dtype of the first tensor or ensure all tensors share the same
dtype), and pass that dtype as target_dtype to apply_sparse_delta; include a
fallback (e.g., torch.float32) if base_state is empty or heterogeneous, and keep
the rest of the call unchanged.
| # Sparse quality analysis (uses same snapshot) | ||
| if self.sparse_analyzer.enabled: | ||
| try: | ||
| sparse_metrics = self.sparse_analyzer.analyze( | ||
| model, input_ids_tensor, attention_mask_tensor | ||
| ) | ||
| await log_sparse_quality_metrics( | ||
| sparse_metrics, monitor, self.optimizer_step_count | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential issue: input_ids_tensor and attention_mask_tensor may be undefined or stale at epoch-end.
In the epoch-end accumulation path (lines 2089-2097), the sparse_analyzer.analyze() call references input_ids_tensor and attention_mask_tensor, but these variables are only defined inside the micro-batch loop. If the loop never executed or if the last iteration was skipped due to a continue statement (e.g., non-finite values), these tensors may be undefined, causing a NameError.
Consider capturing the tensors explicitly before calling analyze:
if self.sparse_analyzer.enabled:
try:
+ # Note: Uses input tensors from last valid micro-batch
+ if 'input_ids_tensor' not in dir():
+ logger.warning("No input tensors available for sparse analysis at epoch end")
+ else:
sparse_metrics = self.sparse_analyzer.analyze(
model, input_ids_tensor, attention_mask_tensor
)Alternatively, consider extracting the param tracking and sparse analysis logic into a helper method to reduce duplication between the main loop (lines 2002-2031) and epoch-end (lines 2076-2102).
Committable suggestion skipped: line range outside the PR's diff.
| device = next(model.parameters()).device | ||
| dtype = next(model.parameters()).dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle edge case where model has no parameters.
next(model.parameters()) will raise StopIteration if the model has no parameters. While unlikely in practice, consider adding a guard or using a default.
- device = next(model.parameters()).device
- dtype = next(model.parameters()).dtype
+ params = list(model.parameters())
+ if not params:
+ raise ValueError("Model has no parameters to analyze")
+ device = params[0].device
+ dtype = params[0].dtype📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| device = next(model.parameters()).device | |
| dtype = next(model.parameters()).dtype | |
| params = list(model.parameters()) | |
| if not params: | |
| raise ValueError("Model has no parameters to analyze") | |
| device = params[0].device | |
| dtype = params[0].dtype |
🤖 Prompt for AI Agents
In grail/trainer/sparse_quality.py around lines 294-295, the current calls to
next(model.parameters()) will raise StopIteration if the model has no
parameters; change this to safely obtain device and dtype by first trying
next(model.parameters(), None), if that is None try next(model.buffers(), None),
and if still None fall back to explicit defaults (e.g., torch.device('cpu') and
torch.float32); assign device and dtype from the first non-None tensor found and
otherwise use the defaults so the code does not raise on parameterless models.
| for name, param in model.named_parameters(): | ||
| if name not in snapshot: | ||
| continue | ||
|
|
||
| modified_params.add(name) | ||
|
|
||
| # Compute sparse delta (float32 on CPU) | ||
| sparse_delta = deltas[name] * mask_dict[name] | ||
|
|
||
| # Apply: W = W_old + sparse_delta (in float32, then convert) | ||
| snapshot_fp32 = snapshot[name].float() | ||
| new_weight = snapshot_fp32 + sparse_delta | ||
| param.data.copy_(new_weight.to(device=device, dtype=dtype)) | ||
|
|
||
| try: | ||
| # Forward pass | ||
| with torch.no_grad(): | ||
| logits = model(input_ids=input_ids, attention_mask=attention_mask).logits | ||
| logits_cpu = logits.cpu().float() | ||
| del logits # Free GPU memory immediately | ||
| finally: | ||
| # Restore original weights using snapshot + full_delta (no GPU clone needed!) | ||
| for name, param in model.named_parameters(): | ||
| if name in modified_params: | ||
| # W_current = W_old + delta, so restore by computing on CPU and moving to GPU | ||
| snapshot_fp32 = snapshot[name].float() | ||
| full_delta = deltas[name] # Full delta (not sparse) | ||
| original_weight = snapshot_fp32 + full_delta | ||
| param.data.copy_(original_weight.to(device=device, dtype=dtype)) | ||
|
|
||
| return logits_cpu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move weight patching inside try block for exception safety.
The weight modification loop (lines 413-425) is outside the try block. If an exception occurs during patching (e.g., device mismatch, OOM), the model is left with partially modified weights and the finally block won't restore them properly since modified_params tracks what was modified, not what needs restoration.
# Track which params we modified (for restoration)
modified_params: set[str] = set()
- for name, param in model.named_parameters():
- if name not in snapshot:
- continue
-
- modified_params.add(name)
-
- # Compute sparse delta (float32 on CPU)
- sparse_delta = deltas[name] * mask_dict[name]
-
- # Apply: W = W_old + sparse_delta (in float32, then convert)
- snapshot_fp32 = snapshot[name].float()
- new_weight = snapshot_fp32 + sparse_delta
- param.data.copy_(new_weight.to(device=device, dtype=dtype))
-
try:
+ for name, param in model.named_parameters():
+ if name not in snapshot:
+ continue
+
+ modified_params.add(name)
+
+ # Compute sparse delta (float32 on CPU)
+ sparse_delta = deltas[name] * mask_dict[name]
+
+ # Apply: W = W_old + sparse_delta (in float32, then convert)
+ snapshot_fp32 = snapshot[name].float()
+ new_weight = snapshot_fp32 + sparse_delta
+ param.data.copy_(new_weight.to(device=device, dtype=dtype))
+
# Forward pass
with torch.no_grad():🤖 Prompt for AI Agents
In grail/trainer/sparse_quality.py around lines 413 to 443, the loop that
applies sparse deltas to model parameters must be moved inside the try block (or
the try should be expanded to include that loop) so that any exception during
patching is caught and the finally block can reliably restore only the actually
modified parameters; initialize modified_params before the try if not already,
perform the patching inside the try (updating modified_params as you apply each
param), then run the forward under with torch.no_grad() and compute logits_cpu,
and keep the existing finally block to iterate modified_params and restore
original weights from snapshot + full_delta.
Delta Checkpointing: ~90% Bandwidth Reduction for Weight Communication
Implements sparse delta checkpointing to dramatically reduce upload/download bandwidth between trainer and miners/validators.
Summary by CodeRabbit
Release Notes
New Features
Enhancements
✏️ Tip: You can customize this high-level summary in your review settings.