-
Notifications
You must be signed in to change notification settings - Fork 8
feat: add async trainer infrastructure #52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…er configuration in EvalConfig
…ng CheckpointPublisher for publishing and managing checkpoints, and CheckpointManager for consuming checkpoints
Add snapshot-based IPC for training/upload coordination. Include TrainingService for continuous training with pause/resume and SnapshotManager for atomic checkpoint ops. Add upload worker for async checkpoint publishing to R2.
Replace window-based training with continuous loop. Split training and upload into separate processes using SnapshotManager. Add heartbeat monitoring and pause/resume for evaluation coordination.
Change READY marker format to READY-{window} encoding completion
window. Add upload_from_staging method for worker process. Simplify
finalize_checkpoint_ready to accept explicit checkpoint and ready
windows.
Update checkpoint discovery to use READY-{window} format. Add
get_checkpoint_ready_window method to determine availability. Reorder
CheckpointMetadata fields for consistency.
Remove synchronous model loading from CLI. Pass ModelLoadSpec to TrainerContext instead of loaded models. Load models in training process to avoid CUDA fork issues.
Add tests for SnapshotManager atomic operations and IPC. Add smoke tests for train CLI with lightweight fakes to verify end-to-end orchestration.
Add SNAPSHOT_POLL_INTERVAL_SECONDS, TRAINING_HEARTBEAT_TIMEOUT_SECONDS, and upload retry config. Reduce TRAINER_MAX_LENGTH default from 2048 to 1024 for memory efficiency.
Update environment loop for async trainer compatibility. Improve math environment parsing and reward computation.
Update validation logic for async checkpoint publishing. Improve miner validation error handling and logging.
Update W&B backend to support logging from multiple processes (main, training, upload worker) using run ID resumption.
Remove reasoning_start parameter from build_qwen_chat_template. Model must generate reasoning tokens without prompt injection. Clean up related comments.
Update miner to use get_latest_ready_checkpoint for discovery. Add checkpoint_window to rollout packaging. Remove window parameter from checkpoint publishing calls. Minor evaluator vLLM config cleanup.
Fix whitespace on blank lines, sort imports, remove unused imports, and remove quotes from type annotations.
…resilient version
…mproved monitoring initialization
…eout configuration - Added support for WandB shared mode to allow multiple processes to log to a single run. - Introduced `init_timeout` configuration to manage connection delays more effectively. - Updated logging to provide better insights into the configuration and initialization process.
…, we're not passing extra args
… for missing parameters
…process doesn't reuse wandb service used by the main process
…a cleaner interface
…onfigurable sampling
…tor for non-blocking read because it resulted in some issues
…to avoid S3 pagination issues
…ad duration to enhance monitoring
…ncy; set proper context for wandb logging!
…d update GRPO max completion tokens to be 1024
…ional entropy calculation for memory efficiency
…w. Previously it was done after every micro batch.
Replace infinite filesystem polling with multiprocessing.Event for pause/resume coordination between orchestrator and training process. Adds 5-minute timeout, stop_event check, and process liveness check to prevent indefinite hangs. Filesystem markers kept as crash recovery backup.
- Add IPCChannels dataclass with stop, pause, heartbeat, and snapshot_queue - Add helper methods for heartbeat, pause coordination, and snapshot queuing - Fix bug in upload_worker using undefined variables - Remove dead code: filesystem fallback mode and unused ipc_coordinator.py
…rol and add verbosity context
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughImplements async multi-process training orchestration (training + upload workers) with IPC, splits checkpoint producer/consumer roles (CheckpointPublisher and consumer helpers), adds SnapshotManager and ReplayBuffer, updates CLI/trainer wiring and monitoring (WandB) integration, and introduces config/env flags including gradient checkpointing. Changes
Sequence Diagram(s)sequenceDiagram
participant Orchestrator as Orchestrator
participant Trainer as Training Process
participant SnapMgr as SnapshotManager
participant IPC as IPCChannels
participant Upload as Upload Worker
participant Publisher as CheckpointPublisher
participant Storage as Remote Storage
Orchestrator->>Trainer: spawn training_process(ipc, specs, config)
Orchestrator->>Upload: spawn upload_worker(ipc, publisher)
loop training epoch
Trainer->>Trainer: train epoch / update heartbeat
Trainer->>SnapMgr: save_snapshot_atomic(model, tokenizer, metadata)
SnapMgr-->>Trainer: snapshot saved
Trainer->>IPC: queue_snapshot(path, metadata, window)
end
loop upload loop
Upload->>IPC: wait for snapshot message
IPC-->>Upload: snapshot metadata
Upload->>SnapMgr: copy_snapshot_to_staging()
SnapMgr-->>Upload: staging_path
Upload->>Publisher: upload_from_staging(staging_path, metadata, window)
Publisher->>Storage: upload files + metadata
Storage-->>Publisher: uploaded
Publisher->>Storage: write READY-{ready_window} marker
Storage-->>Publisher: marker OK
Upload->>SnapMgr: cleanup_staging()
end
sequenceDiagram
participant Miner as Miner
participant CheckMgr as CheckpointManager (consumer)
participant Storage as Remote Storage
Miner->>CheckMgr: get_latest_ready_checkpoint(before_window)
CheckMgr->>Storage: list/read READY markers
Storage-->>CheckMgr: ready_window (or none)
CheckMgr-->>Miner: checkpoint_window (or None)
alt checkpoint found
Miner->>CheckMgr: get_checkpoint(checkpoint_window)
CheckMgr->>Storage: download files
Storage-->>CheckMgr: model/tokenizer
CheckMgr-->>Miner: artifacts
else none
Miner-->>Miner: warn / retry later
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Areas requiring extra attention:
Possibly related PRs
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 9
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
grail/shared/chat_templates.py (1)
10-24: Remove the deprecated parameter from the docstring Args section.The docstring documents
reasoning_startin the Args section (line 15), but this parameter no longer exists in the function signature. Since it's been fully removed, delete it from the Args section rather than marking it as deprecated.Apply this diff to fix the docstring:
Args: system_prompt: The system prompt to inject - reasoning_start: DEPRECATED - No longer used. The model generates reasoning tokens. Returns:grail/neurons/miner.py (1)
252-264:checkpoint_windowcan beNonebutgenerate_rollouts_for_windowexpectsint.
checkpoint_manager.get_latest_ready_checkpoint()returnsint | None. When no checkpoint is discovered (line 173),checkpoint_windowremainsNone. If a model was already loaded from a previous iteration, the code bypasses the checkpoint loading block (line 178) and proceeds directly to the function call at line 252, passingNonetogenerate_rollouts_for_window. However, that function has a type annotation ofcheckpoint_window: int(notOptional) and uses it directly without None checks at line 739 inassemble_rollout_payload. Either handle theNonecase before the function call or change the type annotation toint | Noneand add defensive checks within the function.grail/infrastructure/checkpoint_consumer.py (1)
426-454: Semaphore created but never used in_download_all_in_prefix.The semaphore on line 439 is created but not applied to the
_dlcoroutine, allowing unbounded concurrent downloads that could overwhelm the network or R2.async def _download_all_in_prefix(self, window: int, tmp_dir: Path) -> None: ... if not keys: raise CheckpointDownloadError(f"No files found at prefix {prefix_dir}") - asyncio.Semaphore(6) + semaphore = asyncio.Semaphore(6) async def _dl(key: str) -> None: - if not key or not key.startswith(prefix_dir) or key.endswith("/"): - return - rel = key[len(prefix_dir) :] - data = await comms.download_file_chunked( - key, credentials=self.credentials, use_write=False - ) - if data is None: - raise CheckpointDownloadError(f"Missing file {key}") - target_path = tmp_dir / rel - target_path.parent.mkdir(parents=True, exist_ok=True) - target_path.write_bytes(data) + async with semaphore: + if not key or not key.startswith(prefix_dir) or key.endswith("/"): + return + rel = key[len(prefix_dir) :] + data = await comms.download_file_chunked( + key, credentials=self.credentials, use_write=False + ) + if data is None: + raise CheckpointDownloadError(f"Missing file {key}") + target_path = tmp_dir / rel + target_path.parent.mkdir(parents=True, exist_ok=True) + target_path.write_bytes(data) await asyncio.gather(*(_dl(k) for k in keys))
🧹 Nitpick comments (33)
grail/model/provider.py (1)
129-148: Model metadata logging looks good; consider single-pass param counting as a micro-optimizationThe new logging block is safe and useful, and the try/except ensures it won’t affect model loading behavior. As a small optional improvement, you could avoid iterating
model.parameters()twice by computingtotal_paramsandtrainable_paramsin one pass:total_params = 0 trainable_params = 0 for p in model.parameters(): n = p.numel() total_params += n if p.requires_grad: trainable_params += nThis keeps the same behavior with slightly less overhead on very large models.
grail/trainer/algorithms/grpo.py (1)
1672-1711: Consider reducing key duplication between initialization and reset.The tracker keys are duplicated between the initial dict (lines 1672-1689) and
reset_actual_batch_tracker(lines 1694-1711). This creates a maintenance burden if keys change.Consider extracting the default tracker state:
+ def _make_empty_tracker() -> dict[str, float]: + return { + "micro_batches": 0, + "sequence_count": 0.0, + "token_count": 0.0, + "loss_total_sum": 0.0, + "loss_pg_sum": 0.0, + "loss_kl_sum": 0.0, + "loss_entropy_sum": 0.0, + "adv_sum": 0.0, + "adv_count": 0.0, + "reward_sum": 0.0, + "reward_count": 0.0, + "entropy_sum": 0.0, + "entropy_count": 0.0, + "kl_sum": 0.0, + "ratio_clip_sum": 0.0, + "ratio_ceiling_sum": 0.0, + } + + actual_batch_tracker = _make_empty_tracker() + + def reset_actual_batch_tracker() -> None: + actual_batch_tracker.update(_make_empty_tracker())grail/environments/math_hendrycks_env.py (1)
32-61: _extract_boxed_answer logic is solid; consider a small robustness tweakThe implementation correctly:
- Detects
\boxed{occurrences with a regex.- Uses a depth counter to handle nested braces.
- Returns
Nonewhen no\boxed{}is present.One edge case: if the last
\boxed{in the string has unbalanced braces (e.g., truncated or malformed solution), the function returnsNoneand ignores any earlier, valid\boxed{...}blocks. If you expect occasionally noisy inputs, it may be safer to iterate overboxed_indicesin reverse and return the first block that parses successfully, while also stripping whitespace from the extracted content.For example:
- # Use the last one - start = boxed_indices[-1] - # Skip \boxed{ (7 chars) - content_start = start + 7 - - depth = 1 - for i, char in enumerate(text[content_start:], start=content_start): - if char == "{": - depth += 1 - elif char == "}": - depth -= 1 - if depth == 0: - return text[content_start:i] + # Try boxed occurrences from last to first; return the first that parses. + for start in reversed(boxed_indices): + # Skip \boxed{ (7 chars) + content_start = start + 7 + + depth = 1 + for i, char in enumerate(text[content_start:], start=content_start): + if char == "{": + depth += 1 + elif char == "}": + depth -= 1 + if depth == 0: + return text[content_start:i].strip()You still fall through to
return Noneat the end if no occurrence parses cleanly.grail/infrastructure/network.py (1)
117-127: Make subtensor.close()handling more robust and avoid duplicationThe new close logic in
_restart_subtensorandclose()is good for avoiding leaks, but the pattern:if hasattr(subtensor, "close"): if asyncio.iscoroutinefunction(subtensor.close): await subtensor.close() else: subtensor.close()relies on
closebeing defined as a coroutine function. It will not await cases whereclose()returns an awaitable but isn’t marked as a coroutine function (wrappers, partials, etc.), and the same logic is duplicated in two places.You can make this both more robust and DRY by:
- Calling
close()once, capturing the result, and awaiting it if it’s awaitable.- Reusing the same helper pattern in both
_restart_subtensorand the publicclose().For example:
- try: - if hasattr(subtensor, "close"): - if asyncio.iscoroutinefunction(subtensor.close): - await subtensor.close() - else: - subtensor.close() - logger.debug("Closed old subtensor connection") + try: + if hasattr(subtensor, "close"): + close_result = subtensor.close() + if inspect.isawaitable(close_result): + await close_result + logger.debug("Closed old subtensor connection") @@ - subtensor = object.__getattribute__(self, "_subtensor") - if hasattr(subtensor, "close"): - if asyncio.iscoroutinefunction(subtensor.close): - await subtensor.close() - else: - subtensor.close() + subtensor = object.__getattribute__(self, "_subtensor") + if hasattr(subtensor, "close"): + close_result = subtensor.close() + if inspect.isawaitable(close_result): + await close_resultAnd add at the top of the file:
import inspect # new importOptionally, you could also extract this into a small
_maybe_await_close(subtensor)helper to avoid repeating the pattern a third time if you ever need it elsewhere.Also applies to: 138-146
grail/trainer/replay_buffer.py (2)
163-170: Consider using a set for faster membership checking.The
g not in sampledcheck on Line 167 is O(n) per group. For larger buffers, this could become a bottleneck. Consider tracking sampled groups in a set using group IDs.# Fill remaining quota from most recent window if undersampled if len(sampled) < max_groups: remaining_quota = max_groups - len(sampled) most_recent_window = sorted_windows[-1] available = self._windows[most_recent_window] - unused = [g for g in available if g not in sampled] + sampled_ids = {id(g) for g in sampled} + unused = [g for g in available if id(g) not in sampled_ids] k = min(remaining_quota, len(unused)) if k > 0: sampled.extend(rng.sample(unused, k))
223-224: Document the memory estimation assumption.The hardcoded
10_000bytes per group is a rough approximation. Consider adding a comment explaining this estimate or making it configurable.total_groups = sum(len(groups) for groups in self._windows.values()) + # Rough estimate: ~10KB per group (tokens + logprobs + metadata) memory_mb = (total_groups * 10_000) / (1024 * 1024)scripts/test_wandb_shared.py (1)
36-44: Consider extracting hardcoded project name.The
"grail"project name is hardcoded here but could benefit from being a CLI argument or matching a constant, especially if this script is reused for different projects.+ parser.add_argument( + "--project", + type=str, + default="grail", + help="WandB project name (default: grail)", + )grail/cli/train.py (1)
76-86: Consider making the WandB sync delay configurable.The hardcoded 5-second sleep is a pragmatic workaround for WandB shared mode sync requirements. Consider making this configurable or implementing a retry/polling mechanism for more robust synchronization.
+from grail.monitoring.config import WANDB_SYNC_DELAY_SECONDS # Add to config + if training_config.get("wandb_shared_mode"): - import asyncio - logger.info( - "Waiting 5s for WandB run to sync to cloud (shared mode requirement)..." + f"Waiting {WANDB_SYNC_DELAY_SECONDS}s for WandB run to sync..." ) - await asyncio.sleep(5) + await asyncio.sleep(WANDB_SYNC_DELAY_SECONDS)grail/neurons/miner.py (1)
225-234: Verify the dual model/tokenizer checks are intentional.There are two separate checks for
model is None or tokenizer is None:
- Lines 225-228: When no checkpoint is available (sleeps 60s)
- Lines 231-234: Catch-all before mining (sleeps 30s)
The second check at lines 231-234 appears to be a fallback for edge cases where checkpoint loading might silently fail. Consider adding a comment clarifying this is a defensive guard.
grail/logging_utils.py (2)
294-296: Consider adding error handling forreconfigurecalls.
sys.stdout.reconfigure()andsys.stderr.reconfigure()can raiseAttributeErrorin non-standard Python environments or when stdout/stderr are replaced with objects that don't support reconfiguration (e.g., certain test fixtures, custom streams).# Force unbuffered output for immediate visibility - sys.stdout.reconfigure(line_buffering=True) - sys.stderr.reconfigure(line_buffering=True) + try: + sys.stdout.reconfigure(line_buffering=True) + sys.stderr.reconfigure(line_buffering=True) + except (AttributeError, io.UnsupportedOperation): + pass # Stream doesn't support reconfiguration
341-366: AddCancelledErrorhandling for clean task cancellation.The
while Trueloop runs indefinitely without handlingasyncio.CancelledError. When this task is cancelled during shutdown, it will propagate the exception instead of exiting cleanly.async def monitor_event_loop_lag(interval: float = 5.0, threshold: float = 1.0) -> None: """Background task to detect event loop blocking. Args: interval: Check interval in seconds threshold: Lag threshold to trigger warning (seconds) """ logger = logging.getLogger(__name__) last_check = time.time() - while True: - await asyncio.sleep(interval) - now = time.time() - actual_elapsed = now - last_check - expected_elapsed = interval - lag = actual_elapsed - expected_elapsed - - if lag > threshold: - logger.warning( - "⚠️ Event loop lag detected: expected %.1fs, actual %.1fs (lag: %.1fs)", - expected_elapsed, - actual_elapsed, - lag, - ) - - last_check = now + try: + while True: + await asyncio.sleep(interval) + now = time.time() + actual_elapsed = now - last_check + expected_elapsed = interval + lag = actual_elapsed - expected_elapsed + + if lag > threshold: + logger.warning( + "⚠️ Event loop lag detected: expected %.1fs, actual %.1fs (lag: %.1fs)", + expected_elapsed, + actual_elapsed, + lag, + ) + + last_check = now + except asyncio.CancelledError: + logger.debug("Event loop lag monitor cancelled") + raisegrail/validation/miner_validator.py (1)
177-187: The checkpoint window extraction is fragile and relies on specific path format.The parsing logic assumes the model path contains "checkpoint-{window}" and splits on that pattern. This could fail silently or extract incorrect values if:
- The path uses a different separator (e.g.,
checkpoint_1000vscheckpoint-1000)- Multiple "checkpoint-" segments exist in the path
- The model was loaded from a different source (e.g., HuggingFace Hub)
Consider adding validation or logging when extraction fails.
# Extract validator's checkpoint window from model path (if available) validator_checkpoint_window = None if hasattr(model, "name_or_path"): model_path = str(model.name_or_path) # Parse checkpoint-{window} from path like "/cache/checkpoints/checkpoint-1000" if "checkpoint-" in model_path: try: checkpoint_segment = model_path.split("checkpoint-")[-1].split("/")[0] validator_checkpoint_window = int(checkpoint_segment) + logger.debug( + "Extracted validator checkpoint window %s from path %s", + validator_checkpoint_window, + model_path, + ) except (ValueError, IndexError): - pass + logger.warning( + "Failed to parse checkpoint window from model path: %s", + model_path, + )scripts/test_monitoring_shared.py (1)
173-217: Consider consolidating multipleasyncio.run()calls into a single async workflow.The worker process makes multiple separate
asyncio.run()calls (lines 173, 199, 205, 206, 207, 216, 217). Each call creates and destroys a new event loop, which is inefficient and can cause issues with async resources that expect a persistent loop.For a test script this is acceptable, but consider refactoring to a single async workflow:
async def _worker_async(manager, run_id, process_label, config): actual_run_id = await manager.start_run(f"{process_label}_process", config) await manager.log_gauge(f"testing/{process_label}_startup_time", elapsed) await manager.log_gauge(f"testing/{process_label}_timestamp", time.time()) await manager.log_gauge(f"testing/{process_label}_connected", 1.0) await manager.flush_metrics() await asyncio.sleep(2) await manager.finish_run(actual_run_id) await manager.shutdown() # Then call with single asyncio.run() asyncio.run(_worker_async(manager, run_id, process_label, config))grail/trainer/ipc.py (1)
122-142: Consider documenting the unbounded queue behavior.The
snapshot_queueis an unboundedmultiprocessing.Queue(). If the upload worker falls behind, this could theoretically grow indefinitely. For the async trainer use case with infrequent snapshots, this is likely acceptable, but it may be worth documenting in the docstring.grail/infrastructure/checkpoint_consumer.py (1)
600-624: Deprecated function references non-existent method.The deprecated function
_find_latest_ready_checkpoint_window_DEPRECATEDcallscheckpoint_manager._is_checkpoint_ready(window)on line 619, but this method doesn't exist in the currentCheckpointManagerclass. If someone tries to use this preserved reference, it will fail.Consider removing the deprecated function entirely or updating it to use
_get_checkpoint_ready_window:- for window in sorted(windows, reverse=True): - if await checkpoint_manager._is_checkpoint_ready(window): + for window in sorted(windows, reverse=True): + if await checkpoint_manager._get_checkpoint_ready_window(window) is not None: logger.info(f"Found latest ready checkpoint at window {window}")grail/trainer/upload_worker.py (2)
37-65: Potential issue withasyncio.get_event_loop()deprecation.Using
asyncio.get_event_loop()is deprecated in Python 3.10+ and may raiseDeprecationWarningor fail in some contexts. Consider usingasyncio.get_running_loop()instead since this is called within an async function.- loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop()
67-100: Import statement inside function reduces readability.The
import oson line 88 should be moved to the top of the file with other imports for consistency and to avoid repeated import overhead if this function is called multiple times.grail/trainer/config.py (1)
129-133: Consider removing or deprecating legacy sglang-specific fields.Lines 130-133 still contain
sglang_*prefixed fields (sglang_mem_fraction_static,sglang_context_length, etc.) while the main server config was renamed toserver_*. This inconsistency could cause confusion about which fields apply to which backend.grail/trainer/checkpoint_publisher.py (3)
269-299: Consider adding total upload timeout in addition to per-file timeout.The current implementation has per-file timeouts via
upload_timeout, but if there are many files, the total upload time could be unbounded. Consider adding a total operation timeout to prevent indefinite hangs.
386-528:upload_from_stagingduplicates upload logic frompublish_checkpoint.The
upload_fileclosure and upload logic inupload_from_staging(lines 467-496) is nearly identical to the one inpublish_checkpoint(lines 269-299). Consider extracting to a shared helper to reduce duplication and maintenance burden.Consider extracting the common upload logic:
async def _upload_files_from_directory( self, source_dir: Path, remote_prefix: str, exclude_files: set[str] | None = None, ) -> tuple[bool, float, float]: """Upload all files from a directory with retry and timing. Returns: Tuple of (success, total_mb, duration_seconds) """ semaphore = asyncio.Semaphore(4) # ... shared upload logic ...
431-433: Fallback forparent_windowcould be off-by-one if windows aren't contiguous.The fallback
parent_window = max(0, target_window - WINDOW_LENGTH)assumes sequential windows. If windows can be skipped (e.g., due to downtime), this could produce incorrect lineage. Consider documenting this assumption.grail/trainer/training_process.py (5)
23-25: Module-level side effect may cause issues in multi-process environments.Setting
PYTORCH_CUDA_ALLOC_CONFat import time could affect other processes that import this module but don't want this configuration. Consider moving this to the process entry point (run_training_process).-if "PYTORCH_CUDA_ALLOC_CONF" not in os.environ: - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"Move to
run_training_process:def run_training_process(...): if "PYTORCH_CUDA_ALLOC_CONF" not in os.environ: os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # ... rest of function
210-210:WAIT_WINDOW_LENGTH = 0.0makes the wait loop a no-op.With
WAIT_WINDOW_LENGTH = 0.0, the conditioncurrent_block >= target_blockis immediately true sincetarget_block = current_block + 0.0. The TODO comment suggests this is intentional for now, but the loop structure is misleading.Consider simplifying if no wait is needed:
# TODO: Add waiting when using extra process/GPUs # Currently no wait - miners should already have the checkpoint logger.info("Skipping initial checkpoint wait (WAIT_WINDOW_LENGTH=0)")
306-408: Monitoring initialization is overly complex with duplicate logging.The
_initialize_monitoringmethod has multiple layers of try/catch with verbose logging. While thorough, this makes the code harder to follow. Consider simplifying the success path and consolidating error handling.
317-320: Setting environment variables during initialization has side effects.Lines 319-320 modify
os.environwhich affects the entire process and any libraries that have already read these values. Consider setting these before the process starts.
770-784: Fallback window calculation may hide subtensor issues.When
get_current_blockfails (line 782), the fallbackself.epoch_counter * WINDOW_LENGTHprovides a pseudo-window that's decoupled from actual chain state. This could cause training to proceed with incorrect window context. Consider logging at WARNING level instead of just returning the fallback.except Exception as exc: - logger.warning("Failed to get current block: %s", exc) + logger.warning( + "Failed to get current block (using fallback epoch-based window): %s", exc + ) return self.epoch_counter * WINDOW_LENGTHgrail/neurons/trainer.py (4)
116-118: Potential issue:multiprocessing.set_start_methodcalled in constructor.Calling
set_start_method("spawn", force=True)in the constructor can cause issues if:
- Other parts of the codebase have already set a different method
- Multiple
TrainerNeuroninstances are created (though unlikely)Consider moving this to module-level or the CLI entrypoint where it's called once at startup.
- def __init__(self, context: TrainerContext) -> None: - """Initialize trainer neuron. - - Args: - context: Resources required for training - """ - super().__init__() - self._context = context - - # Set multiprocessing start method to 'spawn' for CUDA compatibility - multiprocessing.set_start_method("spawn", force=True) - logger.info("Set multiprocessing start method to 'spawn' for CUDA compatibility") + def __init__(self, context: TrainerContext) -> None: + """Initialize trainer neuron. + + Args: + context: Resources required for training + """ + super().__init__() + self._context = context + + # Ensure 'spawn' method for CUDA compatibility (safe to call multiple times with force=True) + try: + multiprocessing.set_start_method("spawn", force=True) + logger.info("Set multiprocessing start method to 'spawn' for CUDA compatibility") + except RuntimeError: + # Already set - this is fine + logger.debug("Multiprocessing start method already set")
419-420: Minor: Useasyncio.get_running_loop()instead of deprecatedget_event_loop().
asyncio.get_event_loop()is deprecated in Python 3.10+ when called from a coroutine. Useasyncio.get_running_loop()instead.- # Use Event.wait() with timeout in thread pool to avoid blocking event loop - loop = asyncio.get_event_loop() + # Use Event.wait() with timeout in thread pool to avoid blocking event loop + loop = asyncio.get_running_loop()
831-835: Fragile: Backend type inference from class name.Inferring
backend_typeby checking if "WandB" or "Null" is in the class name is fragile. Consider adding abackend_typeproperty to the backend interface.- # Add backend_type from backend class name (needed by subprocess) - backend_class_name = self._context.monitor.backend.__class__.__name__ - if "WandB" in backend_class_name: - monitor_config["backend_type"] = "wandb" - elif "Null" in backend_class_name: - monitor_config["backend_type"] = "null" + # Add backend_type from backend (needed by subprocess) + if hasattr(self._context.monitor.backend, "backend_type"): + monitor_config["backend_type"] = self._context.monitor.backend.backend_type + else: + # Fallback: infer from class name + backend_class_name = self._context.monitor.backend.__class__.__name__ + if "WandB" in backend_class_name: + monitor_config["backend_type"] = "wandb" + elif "Null" in backend_class_name: + monitor_config["backend_type"] = "null"
842-844: Minor: Accessing private_configattribute.The fallback accesses
self._context.monitor._config, a private attribute. Consider exposing a public method on the monitor manager for this use case.grail/monitoring/backends/wandb_backend.py (3)
95-97: Useasyncio.get_running_loop()instead of deprecatedget_event_loop().
asyncio.get_event_loop()is deprecated when called from a coroutine. Since_run_executoris called from async methods, useasyncio.get_running_loop().def _run_executor(self, func: Any, *args: Any) -> Any: """Run synchronous function in executor.""" - return asyncio.get_event_loop().run_in_executor(None, func, *args) + return asyncio.get_running_loop().run_in_executor(None, func, *args)
323-325: Consider: 256 define_metric calls at initialization.Defining metrics for all 256 possible UIDs upfront may add initialization overhead. An alternative is lazy definition on first log via
_maybe_define_step_for_name. However, since this runs once per run, the current approach is acceptable if init time is not critical.
476-485: Edge case: String coercion doesn't handle negative numbers or scientific notation.
isdigit()returnsFalsefor "-5" or "1e10". The fallback tofloat()handles some cases, but negative integers will become floats.def _coerce_tag_value(self, raw: Any) -> Any: """Coerce tag value from string to appropriate type.""" if isinstance(raw, str): - if raw.isdigit(): - return int(raw) try: - return float(raw) + # Try int first (handles negative), then float + float_val = float(raw) + if float_val.is_integer(): + return int(float_val) + return float_val except Exception: return raw return raw
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (39)
.env.example(1 hunks)grail/cli/mine.py(6 hunks)grail/cli/train.py(4 hunks)grail/environments/loop.py(3 hunks)grail/environments/math_hendrycks_env.py(3 hunks)grail/infrastructure/checkpoint_consumer.py(7 hunks)grail/infrastructure/network.py(4 hunks)grail/logging_utils.py(3 hunks)grail/model/provider.py(1 hunks)grail/model/train_loading.py(2 hunks)grail/monitoring/backends/wandb_backend.py(13 hunks)grail/monitoring/config.py(4 hunks)grail/neurons/miner.py(4 hunks)grail/neurons/trainer.py(2 hunks)grail/neurons/validator.py(1 hunks)grail/shared/chat_templates.py(3 hunks)grail/shared/constants.py(4 hunks)grail/trainer/__init__.py(1 hunks)grail/trainer/algorithms/grpo.py(12 hunks)grail/trainer/checkpoint_publisher.py(1 hunks)grail/trainer/checkpointing.py(0 hunks)grail/trainer/config.py(3 hunks)grail/trainer/evaluator.py(1 hunks)grail/trainer/inference_server.py(1 hunks)grail/trainer/ipc.py(1 hunks)grail/trainer/replay_buffer.py(1 hunks)grail/trainer/service.py(4 hunks)grail/trainer/snapshot_manager.py(1 hunks)grail/trainer/training_process.py(1 hunks)grail/trainer/upload_worker.py(1 hunks)grail/validation/miner_validator.py(4 hunks)grail/validation/service.py(4 hunks)research/offline_trainer/tests/test_gpu_integration.py(2 hunks)scripts/test_monitoring_shared.py(1 hunks)scripts/test_wandb_shared.py(1 hunks)tests/integration/test_async_trainer.py(1 hunks)tests/integration/test_train_cli_smoke.py(1 hunks)tests/integration/test_validation_service.py(1 hunks)tests/unit/trainer/test_replay_buffer.py(1 hunks)
💤 Files with no reviewable changes (1)
- grail/trainer/checkpointing.py
🧰 Additional context used
🧬 Code graph analysis (17)
grail/model/provider.py (1)
tests/integration/protocol/test_proof_cross_framework.py (1)
model(39-43)
tests/unit/trainer/test_replay_buffer.py (2)
grail/trainer/algorithms/grpo.py (1)
GRPOGroup(243-256)grail/trainer/replay_buffer.py (2)
RecencyWeightedBuffer(66-250)create_replay_buffer(253-280)
grail/trainer/ipc.py (2)
grail/infrastructure/chain.py (1)
stop(415-431)grail/neurons/base.py (1)
heartbeat(123-125)
tests/integration/test_validation_service.py (1)
grail/infrastructure/checkpoint_consumer.py (1)
CheckpointManager(87-562)
grail/validation/miner_validator.py (1)
tests/integration/protocol/test_proof_cross_framework.py (1)
model(39-43)
scripts/test_wandb_shared.py (1)
scripts/test_monitoring_shared.py (1)
main(299-331)
grail/infrastructure/network.py (1)
tests/integration/infrastructure/test_block_timestamp.py (1)
subtensor(19-28)
grail/cli/train.py (3)
grail/trainer/checkpoint_publisher.py (1)
CheckpointPublisher(87-528)grail/neurons/trainer.py (1)
TrainerContext(67-88)grail/model/train_loading.py (3)
ModelLoadSpec(30-41)parse_train_env(56-81)parse_ref_env(84-109)
grail/model/train_loading.py (1)
grail/model/provider.py (1)
get_model(55-150)
tests/integration/test_async_trainer.py (2)
grail/trainer/snapshot_manager.py (10)
save_snapshot_atomic(61-129)check_snapshot_ready(131-137)copy_snapshot_to_staging(139-169)cleanup_staging(171-176)check_pause_flag(178-184)set_pause_flag(186-189)clear_pause_flag(191-195)get_training_heartbeat_age(201-214)set_training_heartbeat(197-199)get_latest_snapshot_path(216-223)tests/integration/test_train_cli_smoke.py (1)
save_pretrained(60-63)
grail/validation/service.py (1)
grail/infrastructure/checkpoint_consumer.py (1)
get_latest_ready_checkpoint(498-562)
grail/monitoring/backends/wandb_backend.py (2)
grail/monitoring/base.py (3)
MetricData(29-43)log_metric(73-83)MetricType(19-25)grail/monitoring/backends/null_backend.py (1)
log_metric(34-40)
grail/trainer/config.py (1)
grail/trainer/inference_server.py (1)
start_server(70-72)
grail/trainer/replay_buffer.py (1)
grail/trainer/algorithms/grpo.py (1)
GRPOGroup(243-256)
grail/neurons/miner.py (1)
grail/infrastructure/checkpoint_consumer.py (2)
default_checkpoint_cache_root(570-574)get_latest_ready_checkpoint(498-562)
grail/trainer/snapshot_manager.py (2)
tests/integration/protocol/test_proof_cross_framework.py (2)
model(39-43)tokenizer(47-52)tests/integration/test_train_cli_smoke.py (1)
save_pretrained(60-63)
tests/integration/test_train_cli_smoke.py (4)
grail/cli/train.py (1)
train(40-124)grail/trainer/snapshot_manager.py (7)
save_snapshot_atomic(61-129)set_training_heartbeat(197-199)check_snapshot_ready(131-137)copy_snapshot_to_staging(139-169)cleanup_staging(171-176)get_training_heartbeat_age(201-214)get_latest_snapshot_path(216-223)grail/infrastructure/chain.py (1)
stop(415-431)tests/integration/infrastructure/test_block_timestamp.py (1)
chain_manager(32-65)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: Run Tests (3.11)
- GitHub Check: Run Tests (3.10)
- GitHub Check: Cursor Bugbot
- GitHub Check: Run Tests (3.11)
- GitHub Check: Run Tests (3.10)
| git_commit: str = "unknown" | ||
| created_at: float = 0.0 | ||
| model_name: str = "no_name" | ||
| parent_window: int | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parent_window added but not populated during metadata fetch.
The parent_window field is added to CheckpointMetadata, but in _fetch_metadata() (lines 382-391), it's not read from the payload, leaving it as None even when present in the remote metadata. This inconsistency could cause issues when consumers need to track checkpoint lineage.
metadata = CheckpointMetadata(
window=payload.get("window", window),
file_manifest=payload.get("file_manifest", {}),
training_config=payload.get("training_config", {}),
git_commit=payload.get("git_commit", "unknown"),
created_at=float(payload.get("created_at", 0.0)),
model_name=payload.get("model_name", "no_name"),
+ parent_window=payload.get("parent_window"),
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| parent_window: int | None = None | |
| metadata = CheckpointMetadata( | |
| window=payload.get("window", window), | |
| file_manifest=payload.get("file_manifest", {}), | |
| training_config=payload.get("training_config", {}), | |
| git_commit=payload.get("git_commit", "unknown"), | |
| created_at=float(payload.get("created_at", 0.0)), | |
| model_name=payload.get("model_name", "no_name"), | |
| parent_window=payload.get("parent_window"), | |
| ) |
🤖 Prompt for AI Agents
In grail/infrastructure/checkpoint_consumer.py around line 72 (where
parent_window: int | None = None is declared) and in the metadata fetch logic at
lines ~382-391, populate the new parent_window field from the fetched payload:
read payload.get("parent_window") (or the equivalent key used by the remote
metadata), validate/convert it to int if present, and pass that value into the
CheckpointMetadata constructor so it is not left as None when present remotely;
also handle missing/null values by leaving parent_window as None and add a brief
defensive check to avoid ValueError on conversion.
| # Track last successful call for idle detection | ||
| object.__setattr__(self, "_last_call_timestamp", time.time()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clarify idle tracking semantics and align comments/logs with behavior
The idle‑tracking itself (_last_call_timestamp on success, checked in _call_with_retry) looks reasonable, but the surrounding comments and log message are now inconsistent with the implementation:
- The comment at Line 230 talks about “Double timeout if connection has been idle for 20+ seconds”, but:
- The code restarts the subtensor when
idle_duration > 60.0. - The code does not modify
timeoutat all in this block (only metagraph calls double timeout earlier).
- The code restarts the subtensor when
- The log message still says “restarting subtensor and doubling timeout for %s”, which is misleading since the timeout isn’t changed here.
I’d recommend either (a) actually doubling the timeout here, or (b) updating the comment and log text to match the current “restart after 60s idle” behavior. If the intent is just to restart after 60s and keep timeouts unchanged, something like this would keep the implementation and docs in sync:
- # Double timeout if connection has been idle for 20+ seconds
- # Research-based threshold:
- # - Bittensor WebSocket auto-closes after 10s inactivity
- # - Substrate layer closes after ~60s inactivity
- # - 20s catches stale connections early without false positives
- # - Critical for upload worker (40-300s idle during R2 uploads)
+ # Restart connection if it has been idle for an extended period.
+ # Research-based threshold:
+ # - Bittensor WebSocket auto-closes after ~10s inactivity
+ # - Substrate layer closes after ~60s inactivity
+ # - We restart after 60s of no successful calls to avoid stale connections
+ # while avoiding excessive churn during normal gaps (e.g. long R2 uploads).
@@
- if idle_duration > 60.0:
- logger.warning(
- "⏰ Connection idle for %.1fs, restarting subtensor and doubling timeout for %s",
- idle_duration,
- method_name,
- )
+ if idle_duration > 60.0:
+ logger.warning(
+ "⏰ Connection idle for %.1fs, restarting subtensor before calling %s",
+ idle_duration,
+ method_name,
+ )If you do intend to also relax timeouts for long‑idle connections, you could additionally add timeout *= 2 inside this block (or some capped multiplier) and adjust the wording accordingly.
Also applies to: 168-169, 230-245
🤖 Prompt for AI Agents
In grail/infrastructure/network.py around lines 62-63 (and also review lines
168-169 and 230-245), the comments and log messages claim the code "doubles
timeout" for long-idle connections but the implementation only restarts the
subtensor when idle_duration > 60.0 and does not change timeout; update the code
to make behavior and docs consistent by either (A) actually doubling/relaxing
the timeout inside the idle branch (e.g., timeout *= 2 with an optional cap) and
update comments/logs to mention both restart and timeout adjustment, or (B)
leave timeouts unchanged and revise the comment and log messages to remove any
mention of doubling timeout and instead state that the subtensor is being
restarted after 60s of idle; apply the chosen fix consistently to all affected
comment/log locations (lines mentioned above).
| def _serialize_wallet(self) -> dict[str, str]: | ||
| """Serialize wallet for pickling to child process. | ||
| Returns: | ||
| Dictionary of wallet arguments | ||
| """ | ||
| logger.info( | ||
| "Reloading training models: train_model=%s ref_model=%s", | ||
| self._context.train_model is not None, | ||
| self._context.ref_model is not None, | ||
| ) | ||
| return { | ||
| "name": self._context.wallet.name, | ||
| "hotkey": self._context.wallet.hotkey_str, | ||
| "path": self._context.wallet.path, | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, let's find wallet-related imports and class definitions
rg -n "class.*[Ww]allet" --type py | head -20Repository: one-covenant/grail
Length of output: 187
🏁 Script executed:
# Search for hotkey_str usage across the codebase
rg -n "hotkey_str" --type py -B 2 -A 2Repository: one-covenant/grail
Length of output: 815
🏁 Script executed:
# Look for _FakeWallet or test wallet implementations
rg -n "_FakeWallet|FakeWallet" --type py -B 2 -A 5Repository: one-covenant/grail
Length of output: 1430
🏁 Script executed:
# Search for bittensor wallet imports to understand the wallet object
rg -n "from bittensor|import.*wallet|bt\.wallet" --type py | head -20Repository: one-covenant/grail
Length of output: 1629
🏁 Script executed:
# Search for wallet attribute access patterns in production code
rg -n "wallet\.(hotkey|name|path)" --type py -B 2 -A 2 | grep -v test | head -40Repository: one-covenant/grail
Length of output: 2696
🏁 Script executed:
# Look at the _reconstruct_wallet function more closely
cat -n grail/trainer/training_process.py | sed -n '1045,1060p'Repository: one-covenant/grail
Length of output: 545
🏁 Script executed:
# Check where wallet_args dict is created and passed to _reconstruct_wallet
rg -n "_reconstruct_wallet|_serialize_wallet" --type py -B 3 -A 3Repository: one-covenant/grail
Length of output: 2478
🏁 Script executed:
# Search for bittensor wallet source or any documentation about its attributes
rg -n "\.hotkey_str|\.hotkey\b" --type py | grep -v test | head -20Repository: one-covenant/grail
Length of output: 2014
🏁 Script executed:
# Search for wallet property definitions or any string representations
rg -n "hotkey" grail/trainer/training_process.py -B 2 -A 2Repository: one-covenant/grail
Length of output: 3126
🏁 Script executed:
# Look at actual wallet initialization to understand what parameters bt.wallet accepts
rg -n "bt\.wallet\(" --type py -B 1 -A 1 | head -40Repository: one-covenant/grail
Length of output: 1374
🏁 Script executed:
# Check if there's a hotkey_str property anywhere in the codebase
rg -n "hotkey_str" --type pyRepository: one-covenant/grail
Length of output: 223
🏁 Script executed:
# Look at the upload_worker that also uses wallet reconstruction
cat -n grail/trainer/upload_worker.py | sed -n '290,310p'Repository: one-covenant/grail
Length of output: 954
🏁 Script executed:
# Check if there's any documentation or comments about wallet hotkey parameter
rg -n "hotkey.*=.*get_conf|BT_WALLET_HOT" --type py -B 2 -A 5Repository: one-covenant/grail
Length of output: 4223
🏁 Script executed:
# Look at test to understand what _FakeWallet expects as hotkey parameter
cat -n tests/integration/test_train_cli_smoke.py | sed -n '29,50p'Repository: one-covenant/grail
Length of output: 886
🏁 Script executed:
# Search for any property or method that gives hotkey string representation
rg -n "def.*hotkey|@property.*hotkey" --type pyRepository: one-covenant/grail
Length of output: 1628
🏁 Script executed:
# Search for any wallet attributes that might store the original hotkey name/string
rg -n "wallet\." grail/neurons/trainer.py | head -20Repository: one-covenant/grail
Length of output: 222
🏁 Script executed:
# Check if wallet has any other attributes we should use
rg -n "self\._context\.wallet\." grail/neurons/trainer.py -A 1Repository: one-covenant/grail
Length of output: 236
🏁 Script executed:
# Look for wallet name attribute to understand the pattern
rg -n "wallet\.name" --type py | grep -v test | head -10Repository: one-covenant/grail
Length of output: 804
🏁 Script executed:
# Search for any comments or documentation about wallet structure
rg -n "wallet.*hotkey|hotkey.*wallet" grail/neurons/trainer.py -B 3 -A 3Repository: one-covenant/grail
Length of output: 278
🏁 Script executed:
# Check the context object to understand where wallet comes from
rg -n "_context.*wallet|wallet.*context" grail/neurons/trainer.py -B 2 -A 2Repository: one-covenant/grail
Length of output: 464
🏁 Script executed:
# Look at the _FakeWallet definition once more to see exactly what it mimics
cat -n tests/integration/test_train_cli_smoke.py | sed -n '29,40p'Repository: one-covenant/grail
Length of output: 508
🏁 Script executed:
# Search for any real wallet initialization or attribute documentation
rg -n "class.*Context|context.*wallet" --type py grail/neurons/trainer.py | head -10Repository: one-covenant/grail
Length of output: 289
🏁 Script executed:
# Find and examine TrainerContext class definition
rg -n "class TrainerContext" --type py -A 20Repository: one-covenant/grail
Length of output: 1426
Critical bug: wallet.hotkey_str attribute does not exist on bittensor wallet objects.
Line 810 accesses self._context.wallet.hotkey_str, which is not a valid attribute on real bt.wallet instances. The test suite only passes because _FakeWallet (in tests) manually defines this attribute, but production code will crash with AttributeError at runtime.
Real bittensor wallets have wallet.hotkey (a Keypair object), wallet.name, and wallet.path, but not hotkey_str. To fix the serialization, determine what value should be passed as the hotkey parameter when reconstructing the wallet with bt.wallet(**wallet_args). The current implementation cannot work in production.
🤖 Prompt for AI Agents
In grail/neurons/trainer.py around lines 802–812, the serializer reads a
non-existent attribute wallet.hotkey_str; change the serialized key to "hotkey"
and produce a string representation that bt.wallet(...) accepts. Replace the
hotkey value with a string derived from the Keypair, e.g. use
wallet.hotkey.ss58_address if present, otherwise fallback to str(wallet.hotkey)
or wallet.hotkey.public_key.hex(); ensure the returned dict keys are "name",
"hotkey", and "path" so bt.wallet(**wallet_args) can reconstruct the wallet
without AttributeError.
grail/shared/chat_templates.py
Outdated
|
|
||
|
|
||
| def build_qwen_chat_template(system_prompt: str, reasoning_start: str = "") -> str: | ||
| def build_qwen_chat_template(system_prompt: str) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
rg -n --type=py 'build_qwen_chat_template\s*\(' -A 2 -B 2Repository: one-covenant/grail
Length of output: 4202
Update two remaining call sites that still pass the reasoning_start parameter.
The function signature change removes the reasoning_start parameter, but two files still pass this argument and will fail:
research/trl/train_trl_gsm8k.py:90-91— Removereasoning_start=REASONING_STARTresearch/offline_trainer/src/grail_offline/pipelines/offline_grpo.py:312— Change to pass onlySYSTEM_PROMPT
Seven other call sites have been correctly updated and only pass the system_prompt argument.
🤖 Prompt for AI Agents
In grail/shared/chat_templates.py around line 9, the function
build_qwen_chat_template has removed the reasoning_start parameter; update the
two remaining call sites: in research/trl/train_trl_gsm8k.py at lines ~90-91
remove the named argument reasoning_start=REASONING_START so the call passes
only SYSTEM_PROMPT (or the appropriate system prompt variable), and in
research/offline_trainer/src/grail_offline/pipelines/offline_grpo.py at line
~312 change the call to pass only SYSTEM_PROMPT (drop the reasoning_start
argument) so both calls match the new single-argument signature.
| micro_batch_size = max(1, int(config.batch_size)) | ||
| grad_accum_steps = max(1, self.config.grad_accum_steps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent config source: config vs self.config.
Line 1650 uses the config parameter while line 1651 uses self.config. If the passed config should take precedence (as it's explicitly passed to train_epoch), consider using config.grad_accum_steps for consistency:
micro_batch_size = max(1, int(config.batch_size))
-grad_accum_steps = max(1, self.config.grad_accum_steps)
+grad_accum_steps = max(1, int(config.grad_accum_steps))Alternatively, if self.config is intentionally used for certain parameters, add a comment explaining the distinction.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| micro_batch_size = max(1, int(config.batch_size)) | |
| grad_accum_steps = max(1, self.config.grad_accum_steps) | |
| micro_batch_size = max(1, int(config.batch_size)) | |
| grad_accum_steps = max(1, int(config.grad_accum_steps)) |
🤖 Prompt for AI Agents
In grail/trainer/algorithms/grpo.py around lines 1650-1651, the code mixes the
local parameter config with self.config (micro_batch_size uses config while
grad_accum_steps uses self.config); make the source consistent by using
config.grad_accum_steps instead of self.config.grad_accum_steps (or, if
self.config is intentionally authoritative, add a clear comment explaining why
and use self.config for both), and ensure a safe fallback (e.g., max(1,
int(config.grad_accum_steps))) so the value is validated like micro_batch_size.
| # Atomic rename: remove old target if exists, then rename | ||
| if target_dir.exists(): | ||
| # Remove old snapshot | ||
| shutil.rmtree(target_dir) | ||
|
|
||
| # Rename temp to target (atomic on POSIX) | ||
| temp_dir.rename(target_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Race condition: rmtree + rename is not atomic.
If a crash occurs between shutil.rmtree(target_dir) at line 111 and temp_dir.rename(target_dir) at line 114, the snapshot is lost. Consider a safer pattern:
- Rename
latest→latest.old - Rename temp →
latest - Delete
latest.old
This ensures at least one valid snapshot always exists.
# Atomic rename: remove old target if exists, then rename
if target_dir.exists():
- # Remove old snapshot
- shutil.rmtree(target_dir)
-
- # Rename temp to target (atomic on POSIX)
- temp_dir.rename(target_dir)
+ # Rename old snapshot to backup first
+ backup_dir = self.snapshot_dir / "latest.old"
+ if backup_dir.exists():
+ shutil.rmtree(backup_dir)
+ target_dir.rename(backup_dir)
+
+ # Rename temp to target (atomic on POSIX)
+ temp_dir.rename(target_dir)
+
+ # Clean up backup after successful rename
+ backup_dir = self.snapshot_dir / "latest.old"
+ if backup_dir.exists():
+ try:
+ shutil.rmtree(backup_dir)
+ except Exception:
+ pass # Best effort cleanupCommittable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In grail/trainer/snapshot_manager.py around lines 108-114, the current sequence
removes target_dir with shutil.rmtree() then renames temp_dir → target_dir which
can lose the snapshot if the process crashes between those steps; instead
implement the safe three-step swap: rename the existing target_dir to a backup
name (e.g., target_dir + ".old" or a unique suffix) using an atomic
rename/replace, then rename temp_dir to target_dir, and finally remove the
backup; ensure you handle pre-existing backup (remove or choose unique name),
catch and propagate errors so failures during the swap leave either the old or
the new snapshot intact, and clean up temp_dir or backup appropriately in error
paths.
| # Set READY-{ready_window} marker | ||
| try: | ||
| finalized = await checkpoint_publisher.finalize_checkpoint_ready( | ||
| checkpoint_window, | ||
| ready_window, | ||
| ) | ||
| if finalized: | ||
| logger.info( | ||
| "✅ Set READY-%s marker for checkpoint-%s", | ||
| ready_window, | ||
| checkpoint_window, | ||
| ) | ||
| except Exception as exc: | ||
| logger.error("Failed to finalize checkpoint READY marker: %s", exc) | ||
|
|
||
| # Update last uploaded window to prevent duplicates | ||
| last_uploaded_window = checkpoint_window | ||
|
|
||
| # Cleanup staging directory | ||
| snapshot_manager.cleanup_staging() | ||
| logger.info("Upload cycle complete for checkpoint-%s", checkpoint_window) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
READY marker failure doesn't prevent last_uploaded_window update.
If finalize_checkpoint_ready fails (line 189), the upload is still considered complete and last_uploaded_window is updated (line 192). This could leave orphaned checkpoints without READY markers that consumers won't discover.
Consider only updating last_uploaded_window when the READY marker is successfully set:
# Set READY-{ready_window} marker
try:
finalized = await checkpoint_publisher.finalize_checkpoint_ready(
checkpoint_window,
ready_window,
)
if finalized:
logger.info(
"✅ Set READY-%s marker for checkpoint-%s",
ready_window,
checkpoint_window,
)
+ # Only mark as uploaded if READY marker was set successfully
+ last_uploaded_window = checkpoint_window
+ else:
+ logger.warning("READY marker not finalized, will retry upload next cycle")
except Exception as exc:
logger.error("Failed to finalize checkpoint READY marker: %s", exc)
+ # Don't update last_uploaded_window so retry is possible
- # Update last uploaded window to prevent duplicates
- last_uploaded_window = checkpoint_window🤖 Prompt for AI Agents
In grail/trainer/upload_worker.py around lines 176 to 196, the code updates
last_uploaded_window and treats the upload as complete even if
finalize_checkpoint_ready() fails; change the flow so last_uploaded_window is
only set when the READY marker was successfully created (i.e., finalized is
True). Specifically, move the assignment of last_uploaded_window into the try
block under the condition that finalized is True, and ensure that on exception
you do not update last_uploaded_window (but still perform necessary
cleanup/logging as appropriate) so orphaned checkpoints won't be considered
uploaded.
| # Step 1: Start primary process | ||
| logger.info("\n[STEP 1] Starting primary process...") | ||
| run_id = await primary_process_workflow(entity=entity) | ||
|
|
||
| logger.info("\n[STEP 2] Spawning worker process to connect to run...") | ||
|
|
||
| # Step 2: Spawn worker process | ||
| process = multiprocessing.Process( | ||
| target=worker_process_workflow, | ||
| args=(run_id, entity, "training_worker"), | ||
| ) | ||
| process.start() | ||
| process.join(timeout=30) | ||
|
|
||
| if process.is_alive(): | ||
| logger.error("❌ Worker process timeout!") | ||
| process.terminate() | ||
| process.join() | ||
| elif process.exitcode != 0: | ||
| logger.error(f"❌ Worker process failed with exit code {process.exitcode}") | ||
| else: | ||
| logger.info("✅ Worker process completed successfully") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test sequence has a timing issue: primary finishes before worker starts.
In test_shared_mode_new, the primary process workflow at line 238 runs to completion (including finish_run and shutdown on lines 107-108) before the worker is spawned at line 243. This means the worker attempts to connect to a run that has already been finished.
For a valid shared-mode test, the primary should remain active while the worker connects.
Consider restructuring to keep primary alive:
async def test_shared_mode_new(entity: str | None = None) -> None:
...
# Step 1: Start primary process
logger.info("\n[STEP 1] Starting primary process...")
- run_id = await primary_process_workflow(entity=entity)
-
- logger.info("\n[STEP 2] Spawning worker process to connect to run...")
-
- # Step 2: Spawn worker process
- process = multiprocessing.Process(
+ # Start primary in background, get run_id, then spawn worker while primary is alive
+ # This requires restructuring primary_process_workflow to not immediately shutdownCommittable suggestion skipped: line range outside the PR's diff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is being reviewed by Cursor Bugbot
Details
You are on the Bugbot Free tier. On this plan, Bugbot will review limited PRs each billing cycle.
To receive Bugbot reviews on all of your PRs, visit the Cursor dashboard to activate Pro and start your 14-day free trial.
|
|
||
| # Stage 5: Rank by combined efficiency score and select top-k groups | ||
| max_groups: int = int(getattr(config, "grpo_max_groups", 8)) | ||
| max_groups: int = int(getattr(config, "grpo_max_groups_per_window", 10000)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Config attribute name mismatch disables group filtering
The code uses getattr(config, "grpo_max_groups_per_window", 10000) but TrainingConfig defines the attribute as grpo_max_groups (not grpo_max_groups_per_window). The attribute name mismatch causes getattr to always return the fallback value of 10000, effectively disabling the configured GRPO_MAX_GROUPS limit (default 32) and allowing far more groups than intended in training.
Additional Locations (1)
| timers, | ||
| monitor, | ||
| self.use_drand, | ||
| checkpoint_window, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Miner passes None checkpoint_window to function expecting int
When get_latest_ready_checkpoint() returns None but a model is already loaded from a previous iteration, the code continues past the checkpoint loading block and calls generate_rollouts_for_window passing checkpoint_window which is None. The function signature expects checkpoint_window: int, causing a type mismatch. The code should likely use current_checkpoint_window (which stores the previously loaded checkpoint window) instead of the local checkpoint_window variable.
…streamline IPC usage for training pause coordination
… pause coordination
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (3)
grail/trainer/training_process.py (1)
831-832: Critical: Remove hardcoded UID 80 filter from production code.Lines 831-832 override the trusted miners list with a hardcoded filter to UID 80, which will break training for all other miners in production. The
# TEST:comment indicates this is debug code that was not removed before commit.Remove or conditionally disable this test filter:
- # TEST: filter to uid 80 hotkey and choose between all hotkeys - trusted_hotkeys = [hk for hk, uid in uid_by_hotkey.items() if uid == 80] + # Use all trusted hotkeys for training (test filter removed)If testing is needed, gate behind an environment variable:
if os.getenv("GRAIL_TEST_SINGLE_MINER"): test_uid = int(os.getenv("GRAIL_TEST_UID", "80")) trusted_hotkeys = [hk for hk, uid in uid_by_hotkey.items() if uid == test_uid] logger.warning("TEST MODE: Filtering to single UID %d", test_uid)grail/neurons/trainer.py (1)
798-808: Critical:wallet.hotkey_strattribute does not exist on bittensor wallet objects.Line 806 accesses
self._context.wallet.hotkey_str, which is not a valid attribute on realbt.walletinstances. This will cause anAttributeErrorat runtime in production. The test suite only passes because_FakeWalletmanually defines this attribute.Real bittensor wallets have
wallet.hotkey(a Keypair object),wallet.name, andwallet.path, but nothotkey_str.Replace with a valid attribute:
def _serialize_wallet(self) -> dict[str, str]: """Serialize wallet for pickling to child process. Returns: Dictionary of wallet arguments """ return { "name": self._context.wallet.name, - "hotkey": self._context.wallet.hotkey_str, + "hotkey": self._context.wallet.hotkey.ss58_address, "path": self._context.wallet.path, }This extracts the SS58 address from the Keypair object, which
bt.wallet(**wallet_args)accepts as thehotkeyparameter.grail/trainer/snapshot_manager.py (1)
106-112: Race condition:rmtree+renameis not atomic and risks snapshot loss.If a crash occurs between
shutil.rmtree(target_dir)at line 109 andtemp_dir.rename(target_dir)at line 112, the snapshot is permanently lost. This violates the atomicity guarantee stated in the docstring.Implement a safer three-step swap pattern:
# Atomic rename: remove old target if exists, then rename if target_dir.exists(): - # Remove old snapshot - shutil.rmtree(target_dir) - - # Rename temp to target (atomic on POSIX) - temp_dir.rename(target_dir) + # Step 1: Rename old snapshot to backup (atomic) + backup_dir = self.snapshot_dir / "latest.old" + if backup_dir.exists(): + shutil.rmtree(backup_dir) + target_dir.rename(backup_dir) + + # Step 2: Rename temp to target (atomic on POSIX) + temp_dir.rename(target_dir) + + # Step 3: Clean up backup after successful rename + backup_dir = self.snapshot_dir / "latest.old" + if backup_dir.exists(): + try: + shutil.rmtree(backup_dir) + except Exception: + pass # Best effort cleanupThis ensures at least one valid snapshot (either old or new) always exists on disk, even if the process crashes mid-operation.
🧹 Nitpick comments (1)
grail/trainer/checkpoint_publisher.py (1)
328-354: Verify the compressed file detection logic aligns with upload_file_chunked behavior.Lines 333-335 infer whether a file was compressed by checking if it's a small JSON file (<10MB). However, this duplicates the compression decision logic from
upload_file_chunkedingrail/infrastructure/comms.py(lines 429-435). If the compression threshold or logic changes inupload_file_chunked, this verification will break.Consider one of these approaches:
- Return compression info from upload_file_chunked (preferred):
-async def upload_file_chunked(...) -> bool: +async def upload_file_chunked(...) -> tuple[bool, bool]: # ... existing logic ... - return True + return True, was_compressed # Return compression statusThen use the returned value here instead of re-inferring.
- Extract shared compression logic to a utility function:
def should_compress_file(key: str, size: int) -> bool: return key.endswith(".json") and size < 10 * 1024 * 1024Use this in both places to ensure consistency.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
grail/infrastructure/comms.py(1 hunks)grail/neurons/trainer.py(2 hunks)grail/trainer/checkpoint_publisher.py(1 hunks)grail/trainer/snapshot_manager.py(1 hunks)grail/trainer/training_process.py(1 hunks)tests/integration/test_async_trainer.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/integration/test_async_trainer.py
🧰 Additional context used
🧬 Code graph analysis (2)
grail/trainer/checkpoint_publisher.py (2)
grail/infrastructure/checkpoint_consumer.py (2)
CheckpointMetadata(63-75)remote_prefix(74-75)grail/infrastructure/comms.py (4)
delete_prefix(1056-1087)get_file_size(1025-1053)upload_file_chunked(402-541)list_bucket_files(984-1022)
grail/trainer/snapshot_manager.py (2)
tests/integration/protocol/test_proof_cross_framework.py (2)
model(39-43)tokenizer(47-52)tests/integration/test_train_cli_smoke.py (1)
save_pretrained(60-63)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: Cursor Bugbot
- GitHub Check: Run Tests (3.11)
- GitHub Check: Run Tests (3.10)
- GitHub Check: Run Tests (3.11)
- GitHub Check: Run Tests (3.10)
🔇 Additional comments (1)
grail/infrastructure/comms.py (1)
984-1022: LGTM! Pagination support correctly implemented.The pagination logic properly handles S3/R2's continuation token mechanism for listing >1000 objects. The safety check at lines 1014-1015 prevents infinite loops if
NextContinuationTokenis unexpectedly missing. This enhancement is essential for the checkpoint cleanup operations that may need to list many checkpoint files.
|
|
||
| batch_size = config.batch_size | ||
| micro_batch_size = max(1, int(config.batch_size)) | ||
| grad_accum_steps = max(1, self.config.grad_accum_steps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: train_epoch mixes config parameter with self.config
The train_epoch function accepts a config parameter but inconsistently uses it only for batch_size while using self.config for grad_accum_steps, entropy_coef, grad_clip, and max_length. If a caller passes a different config to train_epoch, only the batch size would reflect the passed config while other hyperparameters would use the instance's config, leading to unexpected training behavior.
Additional Locations (1)
| ) | ||
| self._windows_since_last_eval += 1 | ||
|
|
||
| await asyncio.sleep(ORCHESTRATION_SLEEP_SECONDS) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Evaluation interval counts loop iterations not training windows
The _windows_since_last_eval counter is now incremented per orchestration loop iteration (every 60 seconds) rather than per actual training window processed. This changes the semantics of window_interval from "evaluate every N training windows" to "evaluate every N minutes." Users configuring window_interval=20 expecting evaluation every 20 training windows will instead get evaluation approximately every 20 minutes regardless of training progress.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
grail/cli/mine.py(8 hunks)grail/shared/chat_templates.py(1 hunks)grail/trainer/config.py(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- grail/cli/mine.py
🧰 Additional context used
🧬 Code graph analysis (1)
grail/trainer/config.py (1)
grail/trainer/inference_server.py (1)
start_server(70-72)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: Cursor Bugbot
- GitHub Check: Run Tests (3.10)
- GitHub Check: Run Tests (3.11)
- GitHub Check: Run Tests (3.11)
- GitHub Check: Run Tests (3.10)
🔇 Additional comments (5)
grail/trainer/config.py (2)
53-54: LGTM. The gradient checkpointing configuration follows the established pattern of sourcing defaults from constants, and theTRAINER_USE_GRADIENT_CHECKPOINTINGconstant is properly defined ingrail/shared/constants.py(line 83-84).
88-88: This change is justified and aligns with standard evaluation practices.The default split of
"val"is the appropriate choice for evaluation, as validation sets are typically more controlled and representative than test sets. All EvalConfig instantiations throughout the codebase work correctly with this default, and the dataset providers (GSM8K, MATH) fully support the"val"split. The TODO comment accurately reflects a future improvement: making split configurable per environment rather than a global default. No breaking changes result from this modification.grail/shared/chat_templates.py (3)
22-26: LGTM!Good use of constants for ChatML tokens (
IM_START,IM_END) to centralize token definitions and improve maintainability.
29-51: LGTM on the signature change and default handling.The optional
system_promptparameter with sensible defaulting toSYSTEM_PROMPTis a clean API design. The past review already flagged the external call sites that need updating.
65-86: LGTM!The template construction is well-structured with clear inline comments explaining the critical behavior around
add_generation_prompt. The conditional system message handling and message iteration are correctly implemented.
| # Escape single quotes in system prompt for safe embedding in template | ||
| escaped_prompt = system_prompt.replace("'", "\\'") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incomplete escaping may cause Jinja2 syntax errors.
Only single quotes are escaped, but backslashes are not. If system_prompt contains a backslash followed by a single quote (e.g., \'), the result \\' would be misinterpreted in the Jinja2 template as an escaped backslash followed by a string-terminating quote.
Apply this diff to escape backslashes before single quotes:
- # Escape single quotes in system prompt for safe embedding in template
- escaped_prompt = system_prompt.replace("'", "\\'")
+ # Escape backslashes and single quotes for safe embedding in Jinja2 template
+ escaped_prompt = system_prompt.replace("\\", "\\\\").replace("'", "\\'")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Escape single quotes in system prompt for safe embedding in template | |
| escaped_prompt = system_prompt.replace("'", "\\'") | |
| # Escape backslashes and single quotes for safe embedding in Jinja2 template | |
| escaped_prompt = system_prompt.replace("\\", "\\\\").replace("'", "\\'") |
🤖 Prompt for AI Agents
In grail/shared/chat_templates.py around lines 53-54, the code only escapes
single quotes which fails when the system_prompt contains backslashes before
quotes; first escape backslashes (replace "\" with "\\") then escape single
quotes (replace "'" with "\'") so backslashes are doubled before you escape
quotes, ensuring Jinja2 receives a correctly escaped string.
| # Replay buffer configuration | ||
| replay_buffer_enabled: bool = True | ||
| replay_buffer_max_windows: int = 1 # Store last 1 windows (~6 min of data) | ||
| replay_buffer_recent_fraction: float = 0.5 # 50% samples from most recent window | ||
| replay_buffer_decay_factor: float = 0.7 # Exponential decay for older windows | ||
| replay_buffer_max_groups_per_epoch: int = 64 # Max groups to sample per epoch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Move replay buffer defaults to constants module for consistency.
These fields break the established pattern in TrainingConfig where all other defaults are sourced from grail.shared.constants. Hardcoding values here creates an inconsistent maintainability model.
Move these defaults to grail/shared/constants.py and reference them here:
# Replay buffer configuration
- replay_buffer_enabled: bool = True
- replay_buffer_max_windows: int = 1 # Store last 1 windows (~6 min of data)
- replay_buffer_recent_fraction: float = 0.5 # 50% samples from most recent window
- replay_buffer_decay_factor: float = 0.7 # Exponential decay for older windows
- replay_buffer_max_groups_per_epoch: int = 64 # Max groups to sample per epoch
+ replay_buffer_enabled: bool = constants.REPLAY_BUFFER_ENABLED
+ replay_buffer_max_windows: int = constants.REPLAY_BUFFER_MAX_WINDOWS
+ replay_buffer_recent_fraction: float = constants.REPLAY_BUFFER_RECENT_FRACTION
+ replay_buffer_decay_factor: float = constants.REPLAY_BUFFER_DECAY_FACTOR
+ replay_buffer_max_groups_per_epoch: int = constants.REPLAY_BUFFER_MAX_GROUPS_PER_EPOCHThen add to grail/shared/constants.py:
# Replay buffer configuration
REPLAY_BUFFER_ENABLED: bool = True
REPLAY_BUFFER_MAX_WINDOWS: int = 1 # Store last 1 windows (~6 min of data)
REPLAY_BUFFER_RECENT_FRACTION: float = 0.5 # 50% samples from most recent window
REPLAY_BUFFER_DECAY_FACTOR: float = 0.7 # Exponential decay for older windows
REPLAY_BUFFER_MAX_GROUPS_PER_EPOCH: int = 64 # Max groups to sample per epoch🤖 Prompt for AI Agents
In grail/trainer/config.py around lines 71 to 76, the replay buffer default
values are hardcoded and must be moved to grail/shared/constants.py for
consistency; add the five constants (REPLAY_BUFFER_ENABLED,
REPLAY_BUFFER_MAX_WINDOWS, REPLAY_BUFFER_RECENT_FRACTION,
REPLAY_BUFFER_DECAY_FACTOR, REPLAY_BUFFER_MAX_GROUPS_PER_EPOCH) to
grail/shared/constants.py with the provided values, then update
grail/trainer/config.py to import these constants and use them as the default
values for the corresponding replay_buffer_* fields in TrainingConfig.
| # sgLang server options (used when backend == "sglang") | ||
| sglang_host: str = "127.0.0.1" | ||
| sglang_port: int = 30000 | ||
| sglang_start_server: bool = True # Server runs in subprocess (avoids Gloo socket issues) | ||
| sglang_server_timeout_s: float = 120.0 | ||
| sglang_trust_remote_code: bool = False | ||
| server_host: str = "127.0.0.1" | ||
| server_port: int = 30000 | ||
| start_server: bool = True # Server runs in subprocess (avoids Gloo socket issues) | ||
| server_timeout: float = 120.0 | ||
| server_trust_remote_code: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update stale comment to reflect generic server naming.
The comment on line 100 still references "sgLang server options" but the fields below have been renamed to generic server_* naming (not sglang-specific). This is misleading.
Apply this diff to update the comment:
- # sgLang server options (used when backend == "sglang")
+ # Server options (used when backend == "sglang" or other server backends)
server_host: str = "127.0.0.1"
server_port: int = 30000
start_server: bool = True # Server runs in subprocess (avoids Gloo socket issues)
server_timeout: float = 120.0
server_trust_remote_code: bool = False🤖 Prompt for AI Agents
In grail/trainer/config.py around lines 100 to 105, the comment "sgLang server
options (used when backend == \"sglang\")" is stale and misleading given the
fields are now generic server_* names; update the comment to a generic
description like "Server options (used when backend == \"server\" or generic
server backends)" or similar to reflect generic server naming, keeping brief and
accurate.
| idle_duration, | ||
| method_name, | ||
| ) | ||
| await self._restart_subtensor() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Comment and log message claim timeout doubling that doesn't happen
The comment at line 230 and log message at line 240 claim the timeout is doubled when the connection has been idle, but the code only restarts the subtensor - it never actually modifies the timeout variable. This could lead to unexpected behavior where connections time out faster than intended after idle periods, despite the documentation suggesting otherwise.
| if self._should_run_evaluation(): | ||
| logger.info("Evaluation due, coordinating with training process...") | ||
| await self._coordinate_evaluation(current_window) | ||
| self._windows_since_last_eval = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Evaluation counter resets even when evaluation fails to start
In _orchestration_loop, _windows_since_last_eval is unconditionally reset to 0 after calling _coordinate_evaluation, even when _coordinate_evaluation returns early without running evaluation (e.g., if training process failed to pause within timeout). This causes evaluation to be incorrectly deferred for another full interval instead of being retried sooner.
Additional Locations (1)
| ) | ||
|
|
||
| logger.info("Server shutdown complete") | ||
| return metrics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Evaluator shutdown missing in server evaluation path
In _run_server_evaluation, the evaluator created at line 690 is never shut down after use. Unlike _run_direct_evaluation which calls _cleanup_evaluation_resources (including evaluator.shutdown()) in a finally block, _run_server_evaluation exits without any cleanup. The EvaluatorService.shutdown() method is documented as "Critical for freeing GPU memory" and "Must be called after evaluation completes."
…filtering previously used for testing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (4)
grail/shared/constants.py (1)
50-85: Constants wiring and new async knobs look sound; just verify external expectationsThe new defaults and flags (
TRAINER_MAX_LENGTH=1024,TRAINER_USE_GRADIENT_CHECKPOINTING, snapshot/heartbeat/upload retry constants) are wired correctly and typed safely. This will change default sequence length and enable gradient checkpointing + async worker behaviour out of the box; please double‑check that any docs/CLI help and deployment configs are updated to reflect these new defaults and env names.Also applies to: 247-257
grail/trainer/training_process.py (3)
245-277: Consider basing KL/ref‑model enablement onTrainingConfiginstead of globalsBoth
_initialize_resourcesand_prepare_models_with_acceleratorgate ref‑model loading and accelerator wrapping on the module‑levelis_kl_enabled()helper, which readsTRAINER_KL_COEFfrom env. If callers ever overrideTrainingConfig.kl_coefat runtime (without changing the env), the env‑based check and the config value can diverge.A slightly more robust pattern would be to make the decision based on the effective config:
- load_ref_model = is_kl_enabled() + load_ref_model = self.config.kl_coef > 0.0 ... - kl_enabled = is_kl_enabled() + kl_enabled = self.config.kl_coef > 0.0This keeps behaviour consistent with whatever
TrainingConfigthe orchestrator passes in, while still defaulting from the shared constants.Also applies to: 583-599
821-839: Optional: reuse metagraph to avoid redundant chain calls in_load_grpo_groups
_load_grpo_groupsalways fetches a fresh metagraph viaself.subtensor.metagraph(NETUID)even though_get_trusted_minersdoes the same in the same iteration. For long training runs this is probably fine, but if metagraph fetches are expensive or rate‑limited, you could reduce load by reusing a single metagraph per loop iteration (e.g., pass it from_get_trusted_minersor cache it briefly in the service).Not critical, but worth considering if you see metagraph calls becoming a bottleneck.
Also applies to: 850-873
448-481: Docstring and stop_event usage around initial checkpoint could be tightened
_upload_initial_checkpoint’s docstring says "Upload initial checkpoint and wait for miners to download", but the method now just saves the snapshot (upload handled by the worker) and does not touchstop_event. The actual wait happens in_wait_for_miners, whereWAIT_WINDOW_LENGTHis currently hardcoded to0.0(no wait) with a TODO.Not a functional bug, but to avoid confusion you might:
- Update
_upload_initial_checkpoint’s docstring to describe that it only saves the initial snapshot and lets the upload worker handle publishing.- Optionally use or remove the unused
stop_eventparameter there, or wire it into a future early‑exit hook.This will make the intent of the startup sequence clearer to future readers.
Also applies to: 201-233
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
grail/shared/constants.py(3 hunks)grail/trainer/__init__.py(1 hunks)grail/trainer/service.py(0 hunks)grail/trainer/training_process.py(1 hunks)
💤 Files with no reviewable changes (1)
- grail/trainer/service.py
🧰 Additional context used
🧬 Code graph analysis (1)
grail/trainer/training_process.py (13)
grail/infrastructure/checkpoint_consumer.py (1)
CheckpointManager(87-562)grail/infrastructure/network.py (1)
create_subtensor(330-404)grail/model/train_loading.py (2)
ModelLoadSpec(30-41)load_training_artifacts(149-190)grail/monitoring/manager.py (2)
get_monitoring_manager(366-375)initialize_monitoring(378-404)grail/shared/constants.py (1)
is_kl_enabled(262-268)grail/trainer/algorithms/base.py (1)
TrainingAlgorithm(14-40)grail/trainer/config.py (1)
TrainingConfig(15-76)grail/trainer/ipc.py (4)
IPCChannels(23-142)update_heartbeat(62-68)is_pause_requested(103-105)queue_snapshot(122-142)grail/trainer/replay_buffer.py (8)
ReplayBuffer(20-63)create_replay_buffer(253-280)get_stats(50-55)get_stats(212-233)add_window(29-35)add_window(110-133)sample_groups(38-47)sample_groups(135-180)grail/trainer/snapshot_manager.py (2)
SnapshotManager(26-223)save_snapshot_atomic(59-127)grail/neurons/miner.py (1)
run(52-308)grail/monitoring/backends/wandb_backend.py (2)
initialize(99-137)start_run(704-743)grail/logging_utils.py (1)
configure_process_logging(227-296)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: Run Tests (3.10)
- GitHub Check: Run Tests (3.11)
- GitHub Check: Cursor Bugbot
- GitHub Check: Run Tests (3.11)
- GitHub Check: Run Tests (3.10)
🔇 Additional comments (1)
grail/trainer/__init__.py (1)
5-11: Updated trainer package surface looks consistentExporting
checkpoint_publisherandreplay_buffer(and dropping the old service/checkpointing modules) matches the new async trainer architecture; this keeps the public API aligned with how downstream code now imports these components.
| async def _initialize_monitoring(self) -> None: | ||
| """Initialize monitoring after heavy resources are loaded. | ||
| This is called AFTER models and chain manager are initialized to avoid | ||
| resource contention during WandB connection. | ||
| """ | ||
| if not self.monitor_config: | ||
| logger.info("No monitoring config provided, skipping monitoring setup") | ||
| return | ||
|
|
||
| # FIRST: Test direct WandB connection (bypasses monitoring class) | ||
| import os | ||
|
|
||
| os.environ["WANDB_DISABLE_SERVICE"] = "false" # Ensure service is enabled | ||
| os.environ["WANDB_SERVICE"] = "" # Clear any inherited service path | ||
| try: | ||
| backend_type = self.monitor_config.get("backend_type", "wandb") | ||
| init_config = {k: v for k, v in self.monitor_config.items() if k != "backend_type"} | ||
|
|
||
| logger.debug( | ||
| "Initializing monitoring: backend_type=%s run_name=%s run_id=%s entity=%s project=%s mode=%s", | ||
| backend_type, | ||
| init_config.get("run_name"), | ||
| init_config.get("run_id"), | ||
| init_config.get("entity"), | ||
| init_config.get("project"), | ||
| init_config.get("mode"), | ||
| ) | ||
| logger.debug( | ||
| "Full init_config keys received in training process: %s", list(init_config.keys()) | ||
| ) | ||
| # Verify critical parameters are present | ||
| if not init_config.get("entity"): | ||
| logger.warning("⚠️ WandB entity not set in training process - will use default") | ||
| if not init_config.get("project"): | ||
| logger.warning("⚠️ WandB project not set in training process - will use default") | ||
|
|
||
| if "run_id" in init_config: | ||
| logger.info( | ||
| "Resuming W&B run %s in training process for multi-process logging", | ||
| init_config["run_id"], | ||
| ) | ||
|
|
||
| initialize_monitoring(backend_type=backend_type, **init_config) | ||
| self.monitor = get_monitoring_manager() | ||
| logger.info("Monitoring initialized in training process") | ||
|
|
||
| # Actually start the WandB run (connects to API, resumes run) | ||
| # This is critical - without this, the run is not actually initialized | ||
| run_name = init_config.get("run_name") | ||
| if run_name: | ||
| logger.info("Starting WandB run: %s", run_name) | ||
| start_time = time.time() | ||
| try: | ||
| actual_run_id = await self.monitor.start_run(run_name, init_config) | ||
| start_duration = time.time() - start_time | ||
| logger.info( | ||
| "✅ WandB run started successfully in %.1fs (run_id=%s)", | ||
| start_duration, | ||
| actual_run_id, | ||
| ) | ||
| except Exception as start_exc: | ||
| start_duration = time.time() - start_time | ||
| logger.error( | ||
| "❌ Failed to start WandB run after %.1fs: %s", | ||
| start_duration, | ||
| start_exc, | ||
| ) | ||
| logger.info( | ||
| "Hint: Increase WANDB_INIT_TIMEOUT (current: %s) if shared mode workers timeout", | ||
| init_config.get("init_timeout", "120"), | ||
| ) | ||
| # Continue without monitoring | ||
| return | ||
|
|
||
| # Test WandB metric logging with a simple gauge | ||
| if self.monitor and run_name: | ||
| logger.info("Testing WandB metric logging...") | ||
| test_start = time.time() | ||
| try: | ||
| # Set current block context so metric appears in WandB properly | ||
| current_block = await self.subtensor.get_current_block() | ||
| current_window = (current_block // WINDOW_LENGTH) * WINDOW_LENGTH | ||
| self.monitor.set_block_context(current_block, current_window) | ||
| await self.monitor.log_gauge("training_process/connection_test", 1.0) | ||
| await self.monitor.flush_metrics() | ||
| test_duration = time.time() - test_start | ||
| logger.info( | ||
| "✅ WandB metric logging successful in %.3fs", | ||
| test_duration, | ||
| ) | ||
| except Exception as test_exc: | ||
| test_duration = time.time() - test_start | ||
| logger.warning( | ||
| "⚠️ WandB metric logging test failed after %.3fs: %s (training will continue)", | ||
| test_duration, | ||
| test_exc, | ||
| ) | ||
|
|
||
| except Exception as exc: | ||
| logger.warning("Failed to initialize monitoring in training process: %s", exc) | ||
| self.monitor = None | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On WandB start failure, monitoring is not actually disabled
In _initialize_monitoring, if self.monitor.start_run(...) raises, the code logs "Continue without monitoring" and returns, but self.monitor still points to a live MonitoringManager instance. The rest of the training flow checks if self.monitor: and will continue to call set_block_context, log_gauge, etc., which may keep failing or behave inconsistently because the run was never started.
Consider explicitly disabling monitoring in this failure path, e.g.:
- except Exception as start_exc:
+ except Exception as start_exc:
start_duration = time.time() - start_time
logger.error(
"❌ Failed to start WandB run after %.1fs: %s",
start_duration,
start_exc,
)
logger.info(
"Hint: Increase WANDB_INIT_TIMEOUT (current: %s) if shared mode workers timeout",
init_config.get("init_timeout", "120"),
)
- # Continue without monitoring
- return
+ # Continue without monitoring
+ self.monitor = None
+ returnThis keeps later training logic from repeatedly exercising a half‑initialized monitoring backend.
🤖 Prompt for AI Agents
In grail/trainer/training_process.py around lines 306 to 408, when
self.monitor.start_run(...) raises the code currently logs and returns but
leaves self.monitor set to a live MonitoringManager; change the error path so
monitoring is explicitly disabled and cleaned up: call the monitor
shutdown/cleanup method if available (or drop any references) and set
self.monitor = None before returning so subsequent checks like if self.monitor:
won't attempt to use a half‑initialized monitoring backend; include a defensive
try/except around the cleanup call to ignore cleanup errors and log them at
debug level.
| def _prepare_models_with_accelerator( | ||
| self, | ||
| accelerator: Accelerator, | ||
| ) -> tuple[Any, Any | None, torch.optim.Optimizer]: | ||
| """Prepare models and optimizer with Accelerator for distributed training. | ||
| Args: | ||
| accelerator: Accelerator instance | ||
| Returns: | ||
| Tuple of (train_model, ref_model, optimizer) | ||
| """ | ||
| kl_enabled = is_kl_enabled() | ||
|
|
||
| if kl_enabled and self.ref_model is not None: | ||
| train_model, ref_model, optimizer = accelerator.prepare( | ||
| self.train_model, | ||
| self.ref_model, | ||
| self.optimizer, | ||
| ) | ||
| if hasattr(ref_model, "eval"): | ||
| ref_model.eval() | ||
| else: | ||
| train_model, optimizer = accelerator.prepare( | ||
| self.train_model, | ||
| self.optimizer, | ||
| ) | ||
| ref_model = None | ||
|
|
||
| return train_model, ref_model, optimizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🌐 Web query:
accelerate library Accelerator.prepare optimizer in place mutation documentation
💡 Result:
Short answer: Accelerator.prepare (and prepare_optimizer) can replace or wrap the optimizer rather than mutate the original in-place — you should use the object returned by prepare/prepare_optimizer. The official docs show prepare returns the prepared optimizer (and recommend assigning it), and GH issues confirm optimizer state/identity can change in some distributed/offload/zero/DeepSpeed cases. [1][2]
Sources:
- accelerate docs (Accelerator.prepare / prepare_optimizer) — shows prepare returns prepared objects and examples assigning the returned optimizer. [1]
- GitHub issue reporting save/load problems after accelerator.prepare (shows optimizer can be altered by prepare in some setups). [2]
References:
[1] Hugging Face Accelerate — Accelerator.prepare / prepare_optimizer (docs).
[2] Issue: "Unable to save/load optimizer state after accelerator.prepare()" — huggingface/accelerate GitHub.
🏁 Script executed:
# First, let's examine the actual code at the mentioned lines
cat -n grail/trainer/training_process.py | sed -n '571,600p'Repository: one-covenant/grail
Length of output: 1228
🏁 Script executed:
# Also check the _handle_pause method at lines 885-942
cat -n grail/trainer/training_process.py | sed -n '885,942p'Repository: one-covenant/grail
Length of output: 2278
🏁 Script executed:
# Let's check if there's any direct reference to how self.optimizer is used after _prepare_models_with_accelerator
rg "self\.optimizer" grail/trainer/training_process.py -n -B2 -A2 | head -100Repository: one-covenant/grail
Length of output: 1485
Use the optimizer returned by accelerator.prepare()—don't rely on in-place mutation
The accelerator.prepare() method can wrap or replace the optimizer rather than mutate it in place. The current code captures the returned optimizer in _prepare_models_with_accelerator but continues to use self.optimizer in _train_epoch (line 933), which may diverge from the prepared optimizer in distributed, DeepSpeed, or offload configurations.
Assign the prepared optimizer back to self.optimizer to ensure the rest of the code uses the correct instance:
def _prepare_models_with_accelerator(
self,
accelerator: Accelerator,
) -> tuple[Any, Any | None, torch.optim.Optimizer]:
...
if kl_enabled and self.ref_model is not None:
train_model, ref_model, optimizer = accelerator.prepare(
self.train_model,
self.ref_model,
self.optimizer,
)
else:
train_model, optimizer = accelerator.prepare(
self.train_model,
self.optimizer,
)
ref_model = None
+
+ self.optimizer = optimizer
return train_model, ref_model, optimizerApply the same fix in _handle_pause where the optimizer is re-prepared after pause events.
🤖 Prompt for AI Agents
In grail/trainer/training_process.py around lines 571-600, the prepared
optimizer returned from accelerator.prepare() is captured locally but not
assigned back to the instance; update the method to assign the returned
optimizer to self.optimizer so subsequent code uses the wrapped/replaced
optimizer (e.g., self.optimizer = optimizer), and do the same in the
pause-handling code path where the optimizer is re-prepared after pauses to
ensure the instance holds the prepared optimizer rather than the original.
| # Close subtensor connection before long idle period (prevents 10s timeout) | ||
| if self.subtensor: | ||
| try: | ||
| if hasattr(self.subtensor, "_subtensor"): | ||
| # Unwrap ResilientSubtensor to get underlying subtensor | ||
| inner = object.__getattribute__(self.subtensor, "_subtensor") | ||
| if hasattr(inner, "close"): | ||
| await inner.close() | ||
| elif hasattr(self.subtensor, "close"): | ||
| await self.subtensor.close() | ||
| logger.info("Closed subtensor connection during pause to prevent idle timeout") | ||
| except Exception as e: | ||
| logger.warning("Failed to close subtensor during pause: %s", e) | ||
| self.subtensor = None | ||
|
|
||
| # Wait for pause flag to be cleared via IPC (primary) or filesystem (fallback) | ||
| while not stop_event.is_set(): | ||
| # Check if resume signal received | ||
| pause_still_active = ( | ||
| self._ipc.is_pause_requested() | ||
| if self._ipc is not None | ||
| else self.snapshot_manager.check_pause_flag() | ||
| ) | ||
| if not pause_still_active: | ||
| break | ||
| await asyncio.sleep(PAUSE_CHECK_INTERVAL_SECONDS) | ||
|
|
||
| if stop_event.is_set(): | ||
| return train_model_cpu, ref_model_cpu, optimizer | ||
|
|
||
| # Clear the confirmed event for next cycle | ||
| if self._ipc is not None: | ||
| self._ipc.clear_pause_confirmed() | ||
|
|
||
| logger.info("🔄 STATE: resume_requested - pause signal cleared") | ||
|
|
||
| # Recreate subtensor connection after pause | ||
| logger.info("Recreating subtensor connection after pause...") | ||
| self.subtensor = await create_subtensor(resilient=True) | ||
| logger.info("Subtensor connection recreated successfully") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recreating subtensor on pause does not refresh GrailChainManager
During _handle_pause, you correctly close the existing subtensor and then recreate a fresh one:
self.subtensor = await create_subtensor(resilient=True)However, self.chain_manager was constructed once in _initialize_resources with the original subtensor and metagraph, and is not reinitialized here. After a pause/resume cycle, any chain interactions inside load_grpo_groups (via self.chain_manager) may still be using the old, closed subtensor, leading to repeated GRPO load failures and effectively stalling training after the first pause.
To keep chain interactions aligned with the new subtensor, you likely want to reinitialize the chain manager in this block, similar to _initialize_resources:
logger.info("Recreating subtensor connection after pause...")
self.subtensor = await create_subtensor(resilient=True)
- logger.info("Subtensor connection recreated successfully")
+ logger.info("Subtensor connection recreated successfully")
+
+ # Reinitialize chain manager with fresh subtensor/metagraph
+ try:
+ logger.info("Reinitializing chain manager after pause...")
+ metagraph = await self.subtensor.metagraph(NETUID)
+ chain_config = SimpleNamespace(netuid=NETUID)
+ self.chain_manager = GrailChainManager(
+ chain_config,
+ self.wallet,
+ metagraph,
+ self.subtensor,
+ self.credentials,
+ )
+ await self.chain_manager.initialize()
+ logger.info("Chain manager reinitialized after pause")
+ except Exception as exc: # noqa: BLE001
+ logger.error("Failed to reinitialize chain manager after pause: %s", exc)Without this, pause/resume may quietly break subsequent GRPO loading and chain‑dependent logic.
Async Trainer Infrastructure
Overview: Major feature introducing asynchronous training architecture with improved checkpoint management and multi-process coordination.
Key Changes:
training_process.pyand IPC coordinator for multiprocessing coordinationCheckpointPublisher(publishing) andCheckpointManager(consuming) with simplified retention policyNote
Adds an async multi-process trainer with IPC and snapshot-based uploads, new checkpoint publish/consume flow, a GRPO replay buffer, improved WandB shared-mode logging, network resilience, and miner/validator checkpoint handling.
training_process.pyandupload_worker.pywith IPC (IPCChannels) andSnapshotManagerfor snapshot-based training and uploads.checkpoint_consumer.py(read-only, pagination, READY-{window} discovery, latest-ready lookup) andcheckpoint_publisher.py(publish, finalize READY markers, remote cleanup).checkpoint_windowmatch.checkpoint_window; drop groups with completion <CHALLENGE_K.Written by Cursor Bugbot for commit aca6ed2. This will update automatically on new commits. Configure here.
Summary by CodeRabbit
New Features
Improvements
Bug Fixes
Tests
✏️ Tip: You can customize this high-level summary in your review settings.