Skip to content

Conversation

@erfanMhi
Copy link
Collaborator

@erfanMhi erfanMhi commented Dec 8, 2025

Async Trainer Infrastructure

Overview: Major feature introducing asynchronous training architecture with improved checkpoint management and multi-process coordination.

Key Changes:

  • Async Training: Convert trainer to async architecture with new training_process.py and IPC coordinator for multiprocessing coordination
  • Checkpoint Refactor: Split into CheckpointPublisher (publishing) and CheckpointManager (consuming) with simplified retention policy
  • Replay Buffer: Implement a configurable replay buffer for GRPO training with recency-weighted sampling
  • Memory Efficiency: Add optional gradient checkpointing and optimize logprob computation with autocast
  • Network Resilience: Implement subtensor connection management with idle detection and automatic restart
  • Upload Worker: New dedicated process for checkpoint uploads

Note

Adds an async multi-process trainer with IPC and snapshot-based uploads, new checkpoint publish/consume flow, a GRPO replay buffer, improved WandB shared-mode logging, network resilience, and miner/validator checkpoint handling.

  • Trainer (async orchestration):
    • New training_process.py and upload_worker.py with IPC (IPCChannels) and SnapshotManager for snapshot-based training and uploads.
    • Periodic eval with server backends; trainer never uses GPU; pause/resume coordination.
    • Adds GRPO replay buffer with recency-weighted sampling; micro-batch metrics logging; optional gradient checkpointing.
  • Checkpoints:
    • Split roles: checkpoint_consumer.py (read-only, pagination, READY-{window} discovery, latest-ready lookup) and checkpoint_publisher.py (publish, finalize READY markers, remote cleanup).
    • Miner/Validator load latest ready checkpoint before window; validator enforces checkpoint_window match.
  • Monitoring & Infra:
    • WandB backend: shared-mode initialization, minimal worker settings, async/batched logging; CLI waits for run sync.
    • Network: resilient subtensor with idle detection/restart; S3 list pagination; structured logging utilities.
  • Mining/Envs/Model:
    • Miner packaging includes checkpoint_window; drop groups with completion < CHALLENGE_K.
    • MATH env improves gold answer extraction; vLLM backend uses returned token IDs.
  • Config/CLI/Tests:
    • Env/config for gradient checkpointing and eval server host/port; trainer CLI wires publisher and verbosity.
    • New integration/unit tests for async trainer, replay buffer, CLI, validation, and WandB shared mode.

Written by Cursor Bugbot for commit aca6ed2. This will update automatically on new commits. Configure here.

Summary by CodeRabbit

  • New Features

    • Async multi-process trainer with snapshot/upload workers and a replay buffer
    • New env var GRAIL_TRAINER_USE_GRADIENT_CHECKPOINTING=1 to enable gradient checkpointing
  • Improvements

    • Robust checkpoint publish/consume and snapshot staging for reliable uploads
    • Enhanced monitoring with shared-mode, async logging, tracing context, and safer formatting
    • Improved math-answer extraction and network idle recovery
  • Bug Fixes

    • Miners discover and use the latest ready checkpoints
  • Tests

    • New integration and unit tests for async trainer, snapshotting, replay buffer, and monitoring

✏️ Tip: You can customize this high-level summary in your review settings.

…ng CheckpointPublisher for publishing and managing checkpoints, and CheckpointManager for consuming checkpoints
Add snapshot-based IPC for training/upload coordination. Include
TrainingService for continuous training with pause/resume and
SnapshotManager for atomic checkpoint ops. Add upload worker for
async checkpoint publishing to R2.
Replace window-based training with continuous loop. Split training
and upload into separate processes using SnapshotManager. Add
heartbeat monitoring and pause/resume for evaluation coordination.
Change READY marker format to READY-{window} encoding completion
window. Add upload_from_staging method for worker process. Simplify
finalize_checkpoint_ready to accept explicit checkpoint and ready
windows.
Update checkpoint discovery to use READY-{window} format. Add
get_checkpoint_ready_window method to determine availability. Reorder
CheckpointMetadata fields for consistency.
Remove synchronous model loading from CLI. Pass ModelLoadSpec to
TrainerContext instead of loaded models. Load models in training
process to avoid CUDA fork issues.
Add tests for SnapshotManager atomic operations and IPC. Add smoke
tests for train CLI with lightweight fakes to verify end-to-end
orchestration.
Add SNAPSHOT_POLL_INTERVAL_SECONDS, TRAINING_HEARTBEAT_TIMEOUT_SECONDS,
and upload retry config. Reduce TRAINER_MAX_LENGTH default from 2048
to 1024 for memory efficiency.
Update environment loop for async trainer compatibility. Improve math
environment parsing and reward computation.
Update validation logic for async checkpoint publishing. Improve miner
validation error handling and logging.
Update W&B backend to support logging from multiple processes (main,
training, upload worker) using run ID resumption.
Remove reasoning_start parameter from build_qwen_chat_template. Model
must generate reasoning tokens without prompt injection. Clean up
related comments.
Update miner to use get_latest_ready_checkpoint for discovery. Add
checkpoint_window to rollout packaging. Remove window parameter from
checkpoint publishing calls. Minor evaluator vLLM config cleanup.
Fix whitespace on blank lines, sort imports, remove unused imports,
and remove quotes from type annotations.
…eout configuration

- Added support for WandB shared mode to allow multiple processes to log to a single run.
- Introduced `init_timeout` configuration to manage connection delays more effectively.
- Updated logging to provide better insights into the configuration and initialization process.
…process doesn't reuse wandb service used by the main process
…tor for non-blocking read because it resulted in some issues
…d update GRPO max completion tokens to be 1024
…ional entropy calculation for memory efficiency
…w. Previously it was done after every micro batch.
Replace infinite filesystem polling with multiprocessing.Event for pause/resume coordination between orchestrator and training process. Adds 5-minute timeout, stop_event check, and process liveness check to prevent indefinite hangs. Filesystem markers kept as crash recovery backup.
- Add IPCChannels dataclass with stop, pause, heartbeat, and snapshot_queue
- Add helper methods for heartbeat, pause coordination, and snapshot queuing
- Fix bug in upload_worker using undefined variables
- Remove dead code: filesystem fallback mode and unused ipc_coordinator.py
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 8, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Implements async multi-process training orchestration (training + upload workers) with IPC, splits checkpoint producer/consumer roles (CheckpointPublisher and consumer helpers), adds SnapshotManager and ReplayBuffer, updates CLI/trainer wiring and monitoring (WandB) integration, and introduces config/env flags including gradient checkpointing.

Changes

Cohort / File(s) Summary
Configuration & Env
\.env\.example, grail/shared/constants.py, grail/trainer/config.py
Added GRAIL_TRAINER_USE_GRADIENT_CHECKPOINTING to .env.example; new TRAINER_USE_GRADIENT_CHECKPOINTING constant and async-training constants; TRAINER_MAX_LENGTH default changed; TrainingConfig adds use_gradient_checkpointing and replay-buffer fields; eval server fields added.
CLI wiring
grail/cli/train.py, grail/cli/mine.py
train() now accepts ctx: typer.Context for verbosity; replaced CheckpointManager use with CheckpointPublisher/wallet in trainer context; threaded checkpoint_window through rollout generation/packaging in mine.py and added completion-length gating.
Trainer orchestration & IPC
grail/trainer/training_process.py, grail/trainer/upload_worker.py, grail/trainer/ipc.py, grail/trainer/snapshot_manager.py
Added standalone training process and upload worker entrypoints; new IPCChannels for process coordination; SnapshotManager for atomic snapshot, staging, and heartbeat operations; training process uses snapshots, queues IPC messages for uploads.
Checkpoint producer & removed legacy
grail/trainer/checkpoint_publisher.py, grail/trainer/checkpointing.py
New CheckpointPublisher implements producer-side upload/finalize/upload-from-staging; legacy checkpointing.py removed (logic migrated).
Trainer service/API surface & exports
grail/trainer/service.py, grail/trainer/__init__.py
Trainer orchestration migrated to multi-process design; TrainerService removed; package exports updated to include checkpoint_publisher and replay_buffer and remove old service/checkpointing.
Checkpoint consumer & infra
grail/infrastructure/checkpoint_consumer.py, grail/infrastructure/network.py, grail/infrastructure/comms.py
Consumer discovers READY markers (get_latest_ready_checkpoint, _get_checkpoint_ready_window); metadata extended with parent_window; removed remote cleanup in consumer; ResilientSubtensor gains restart()/close() and idle-tracking; S3 listing handles pagination.
Miner / Validator / Validation service
grail/neurons/miner.py, grail/neurons/validator.py, grail/validation/service.py, grail/validation/miner_validator.py
Imports switched to consumer module; miner dynamically selects latest ready checkpoint (may retain current model if none/new same); validator threads validator_checkpoint_window through rollout validation and errors on mismatch; validation service uses ready-checkpoint discovery.
Algorithms & Replay buffer
grail/trainer/algorithms/grpo.py, grail/trainer/replay_buffer.py
GRPO refactored for micro-batches, grad-accum, aggregated batch metrics and memory guards; added ReplayBuffer interface and RecencyWeightedBuffer with deterministic sampling and factory.
Monitoring & logging
grail/monitoring/backends/wandb_backend.py, grail/monitoring/config.py, grail/logging_utils.py, scripts/test_monitoring_shared.py, scripts/test_wandb_shared.py
Major WandB backend refactor (shared-mode primary/worker flow, async/offload, metric families, artifact table); added init_timeout config; new structured logging utilities and shared-mode test scripts.
Model loading & provider
grail/model/train_loading.py, grail/model/provider.py
load_training_artifacts(..., *, load_ref_model: bool = True) can skip loading reference model (may return None); get_model now logs model metadata (param counts, dtype, config) safely.
Environments & evaluation
grail/environments/loop.py, grail/environments/math_hendrycks_env.py, grail/trainer/evaluator.py, grail/trainer/inference_server.py, research/offline_trainer/tests/test_gpu_integration.py
Added warn_on_missing_token_ids flag to VLLMServerBackend; MATH env extracts boxed LaTeX answer fallback; evaluator uses chosen-logprobs flag for warn path; inference server reads server_* fields; tests updated accordingly.
Exports & tests
tests/integration/*, tests/unit/trainer/test_replay_buffer.py, tests/integration/test_train_cli_smoke.py, tests/integration/test_validation_service.py
Added integration/unit tests for async trainer (snapshot manager, IPC flows), replay buffer unit tests, train CLI smoke tests, validation-service import path update, and WandB shared-mode test scripts.

Sequence Diagram(s)

sequenceDiagram
    participant Orchestrator as Orchestrator
    participant Trainer as Training Process
    participant SnapMgr as SnapshotManager
    participant IPC as IPCChannels
    participant Upload as Upload Worker
    participant Publisher as CheckpointPublisher
    participant Storage as Remote Storage

    Orchestrator->>Trainer: spawn training_process(ipc, specs, config)
    Orchestrator->>Upload: spawn upload_worker(ipc, publisher)

    loop training epoch
        Trainer->>Trainer: train epoch / update heartbeat
        Trainer->>SnapMgr: save_snapshot_atomic(model, tokenizer, metadata)
        SnapMgr-->>Trainer: snapshot saved
        Trainer->>IPC: queue_snapshot(path, metadata, window)
    end

    loop upload loop
        Upload->>IPC: wait for snapshot message
        IPC-->>Upload: snapshot metadata
        Upload->>SnapMgr: copy_snapshot_to_staging()
        SnapMgr-->>Upload: staging_path
        Upload->>Publisher: upload_from_staging(staging_path, metadata, window)
        Publisher->>Storage: upload files + metadata
        Storage-->>Publisher: uploaded
        Publisher->>Storage: write READY-{ready_window} marker
        Storage-->>Publisher: marker OK
        Upload->>SnapMgr: cleanup_staging()
    end
Loading
sequenceDiagram
    participant Miner as Miner
    participant CheckMgr as CheckpointManager (consumer)
    participant Storage as Remote Storage

    Miner->>CheckMgr: get_latest_ready_checkpoint(before_window)
    CheckMgr->>Storage: list/read READY markers
    Storage-->>CheckMgr: ready_window (or none)
    CheckMgr-->>Miner: checkpoint_window (or None)

    alt checkpoint found
        Miner->>CheckMgr: get_checkpoint(checkpoint_window)
        CheckMgr->>Storage: download files
        Storage-->>CheckMgr: model/tokenizer
        CheckMgr-->>Miner: artifacts
    else none
        Miner-->>Miner: warn / retry later
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Areas requiring extra attention:

  • grail/trainer/training_process.py — async orchestration, device pause/resume, snapshot and heartbeat semantics.
  • grail/trainer/algorithms/grpo.py — micro-batch/grad-accum correctness, metric aggregation, NaN/Inf guards.
  • grail/monitoring/backends/wandb_backend.py — shared-mode initialization, async/offload, metric definitions.
  • grail/trainer/checkpoint_publisher.py & grail/infrastructure/checkpoint_consumer.py — upload concurrency, metadata signing, READY semantics, and readiness discovery.
  • IPCChannels and SnapshotManager — cross-process primitives, filesystem atomicity, staging behavior and cleanup.

Possibly related PRs

Poem

🐇
I hopped through files with tiny paws,
Split checkpoints, signed them with my claws,
Snapshots tucked in cozy nests,
IPC drums beat through tests,
Trainers, uploaders — hare-approved, applause! 🥕

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat: add async trainer infrastructure' directly describes the primary change—introduction of an asynchronous, multi-process training architecture with supporting infrastructure.
Docstring Coverage ✅ Passed Docstring coverage is 90.43% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feat/add-async-trainer

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
grail/shared/chat_templates.py (1)

10-24: Remove the deprecated parameter from the docstring Args section.

The docstring documents reasoning_start in the Args section (line 15), but this parameter no longer exists in the function signature. Since it's been fully removed, delete it from the Args section rather than marking it as deprecated.

Apply this diff to fix the docstring:

     Args:
         system_prompt: The system prompt to inject
-        reasoning_start: DEPRECATED - No longer used. The model generates reasoning tokens.
 
     Returns:
grail/neurons/miner.py (1)

252-264: checkpoint_window can be None but generate_rollouts_for_window expects int.

checkpoint_manager.get_latest_ready_checkpoint() returns int | None. When no checkpoint is discovered (line 173), checkpoint_window remains None. If a model was already loaded from a previous iteration, the code bypasses the checkpoint loading block (line 178) and proceeds directly to the function call at line 252, passing None to generate_rollouts_for_window. However, that function has a type annotation of checkpoint_window: int (not Optional) and uses it directly without None checks at line 739 in assemble_rollout_payload. Either handle the None case before the function call or change the type annotation to int | None and add defensive checks within the function.

grail/infrastructure/checkpoint_consumer.py (1)

426-454: Semaphore created but never used in _download_all_in_prefix.

The semaphore on line 439 is created but not applied to the _dl coroutine, allowing unbounded concurrent downloads that could overwhelm the network or R2.

     async def _download_all_in_prefix(self, window: int, tmp_dir: Path) -> None:
         ...
         if not keys:
             raise CheckpointDownloadError(f"No files found at prefix {prefix_dir}")

-        asyncio.Semaphore(6)
+        semaphore = asyncio.Semaphore(6)

         async def _dl(key: str) -> None:
-            if not key or not key.startswith(prefix_dir) or key.endswith("/"):
-                return
-            rel = key[len(prefix_dir) :]
-            data = await comms.download_file_chunked(
-                key, credentials=self.credentials, use_write=False
-            )
-            if data is None:
-                raise CheckpointDownloadError(f"Missing file {key}")
-            target_path = tmp_dir / rel
-            target_path.parent.mkdir(parents=True, exist_ok=True)
-            target_path.write_bytes(data)
+            async with semaphore:
+                if not key or not key.startswith(prefix_dir) or key.endswith("/"):
+                    return
+                rel = key[len(prefix_dir) :]
+                data = await comms.download_file_chunked(
+                    key, credentials=self.credentials, use_write=False
+                )
+                if data is None:
+                    raise CheckpointDownloadError(f"Missing file {key}")
+                target_path = tmp_dir / rel
+                target_path.parent.mkdir(parents=True, exist_ok=True)
+                target_path.write_bytes(data)

         await asyncio.gather(*(_dl(k) for k in keys))
🧹 Nitpick comments (33)
grail/model/provider.py (1)

129-148: Model metadata logging looks good; consider single-pass param counting as a micro-optimization

The new logging block is safe and useful, and the try/except ensures it won’t affect model loading behavior. As a small optional improvement, you could avoid iterating model.parameters() twice by computing total_params and trainable_params in one pass:

total_params = 0
trainable_params = 0
for p in model.parameters():
    n = p.numel()
    total_params += n
    if p.requires_grad:
        trainable_params += n

This keeps the same behavior with slightly less overhead on very large models.

grail/trainer/algorithms/grpo.py (1)

1672-1711: Consider reducing key duplication between initialization and reset.

The tracker keys are duplicated between the initial dict (lines 1672-1689) and reset_actual_batch_tracker (lines 1694-1711). This creates a maintenance burden if keys change.

Consider extracting the default tracker state:

+        def _make_empty_tracker() -> dict[str, float]:
+            return {
+                "micro_batches": 0,
+                "sequence_count": 0.0,
+                "token_count": 0.0,
+                "loss_total_sum": 0.0,
+                "loss_pg_sum": 0.0,
+                "loss_kl_sum": 0.0,
+                "loss_entropy_sum": 0.0,
+                "adv_sum": 0.0,
+                "adv_count": 0.0,
+                "reward_sum": 0.0,
+                "reward_count": 0.0,
+                "entropy_sum": 0.0,
+                "entropy_count": 0.0,
+                "kl_sum": 0.0,
+                "ratio_clip_sum": 0.0,
+                "ratio_ceiling_sum": 0.0,
+            }
+
+        actual_batch_tracker = _make_empty_tracker()
+
+        def reset_actual_batch_tracker() -> None:
+            actual_batch_tracker.update(_make_empty_tracker())
grail/environments/math_hendrycks_env.py (1)

32-61: _extract_boxed_answer logic is solid; consider a small robustness tweak

The implementation correctly:

  • Detects \boxed{ occurrences with a regex.
  • Uses a depth counter to handle nested braces.
  • Returns None when no \boxed{} is present.

One edge case: if the last \boxed{ in the string has unbalanced braces (e.g., truncated or malformed solution), the function returns None and ignores any earlier, valid \boxed{...} blocks. If you expect occasionally noisy inputs, it may be safer to iterate over boxed_indices in reverse and return the first block that parses successfully, while also stripping whitespace from the extracted content.

For example:

-    # Use the last one
-    start = boxed_indices[-1]
-    # Skip \boxed{ (7 chars)
-    content_start = start + 7
-
-    depth = 1
-    for i, char in enumerate(text[content_start:], start=content_start):
-        if char == "{":
-            depth += 1
-        elif char == "}":
-            depth -= 1
-            if depth == 0:
-                return text[content_start:i]
+    # Try boxed occurrences from last to first; return the first that parses.
+    for start in reversed(boxed_indices):
+        # Skip \boxed{ (7 chars)
+        content_start = start + 7
+
+        depth = 1
+        for i, char in enumerate(text[content_start:], start=content_start):
+            if char == "{":
+                depth += 1
+            elif char == "}":
+                depth -= 1
+                if depth == 0:
+                    return text[content_start:i].strip()

You still fall through to return None at the end if no occurrence parses cleanly.

grail/infrastructure/network.py (1)

117-127: Make subtensor .close() handling more robust and avoid duplication

The new close logic in _restart_subtensor and close() is good for avoiding leaks, but the pattern:

if hasattr(subtensor, "close"):
    if asyncio.iscoroutinefunction(subtensor.close):
        await subtensor.close()
    else:
        subtensor.close()

relies on close being defined as a coroutine function. It will not await cases where close() returns an awaitable but isn’t marked as a coroutine function (wrappers, partials, etc.), and the same logic is duplicated in two places.

You can make this both more robust and DRY by:

  • Calling close() once, capturing the result, and awaiting it if it’s awaitable.
  • Reusing the same helper pattern in both _restart_subtensor and the public close().

For example:

-        try:
-            if hasattr(subtensor, "close"):
-                if asyncio.iscoroutinefunction(subtensor.close):
-                    await subtensor.close()
-                else:
-                    subtensor.close()
-                logger.debug("Closed old subtensor connection")
+        try:
+            if hasattr(subtensor, "close"):
+                close_result = subtensor.close()
+                if inspect.isawaitable(close_result):
+                    await close_result
+                logger.debug("Closed old subtensor connection")
@@
-        subtensor = object.__getattribute__(self, "_subtensor")
-        if hasattr(subtensor, "close"):
-            if asyncio.iscoroutinefunction(subtensor.close):
-                await subtensor.close()
-            else:
-                subtensor.close()
+        subtensor = object.__getattribute__(self, "_subtensor")
+        if hasattr(subtensor, "close"):
+            close_result = subtensor.close()
+            if inspect.isawaitable(close_result):
+                await close_result

And add at the top of the file:

import inspect  # new import

Optionally, you could also extract this into a small _maybe_await_close(subtensor) helper to avoid repeating the pattern a third time if you ever need it elsewhere.

Also applies to: 138-146

grail/trainer/replay_buffer.py (2)

163-170: Consider using a set for faster membership checking.

The g not in sampled check on Line 167 is O(n) per group. For larger buffers, this could become a bottleneck. Consider tracking sampled groups in a set using group IDs.

         # Fill remaining quota from most recent window if undersampled
         if len(sampled) < max_groups:
             remaining_quota = max_groups - len(sampled)
             most_recent_window = sorted_windows[-1]
             available = self._windows[most_recent_window]
-            unused = [g for g in available if g not in sampled]
+            sampled_ids = {id(g) for g in sampled}
+            unused = [g for g in available if id(g) not in sampled_ids]
             k = min(remaining_quota, len(unused))
             if k > 0:
                 sampled.extend(rng.sample(unused, k))

223-224: Document the memory estimation assumption.

The hardcoded 10_000 bytes per group is a rough approximation. Consider adding a comment explaining this estimate or making it configurable.

         total_groups = sum(len(groups) for groups in self._windows.values())
+        # Rough estimate: ~10KB per group (tokens + logprobs + metadata)
         memory_mb = (total_groups * 10_000) / (1024 * 1024)
scripts/test_wandb_shared.py (1)

36-44: Consider extracting hardcoded project name.

The "grail" project name is hardcoded here but could benefit from being a CLI argument or matching a constant, especially if this script is reused for different projects.

+    parser.add_argument(
+        "--project",
+        type=str,
+        default="grail",
+        help="WandB project name (default: grail)",
+    )
grail/cli/train.py (1)

76-86: Consider making the WandB sync delay configurable.

The hardcoded 5-second sleep is a pragmatic workaround for WandB shared mode sync requirements. Consider making this configurable or implementing a retry/polling mechanism for more robust synchronization.

+from grail.monitoring.config import WANDB_SYNC_DELAY_SECONDS  # Add to config
+
             if training_config.get("wandb_shared_mode"):
-                import asyncio
-
                 logger.info(
-                    "Waiting 5s for WandB run to sync to cloud (shared mode requirement)..."
+                    f"Waiting {WANDB_SYNC_DELAY_SECONDS}s for WandB run to sync..."
                 )
-                await asyncio.sleep(5)
+                await asyncio.sleep(WANDB_SYNC_DELAY_SECONDS)
grail/neurons/miner.py (1)

225-234: Verify the dual model/tokenizer checks are intentional.

There are two separate checks for model is None or tokenizer is None:

  1. Lines 225-228: When no checkpoint is available (sleeps 60s)
  2. Lines 231-234: Catch-all before mining (sleeps 30s)

The second check at lines 231-234 appears to be a fallback for edge cases where checkpoint loading might silently fail. Consider adding a comment clarifying this is a defensive guard.

grail/logging_utils.py (2)

294-296: Consider adding error handling for reconfigure calls.

sys.stdout.reconfigure() and sys.stderr.reconfigure() can raise AttributeError in non-standard Python environments or when stdout/stderr are replaced with objects that don't support reconfiguration (e.g., certain test fixtures, custom streams).

     # Force unbuffered output for immediate visibility
-    sys.stdout.reconfigure(line_buffering=True)
-    sys.stderr.reconfigure(line_buffering=True)
+    try:
+        sys.stdout.reconfigure(line_buffering=True)
+        sys.stderr.reconfigure(line_buffering=True)
+    except (AttributeError, io.UnsupportedOperation):
+        pass  # Stream doesn't support reconfiguration

341-366: Add CancelledError handling for clean task cancellation.

The while True loop runs indefinitely without handling asyncio.CancelledError. When this task is cancelled during shutdown, it will propagate the exception instead of exiting cleanly.

 async def monitor_event_loop_lag(interval: float = 5.0, threshold: float = 1.0) -> None:
     """Background task to detect event loop blocking.

     Args:
         interval: Check interval in seconds
         threshold: Lag threshold to trigger warning (seconds)
     """
     logger = logging.getLogger(__name__)
     last_check = time.time()

-    while True:
-        await asyncio.sleep(interval)
-        now = time.time()
-        actual_elapsed = now - last_check
-        expected_elapsed = interval
-        lag = actual_elapsed - expected_elapsed
-
-        if lag > threshold:
-            logger.warning(
-                "⚠️ Event loop lag detected: expected %.1fs, actual %.1fs (lag: %.1fs)",
-                expected_elapsed,
-                actual_elapsed,
-                lag,
-            )
-
-        last_check = now
+    try:
+        while True:
+            await asyncio.sleep(interval)
+            now = time.time()
+            actual_elapsed = now - last_check
+            expected_elapsed = interval
+            lag = actual_elapsed - expected_elapsed
+
+            if lag > threshold:
+                logger.warning(
+                    "⚠️ Event loop lag detected: expected %.1fs, actual %.1fs (lag: %.1fs)",
+                    expected_elapsed,
+                    actual_elapsed,
+                    lag,
+                )
+
+            last_check = now
+    except asyncio.CancelledError:
+        logger.debug("Event loop lag monitor cancelled")
+        raise
grail/validation/miner_validator.py (1)

177-187: The checkpoint window extraction is fragile and relies on specific path format.

The parsing logic assumes the model path contains "checkpoint-{window}" and splits on that pattern. This could fail silently or extract incorrect values if:

  1. The path uses a different separator (e.g., checkpoint_1000 vs checkpoint-1000)
  2. Multiple "checkpoint-" segments exist in the path
  3. The model was loaded from a different source (e.g., HuggingFace Hub)

Consider adding validation or logging when extraction fails.

         # Extract validator's checkpoint window from model path (if available)
         validator_checkpoint_window = None
         if hasattr(model, "name_or_path"):
             model_path = str(model.name_or_path)
             # Parse checkpoint-{window} from path like "/cache/checkpoints/checkpoint-1000"
             if "checkpoint-" in model_path:
                 try:
                     checkpoint_segment = model_path.split("checkpoint-")[-1].split("/")[0]
                     validator_checkpoint_window = int(checkpoint_segment)
+                    logger.debug(
+                        "Extracted validator checkpoint window %s from path %s",
+                        validator_checkpoint_window,
+                        model_path,
+                    )
                 except (ValueError, IndexError):
-                    pass
+                    logger.warning(
+                        "Failed to parse checkpoint window from model path: %s",
+                        model_path,
+                    )
scripts/test_monitoring_shared.py (1)

173-217: Consider consolidating multiple asyncio.run() calls into a single async workflow.

The worker process makes multiple separate asyncio.run() calls (lines 173, 199, 205, 206, 207, 216, 217). Each call creates and destroys a new event loop, which is inefficient and can cause issues with async resources that expect a persistent loop.

For a test script this is acceptable, but consider refactoring to a single async workflow:

async def _worker_async(manager, run_id, process_label, config):
    actual_run_id = await manager.start_run(f"{process_label}_process", config)
    await manager.log_gauge(f"testing/{process_label}_startup_time", elapsed)
    await manager.log_gauge(f"testing/{process_label}_timestamp", time.time())
    await manager.log_gauge(f"testing/{process_label}_connected", 1.0)
    await manager.flush_metrics()
    await asyncio.sleep(2)
    await manager.finish_run(actual_run_id)
    await manager.shutdown()

# Then call with single asyncio.run()
asyncio.run(_worker_async(manager, run_id, process_label, config))
grail/trainer/ipc.py (1)

122-142: Consider documenting the unbounded queue behavior.

The snapshot_queue is an unbounded multiprocessing.Queue(). If the upload worker falls behind, this could theoretically grow indefinitely. For the async trainer use case with infrequent snapshots, this is likely acceptable, but it may be worth documenting in the docstring.

grail/infrastructure/checkpoint_consumer.py (1)

600-624: Deprecated function references non-existent method.

The deprecated function _find_latest_ready_checkpoint_window_DEPRECATED calls checkpoint_manager._is_checkpoint_ready(window) on line 619, but this method doesn't exist in the current CheckpointManager class. If someone tries to use this preserved reference, it will fail.

Consider removing the deprecated function entirely or updating it to use _get_checkpoint_ready_window:

-    for window in sorted(windows, reverse=True):
-        if await checkpoint_manager._is_checkpoint_ready(window):
+    for window in sorted(windows, reverse=True):
+        if await checkpoint_manager._get_checkpoint_ready_window(window) is not None:
             logger.info(f"Found latest ready checkpoint at window {window}")
grail/trainer/upload_worker.py (2)

37-65: Potential issue with asyncio.get_event_loop() deprecation.

Using asyncio.get_event_loop() is deprecated in Python 3.10+ and may raise DeprecationWarning or fail in some contexts. Consider using asyncio.get_running_loop() instead since this is called within an async function.

-    loop = asyncio.get_event_loop()
+    loop = asyncio.get_running_loop()

67-100: Import statement inside function reduces readability.

The import os on line 88 should be moved to the top of the file with other imports for consistency and to avoid repeated import overhead if this function is called multiple times.

grail/trainer/config.py (1)

129-133: Consider removing or deprecating legacy sglang-specific fields.

Lines 130-133 still contain sglang_* prefixed fields (sglang_mem_fraction_static, sglang_context_length, etc.) while the main server config was renamed to server_*. This inconsistency could cause confusion about which fields apply to which backend.

grail/trainer/checkpoint_publisher.py (3)

269-299: Consider adding total upload timeout in addition to per-file timeout.

The current implementation has per-file timeouts via upload_timeout, but if there are many files, the total upload time could be unbounded. Consider adding a total operation timeout to prevent indefinite hangs.


386-528: upload_from_staging duplicates upload logic from publish_checkpoint.

The upload_file closure and upload logic in upload_from_staging (lines 467-496) is nearly identical to the one in publish_checkpoint (lines 269-299). Consider extracting to a shared helper to reduce duplication and maintenance burden.

Consider extracting the common upload logic:

async def _upload_files_from_directory(
    self,
    source_dir: Path,
    remote_prefix: str,
    exclude_files: set[str] | None = None,
) -> tuple[bool, float, float]:
    """Upload all files from a directory with retry and timing.
    
    Returns:
        Tuple of (success, total_mb, duration_seconds)
    """
    semaphore = asyncio.Semaphore(4)
    # ... shared upload logic ...

431-433: Fallback for parent_window could be off-by-one if windows aren't contiguous.

The fallback parent_window = max(0, target_window - WINDOW_LENGTH) assumes sequential windows. If windows can be skipped (e.g., due to downtime), this could produce incorrect lineage. Consider documenting this assumption.

grail/trainer/training_process.py (5)

23-25: Module-level side effect may cause issues in multi-process environments.

Setting PYTORCH_CUDA_ALLOC_CONF at import time could affect other processes that import this module but don't want this configuration. Consider moving this to the process entry point (run_training_process).

-if "PYTORCH_CUDA_ALLOC_CONF" not in os.environ:
-    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

Move to run_training_process:

def run_training_process(...):
    if "PYTORCH_CUDA_ALLOC_CONF" not in os.environ:
        os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    # ... rest of function

210-210: WAIT_WINDOW_LENGTH = 0.0 makes the wait loop a no-op.

With WAIT_WINDOW_LENGTH = 0.0, the condition current_block >= target_block is immediately true since target_block = current_block + 0.0. The TODO comment suggests this is intentional for now, but the loop structure is misleading.

Consider simplifying if no wait is needed:

# TODO: Add waiting when using extra process/GPUs
# Currently no wait - miners should already have the checkpoint
logger.info("Skipping initial checkpoint wait (WAIT_WINDOW_LENGTH=0)")

306-408: Monitoring initialization is overly complex with duplicate logging.

The _initialize_monitoring method has multiple layers of try/catch with verbose logging. While thorough, this makes the code harder to follow. Consider simplifying the success path and consolidating error handling.


317-320: Setting environment variables during initialization has side effects.

Lines 319-320 modify os.environ which affects the entire process and any libraries that have already read these values. Consider setting these before the process starts.


770-784: Fallback window calculation may hide subtensor issues.

When get_current_block fails (line 782), the fallback self.epoch_counter * WINDOW_LENGTH provides a pseudo-window that's decoupled from actual chain state. This could cause training to proceed with incorrect window context. Consider logging at WARNING level instead of just returning the fallback.

         except Exception as exc:
-            logger.warning("Failed to get current block: %s", exc)
+            logger.warning(
+                "Failed to get current block (using fallback epoch-based window): %s", exc
+            )
             return self.epoch_counter * WINDOW_LENGTH
grail/neurons/trainer.py (4)

116-118: Potential issue: multiprocessing.set_start_method called in constructor.

Calling set_start_method("spawn", force=True) in the constructor can cause issues if:

  1. Other parts of the codebase have already set a different method
  2. Multiple TrainerNeuron instances are created (though unlikely)

Consider moving this to module-level or the CLI entrypoint where it's called once at startup.

-    def __init__(self, context: TrainerContext) -> None:
-        """Initialize trainer neuron.
-
-        Args:
-            context: Resources required for training
-        """
-        super().__init__()
-        self._context = context
-
-        # Set multiprocessing start method to 'spawn' for CUDA compatibility
-        multiprocessing.set_start_method("spawn", force=True)
-        logger.info("Set multiprocessing start method to 'spawn' for CUDA compatibility")
+    def __init__(self, context: TrainerContext) -> None:
+        """Initialize trainer neuron.
+
+        Args:
+            context: Resources required for training
+        """
+        super().__init__()
+        self._context = context
+
+        # Ensure 'spawn' method for CUDA compatibility (safe to call multiple times with force=True)
+        try:
+            multiprocessing.set_start_method("spawn", force=True)
+            logger.info("Set multiprocessing start method to 'spawn' for CUDA compatibility")
+        except RuntimeError:
+            # Already set - this is fine
+            logger.debug("Multiprocessing start method already set")

419-420: Minor: Use asyncio.get_running_loop() instead of deprecated get_event_loop().

asyncio.get_event_loop() is deprecated in Python 3.10+ when called from a coroutine. Use asyncio.get_running_loop() instead.

-        # Use Event.wait() with timeout in thread pool to avoid blocking event loop
-        loop = asyncio.get_event_loop()
+        # Use Event.wait() with timeout in thread pool to avoid blocking event loop
+        loop = asyncio.get_running_loop()

831-835: Fragile: Backend type inference from class name.

Inferring backend_type by checking if "WandB" or "Null" is in the class name is fragile. Consider adding a backend_type property to the backend interface.

-            # Add backend_type from backend class name (needed by subprocess)
-            backend_class_name = self._context.monitor.backend.__class__.__name__
-            if "WandB" in backend_class_name:
-                monitor_config["backend_type"] = "wandb"
-            elif "Null" in backend_class_name:
-                monitor_config["backend_type"] = "null"
+            # Add backend_type from backend (needed by subprocess)
+            if hasattr(self._context.monitor.backend, "backend_type"):
+                monitor_config["backend_type"] = self._context.monitor.backend.backend_type
+            else:
+                # Fallback: infer from class name
+                backend_class_name = self._context.monitor.backend.__class__.__name__
+                if "WandB" in backend_class_name:
+                    monitor_config["backend_type"] = "wandb"
+                elif "Null" in backend_class_name:
+                    monitor_config["backend_type"] = "null"

842-844: Minor: Accessing private _config attribute.

The fallback accesses self._context.monitor._config, a private attribute. Consider exposing a public method on the monitor manager for this use case.

grail/monitoring/backends/wandb_backend.py (3)

95-97: Use asyncio.get_running_loop() instead of deprecated get_event_loop().

asyncio.get_event_loop() is deprecated when called from a coroutine. Since _run_executor is called from async methods, use asyncio.get_running_loop().

     def _run_executor(self, func: Any, *args: Any) -> Any:
         """Run synchronous function in executor."""
-        return asyncio.get_event_loop().run_in_executor(None, func, *args)
+        return asyncio.get_running_loop().run_in_executor(None, func, *args)

323-325: Consider: 256 define_metric calls at initialization.

Defining metrics for all 256 possible UIDs upfront may add initialization overhead. An alternative is lazy definition on first log via _maybe_define_step_for_name. However, since this runs once per run, the current approach is acceptable if init time is not critical.


476-485: Edge case: String coercion doesn't handle negative numbers or scientific notation.

isdigit() returns False for "-5" or "1e10". The fallback to float() handles some cases, but negative integers will become floats.

     def _coerce_tag_value(self, raw: Any) -> Any:
         """Coerce tag value from string to appropriate type."""
         if isinstance(raw, str):
-            if raw.isdigit():
-                return int(raw)
             try:
-                return float(raw)
+                # Try int first (handles negative), then float
+                float_val = float(raw)
+                if float_val.is_integer():
+                    return int(float_val)
+                return float_val
             except Exception:
                 return raw
         return raw
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1efa2ba and bd212b5.

📒 Files selected for processing (39)
  • .env.example (1 hunks)
  • grail/cli/mine.py (6 hunks)
  • grail/cli/train.py (4 hunks)
  • grail/environments/loop.py (3 hunks)
  • grail/environments/math_hendrycks_env.py (3 hunks)
  • grail/infrastructure/checkpoint_consumer.py (7 hunks)
  • grail/infrastructure/network.py (4 hunks)
  • grail/logging_utils.py (3 hunks)
  • grail/model/provider.py (1 hunks)
  • grail/model/train_loading.py (2 hunks)
  • grail/monitoring/backends/wandb_backend.py (13 hunks)
  • grail/monitoring/config.py (4 hunks)
  • grail/neurons/miner.py (4 hunks)
  • grail/neurons/trainer.py (2 hunks)
  • grail/neurons/validator.py (1 hunks)
  • grail/shared/chat_templates.py (3 hunks)
  • grail/shared/constants.py (4 hunks)
  • grail/trainer/__init__.py (1 hunks)
  • grail/trainer/algorithms/grpo.py (12 hunks)
  • grail/trainer/checkpoint_publisher.py (1 hunks)
  • grail/trainer/checkpointing.py (0 hunks)
  • grail/trainer/config.py (3 hunks)
  • grail/trainer/evaluator.py (1 hunks)
  • grail/trainer/inference_server.py (1 hunks)
  • grail/trainer/ipc.py (1 hunks)
  • grail/trainer/replay_buffer.py (1 hunks)
  • grail/trainer/service.py (4 hunks)
  • grail/trainer/snapshot_manager.py (1 hunks)
  • grail/trainer/training_process.py (1 hunks)
  • grail/trainer/upload_worker.py (1 hunks)
  • grail/validation/miner_validator.py (4 hunks)
  • grail/validation/service.py (4 hunks)
  • research/offline_trainer/tests/test_gpu_integration.py (2 hunks)
  • scripts/test_monitoring_shared.py (1 hunks)
  • scripts/test_wandb_shared.py (1 hunks)
  • tests/integration/test_async_trainer.py (1 hunks)
  • tests/integration/test_train_cli_smoke.py (1 hunks)
  • tests/integration/test_validation_service.py (1 hunks)
  • tests/unit/trainer/test_replay_buffer.py (1 hunks)
💤 Files with no reviewable changes (1)
  • grail/trainer/checkpointing.py
🧰 Additional context used
🧬 Code graph analysis (17)
grail/model/provider.py (1)
tests/integration/protocol/test_proof_cross_framework.py (1)
  • model (39-43)
tests/unit/trainer/test_replay_buffer.py (2)
grail/trainer/algorithms/grpo.py (1)
  • GRPOGroup (243-256)
grail/trainer/replay_buffer.py (2)
  • RecencyWeightedBuffer (66-250)
  • create_replay_buffer (253-280)
grail/trainer/ipc.py (2)
grail/infrastructure/chain.py (1)
  • stop (415-431)
grail/neurons/base.py (1)
  • heartbeat (123-125)
tests/integration/test_validation_service.py (1)
grail/infrastructure/checkpoint_consumer.py (1)
  • CheckpointManager (87-562)
grail/validation/miner_validator.py (1)
tests/integration/protocol/test_proof_cross_framework.py (1)
  • model (39-43)
scripts/test_wandb_shared.py (1)
scripts/test_monitoring_shared.py (1)
  • main (299-331)
grail/infrastructure/network.py (1)
tests/integration/infrastructure/test_block_timestamp.py (1)
  • subtensor (19-28)
grail/cli/train.py (3)
grail/trainer/checkpoint_publisher.py (1)
  • CheckpointPublisher (87-528)
grail/neurons/trainer.py (1)
  • TrainerContext (67-88)
grail/model/train_loading.py (3)
  • ModelLoadSpec (30-41)
  • parse_train_env (56-81)
  • parse_ref_env (84-109)
grail/model/train_loading.py (1)
grail/model/provider.py (1)
  • get_model (55-150)
tests/integration/test_async_trainer.py (2)
grail/trainer/snapshot_manager.py (10)
  • save_snapshot_atomic (61-129)
  • check_snapshot_ready (131-137)
  • copy_snapshot_to_staging (139-169)
  • cleanup_staging (171-176)
  • check_pause_flag (178-184)
  • set_pause_flag (186-189)
  • clear_pause_flag (191-195)
  • get_training_heartbeat_age (201-214)
  • set_training_heartbeat (197-199)
  • get_latest_snapshot_path (216-223)
tests/integration/test_train_cli_smoke.py (1)
  • save_pretrained (60-63)
grail/validation/service.py (1)
grail/infrastructure/checkpoint_consumer.py (1)
  • get_latest_ready_checkpoint (498-562)
grail/monitoring/backends/wandb_backend.py (2)
grail/monitoring/base.py (3)
  • MetricData (29-43)
  • log_metric (73-83)
  • MetricType (19-25)
grail/monitoring/backends/null_backend.py (1)
  • log_metric (34-40)
grail/trainer/config.py (1)
grail/trainer/inference_server.py (1)
  • start_server (70-72)
grail/trainer/replay_buffer.py (1)
grail/trainer/algorithms/grpo.py (1)
  • GRPOGroup (243-256)
grail/neurons/miner.py (1)
grail/infrastructure/checkpoint_consumer.py (2)
  • default_checkpoint_cache_root (570-574)
  • get_latest_ready_checkpoint (498-562)
grail/trainer/snapshot_manager.py (2)
tests/integration/protocol/test_proof_cross_framework.py (2)
  • model (39-43)
  • tokenizer (47-52)
tests/integration/test_train_cli_smoke.py (1)
  • save_pretrained (60-63)
tests/integration/test_train_cli_smoke.py (4)
grail/cli/train.py (1)
  • train (40-124)
grail/trainer/snapshot_manager.py (7)
  • save_snapshot_atomic (61-129)
  • set_training_heartbeat (197-199)
  • check_snapshot_ready (131-137)
  • copy_snapshot_to_staging (139-169)
  • cleanup_staging (171-176)
  • get_training_heartbeat_age (201-214)
  • get_latest_snapshot_path (216-223)
grail/infrastructure/chain.py (1)
  • stop (415-431)
tests/integration/infrastructure/test_block_timestamp.py (1)
  • chain_manager (32-65)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: Run Tests (3.11)
  • GitHub Check: Run Tests (3.10)
  • GitHub Check: Cursor Bugbot
  • GitHub Check: Run Tests (3.11)
  • GitHub Check: Run Tests (3.10)

git_commit: str = "unknown"
created_at: float = 0.0
model_name: str = "no_name"
parent_window: int | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

parent_window added but not populated during metadata fetch.

The parent_window field is added to CheckpointMetadata, but in _fetch_metadata() (lines 382-391), it's not read from the payload, leaving it as None even when present in the remote metadata. This inconsistency could cause issues when consumers need to track checkpoint lineage.

         metadata = CheckpointMetadata(
             window=payload.get("window", window),
             file_manifest=payload.get("file_manifest", {}),
             training_config=payload.get("training_config", {}),
             git_commit=payload.get("git_commit", "unknown"),
             created_at=float(payload.get("created_at", 0.0)),
             model_name=payload.get("model_name", "no_name"),
+            parent_window=payload.get("parent_window"),
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
parent_window: int | None = None
metadata = CheckpointMetadata(
window=payload.get("window", window),
file_manifest=payload.get("file_manifest", {}),
training_config=payload.get("training_config", {}),
git_commit=payload.get("git_commit", "unknown"),
created_at=float(payload.get("created_at", 0.0)),
model_name=payload.get("model_name", "no_name"),
parent_window=payload.get("parent_window"),
)
🤖 Prompt for AI Agents
In grail/infrastructure/checkpoint_consumer.py around line 72 (where
parent_window: int | None = None is declared) and in the metadata fetch logic at
lines ~382-391, populate the new parent_window field from the fetched payload:
read payload.get("parent_window") (or the equivalent key used by the remote
metadata), validate/convert it to int if present, and pass that value into the
CheckpointMetadata constructor so it is not left as None when present remotely;
also handle missing/null values by leaving parent_window as None and add a brief
defensive check to avoid ValueError on conversion.

Comment on lines +62 to +63
# Track last successful call for idle detection
object.__setattr__(self, "_last_call_timestamp", time.time())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Clarify idle tracking semantics and align comments/logs with behavior

The idle‑tracking itself (_last_call_timestamp on success, checked in _call_with_retry) looks reasonable, but the surrounding comments and log message are now inconsistent with the implementation:

  • The comment at Line 230 talks about “Double timeout if connection has been idle for 20+ seconds”, but:
    • The code restarts the subtensor when idle_duration > 60.0.
    • The code does not modify timeout at all in this block (only metagraph calls double timeout earlier).
  • The log message still says “restarting subtensor and doubling timeout for %s”, which is misleading since the timeout isn’t changed here.

I’d recommend either (a) actually doubling the timeout here, or (b) updating the comment and log text to match the current “restart after 60s idle” behavior. If the intent is just to restart after 60s and keep timeouts unchanged, something like this would keep the implementation and docs in sync:

-        # Double timeout if connection has been idle for 20+ seconds
-        # Research-based threshold:
-        # - Bittensor WebSocket auto-closes after 10s inactivity
-        # - Substrate layer closes after ~60s inactivity
-        # - 20s catches stale connections early without false positives
-        # - Critical for upload worker (40-300s idle during R2 uploads)
+        # Restart connection if it has been idle for an extended period.
+        # Research-based threshold:
+        # - Bittensor WebSocket auto-closes after ~10s inactivity
+        # - Substrate layer closes after ~60s inactivity
+        # - We restart after 60s of no successful calls to avoid stale connections
+        #   while avoiding excessive churn during normal gaps (e.g. long R2 uploads).
@@
-        if idle_duration > 60.0:
-            logger.warning(
-                "⏰ Connection idle for %.1fs, restarting subtensor and doubling timeout for %s",
-                idle_duration,
-                method_name,
-            )
+        if idle_duration > 60.0:
+            logger.warning(
+                "⏰ Connection idle for %.1fs, restarting subtensor before calling %s",
+                idle_duration,
+                method_name,
+            )

If you do intend to also relax timeouts for long‑idle connections, you could additionally add timeout *= 2 inside this block (or some capped multiplier) and adjust the wording accordingly.

Also applies to: 168-169, 230-245

🤖 Prompt for AI Agents
In grail/infrastructure/network.py around lines 62-63 (and also review lines
168-169 and 230-245), the comments and log messages claim the code "doubles
timeout" for long-idle connections but the implementation only restarts the
subtensor when idle_duration > 60.0 and does not change timeout; update the code
to make behavior and docs consistent by either (A) actually doubling/relaxing
the timeout inside the idle branch (e.g., timeout *= 2 with an optional cap) and
update comments/logs to mention both restart and timeout adjustment, or (B)
leave timeouts unchanged and revise the comment and log messages to remove any
mention of doubling timeout and instead state that the subtensor is being
restarted after 60s of idle; apply the chosen fix consistently to all affected
comment/log locations (lines mentioned above).

Comment on lines +802 to +812
def _serialize_wallet(self) -> dict[str, str]:
"""Serialize wallet for pickling to child process.
Returns:
Dictionary of wallet arguments
"""
logger.info(
"Reloading training models: train_model=%s ref_model=%s",
self._context.train_model is not None,
self._context.ref_model is not None,
)
return {
"name": self._context.wallet.name,
"hotkey": self._context.wallet.hotkey_str,
"path": self._context.wallet.path,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, let's find wallet-related imports and class definitions
rg -n "class.*[Ww]allet" --type py | head -20

Repository: one-covenant/grail

Length of output: 187


🏁 Script executed:

# Search for hotkey_str usage across the codebase
rg -n "hotkey_str" --type py -B 2 -A 2

Repository: one-covenant/grail

Length of output: 815


🏁 Script executed:

# Look for _FakeWallet or test wallet implementations
rg -n "_FakeWallet|FakeWallet" --type py -B 2 -A 5

Repository: one-covenant/grail

Length of output: 1430


🏁 Script executed:

# Search for bittensor wallet imports to understand the wallet object
rg -n "from bittensor|import.*wallet|bt\.wallet" --type py | head -20

Repository: one-covenant/grail

Length of output: 1629


🏁 Script executed:

# Search for wallet attribute access patterns in production code
rg -n "wallet\.(hotkey|name|path)" --type py -B 2 -A 2 | grep -v test | head -40

Repository: one-covenant/grail

Length of output: 2696


🏁 Script executed:

# Look at the _reconstruct_wallet function more closely
cat -n grail/trainer/training_process.py | sed -n '1045,1060p'

Repository: one-covenant/grail

Length of output: 545


🏁 Script executed:

# Check where wallet_args dict is created and passed to _reconstruct_wallet
rg -n "_reconstruct_wallet|_serialize_wallet" --type py -B 3 -A 3

Repository: one-covenant/grail

Length of output: 2478


🏁 Script executed:

# Search for bittensor wallet source or any documentation about its attributes
rg -n "\.hotkey_str|\.hotkey\b" --type py | grep -v test | head -20

Repository: one-covenant/grail

Length of output: 2014


🏁 Script executed:

# Search for wallet property definitions or any string representations
rg -n "hotkey" grail/trainer/training_process.py -B 2 -A 2

Repository: one-covenant/grail

Length of output: 3126


🏁 Script executed:

# Look at actual wallet initialization to understand what parameters bt.wallet accepts
rg -n "bt\.wallet\(" --type py -B 1 -A 1 | head -40

Repository: one-covenant/grail

Length of output: 1374


🏁 Script executed:

# Check if there's a hotkey_str property anywhere in the codebase
rg -n "hotkey_str" --type py

Repository: one-covenant/grail

Length of output: 223


🏁 Script executed:

# Look at the upload_worker that also uses wallet reconstruction
cat -n grail/trainer/upload_worker.py | sed -n '290,310p'

Repository: one-covenant/grail

Length of output: 954


🏁 Script executed:

# Check if there's any documentation or comments about wallet hotkey parameter
rg -n "hotkey.*=.*get_conf|BT_WALLET_HOT" --type py -B 2 -A 5

Repository: one-covenant/grail

Length of output: 4223


🏁 Script executed:

# Look at test to understand what _FakeWallet expects as hotkey parameter
cat -n tests/integration/test_train_cli_smoke.py | sed -n '29,50p'

Repository: one-covenant/grail

Length of output: 886


🏁 Script executed:

# Search for any property or method that gives hotkey string representation
rg -n "def.*hotkey|@property.*hotkey" --type py

Repository: one-covenant/grail

Length of output: 1628


🏁 Script executed:

# Search for any wallet attributes that might store the original hotkey name/string
rg -n "wallet\." grail/neurons/trainer.py | head -20

Repository: one-covenant/grail

Length of output: 222


🏁 Script executed:

# Check if wallet has any other attributes we should use
rg -n "self\._context\.wallet\." grail/neurons/trainer.py -A 1

Repository: one-covenant/grail

Length of output: 236


🏁 Script executed:

# Look for wallet name attribute to understand the pattern
rg -n "wallet\.name" --type py | grep -v test | head -10

Repository: one-covenant/grail

Length of output: 804


🏁 Script executed:

# Search for any comments or documentation about wallet structure
rg -n "wallet.*hotkey|hotkey.*wallet" grail/neurons/trainer.py -B 3 -A 3

Repository: one-covenant/grail

Length of output: 278


🏁 Script executed:

# Check the context object to understand where wallet comes from
rg -n "_context.*wallet|wallet.*context" grail/neurons/trainer.py -B 2 -A 2

Repository: one-covenant/grail

Length of output: 464


🏁 Script executed:

# Look at the _FakeWallet definition once more to see exactly what it mimics
cat -n tests/integration/test_train_cli_smoke.py | sed -n '29,40p'

Repository: one-covenant/grail

Length of output: 508


🏁 Script executed:

# Search for any real wallet initialization or attribute documentation
rg -n "class.*Context|context.*wallet" --type py grail/neurons/trainer.py | head -10

Repository: one-covenant/grail

Length of output: 289


🏁 Script executed:

# Find and examine TrainerContext class definition
rg -n "class TrainerContext" --type py -A 20

Repository: one-covenant/grail

Length of output: 1426


Critical bug: wallet.hotkey_str attribute does not exist on bittensor wallet objects.

Line 810 accesses self._context.wallet.hotkey_str, which is not a valid attribute on real bt.wallet instances. The test suite only passes because _FakeWallet (in tests) manually defines this attribute, but production code will crash with AttributeError at runtime.

Real bittensor wallets have wallet.hotkey (a Keypair object), wallet.name, and wallet.path, but not hotkey_str. To fix the serialization, determine what value should be passed as the hotkey parameter when reconstructing the wallet with bt.wallet(**wallet_args). The current implementation cannot work in production.

🤖 Prompt for AI Agents
In grail/neurons/trainer.py around lines 802–812, the serializer reads a
non-existent attribute wallet.hotkey_str; change the serialized key to "hotkey"
and produce a string representation that bt.wallet(...) accepts. Replace the
hotkey value with a string derived from the Keypair, e.g. use
wallet.hotkey.ss58_address if present, otherwise fallback to str(wallet.hotkey)
or wallet.hotkey.public_key.hex(); ensure the returned dict keys are "name",
"hotkey", and "path" so bt.wallet(**wallet_args) can reconstruct the wallet
without AttributeError.



def build_qwen_chat_template(system_prompt: str, reasoning_start: str = "") -> str:
def build_qwen_chat_template(system_prompt: str) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

rg -n --type=py 'build_qwen_chat_template\s*\(' -A 2 -B 2

Repository: one-covenant/grail

Length of output: 4202


Update two remaining call sites that still pass the reasoning_start parameter.

The function signature change removes the reasoning_start parameter, but two files still pass this argument and will fail:

  • research/trl/train_trl_gsm8k.py:90-91 — Remove reasoning_start=REASONING_START
  • research/offline_trainer/src/grail_offline/pipelines/offline_grpo.py:312 — Change to pass only SYSTEM_PROMPT

Seven other call sites have been correctly updated and only pass the system_prompt argument.

🤖 Prompt for AI Agents
In grail/shared/chat_templates.py around line 9, the function
build_qwen_chat_template has removed the reasoning_start parameter; update the
two remaining call sites: in research/trl/train_trl_gsm8k.py at lines ~90-91
remove the named argument reasoning_start=REASONING_START so the call passes
only SYSTEM_PROMPT (or the appropriate system prompt variable), and in
research/offline_trainer/src/grail_offline/pipelines/offline_grpo.py at line
~312 change the call to pass only SYSTEM_PROMPT (drop the reasoning_start
argument) so both calls match the new single-argument signature.

Comment on lines +1650 to +1651
micro_batch_size = max(1, int(config.batch_size))
grad_accum_steps = max(1, self.config.grad_accum_steps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Inconsistent config source: config vs self.config.

Line 1650 uses the config parameter while line 1651 uses self.config. If the passed config should take precedence (as it's explicitly passed to train_epoch), consider using config.grad_accum_steps for consistency:

 micro_batch_size = max(1, int(config.batch_size))
-grad_accum_steps = max(1, self.config.grad_accum_steps)
+grad_accum_steps = max(1, int(config.grad_accum_steps))

Alternatively, if self.config is intentionally used for certain parameters, add a comment explaining the distinction.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
micro_batch_size = max(1, int(config.batch_size))
grad_accum_steps = max(1, self.config.grad_accum_steps)
micro_batch_size = max(1, int(config.batch_size))
grad_accum_steps = max(1, int(config.grad_accum_steps))
🤖 Prompt for AI Agents
In grail/trainer/algorithms/grpo.py around lines 1650-1651, the code mixes the
local parameter config with self.config (micro_batch_size uses config while
grad_accum_steps uses self.config); make the source consistent by using
config.grad_accum_steps instead of self.config.grad_accum_steps (or, if
self.config is intentionally authoritative, add a clear comment explaining why
and use self.config for both), and ensure a safe fallback (e.g., max(1,
int(config.grad_accum_steps))) so the value is validated like micro_batch_size.

Comment on lines +108 to +114
# Atomic rename: remove old target if exists, then rename
if target_dir.exists():
# Remove old snapshot
shutil.rmtree(target_dir)

# Rename temp to target (atomic on POSIX)
temp_dir.rename(target_dir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Race condition: rmtree + rename is not atomic.

If a crash occurs between shutil.rmtree(target_dir) at line 111 and temp_dir.rename(target_dir) at line 114, the snapshot is lost. Consider a safer pattern:

  1. Rename latestlatest.old
  2. Rename temp → latest
  3. Delete latest.old

This ensures at least one valid snapshot always exists.

             # Atomic rename: remove old target if exists, then rename
             if target_dir.exists():
-                # Remove old snapshot
-                shutil.rmtree(target_dir)
-
-            # Rename temp to target (atomic on POSIX)
-            temp_dir.rename(target_dir)
+                # Rename old snapshot to backup first
+                backup_dir = self.snapshot_dir / "latest.old"
+                if backup_dir.exists():
+                    shutil.rmtree(backup_dir)
+                target_dir.rename(backup_dir)
+
+            # Rename temp to target (atomic on POSIX)
+            temp_dir.rename(target_dir)
+
+            # Clean up backup after successful rename
+            backup_dir = self.snapshot_dir / "latest.old"
+            if backup_dir.exists():
+                try:
+                    shutil.rmtree(backup_dir)
+                except Exception:
+                    pass  # Best effort cleanup

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In grail/trainer/snapshot_manager.py around lines 108-114, the current sequence
removes target_dir with shutil.rmtree() then renames temp_dir → target_dir which
can lose the snapshot if the process crashes between those steps; instead
implement the safe three-step swap: rename the existing target_dir to a backup
name (e.g., target_dir + ".old" or a unique suffix) using an atomic
rename/replace, then rename temp_dir to target_dir, and finally remove the
backup; ensure you handle pre-existing backup (remove or choose unique name),
catch and propagate errors so failures during the swap leave either the old or
the new snapshot intact, and clean up temp_dir or backup appropriately in error
paths.

Comment on lines +176 to +196
# Set READY-{ready_window} marker
try:
finalized = await checkpoint_publisher.finalize_checkpoint_ready(
checkpoint_window,
ready_window,
)
if finalized:
logger.info(
"✅ Set READY-%s marker for checkpoint-%s",
ready_window,
checkpoint_window,
)
except Exception as exc:
logger.error("Failed to finalize checkpoint READY marker: %s", exc)

# Update last uploaded window to prevent duplicates
last_uploaded_window = checkpoint_window

# Cleanup staging directory
snapshot_manager.cleanup_staging()
logger.info("Upload cycle complete for checkpoint-%s", checkpoint_window)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

READY marker failure doesn't prevent last_uploaded_window update.

If finalize_checkpoint_ready fails (line 189), the upload is still considered complete and last_uploaded_window is updated (line 192). This could leave orphaned checkpoints without READY markers that consumers won't discover.

Consider only updating last_uploaded_window when the READY marker is successfully set:

             # Set READY-{ready_window} marker
             try:
                 finalized = await checkpoint_publisher.finalize_checkpoint_ready(
                     checkpoint_window,
                     ready_window,
                 )
                 if finalized:
                     logger.info(
                         "✅ Set READY-%s marker for checkpoint-%s",
                         ready_window,
                         checkpoint_window,
                     )
+                    # Only mark as uploaded if READY marker was set successfully
+                    last_uploaded_window = checkpoint_window
+                else:
+                    logger.warning("READY marker not finalized, will retry upload next cycle")
             except Exception as exc:
                 logger.error("Failed to finalize checkpoint READY marker: %s", exc)
+                # Don't update last_uploaded_window so retry is possible

-            # Update last uploaded window to prevent duplicates
-            last_uploaded_window = checkpoint_window
🤖 Prompt for AI Agents
In grail/trainer/upload_worker.py around lines 176 to 196, the code updates
last_uploaded_window and treats the upload as complete even if
finalize_checkpoint_ready() fails; change the flow so last_uploaded_window is
only set when the READY marker was successfully created (i.e., finalized is
True). Specifically, move the assignment of last_uploaded_window into the try
block under the condition that finalized is True, and ensure that on exception
you do not update last_uploaded_window (but still perform necessary
cleanup/logging as appropriate) so orphaned checkpoints won't be considered
uploaded.

Comment on lines +236 to +257
# Step 1: Start primary process
logger.info("\n[STEP 1] Starting primary process...")
run_id = await primary_process_workflow(entity=entity)

logger.info("\n[STEP 2] Spawning worker process to connect to run...")

# Step 2: Spawn worker process
process = multiprocessing.Process(
target=worker_process_workflow,
args=(run_id, entity, "training_worker"),
)
process.start()
process.join(timeout=30)

if process.is_alive():
logger.error("❌ Worker process timeout!")
process.terminate()
process.join()
elif process.exitcode != 0:
logger.error(f"❌ Worker process failed with exit code {process.exitcode}")
else:
logger.info("✅ Worker process completed successfully")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

The test sequence has a timing issue: primary finishes before worker starts.

In test_shared_mode_new, the primary process workflow at line 238 runs to completion (including finish_run and shutdown on lines 107-108) before the worker is spawned at line 243. This means the worker attempts to connect to a run that has already been finished.

For a valid shared-mode test, the primary should remain active while the worker connects.

Consider restructuring to keep primary alive:

 async def test_shared_mode_new(entity: str | None = None) -> None:
     ...
     # Step 1: Start primary process
     logger.info("\n[STEP 1] Starting primary process...")
-    run_id = await primary_process_workflow(entity=entity)
-
-    logger.info("\n[STEP 2] Spawning worker process to connect to run...")
-
-    # Step 2: Spawn worker process
-    process = multiprocessing.Process(
+    # Start primary in background, get run_id, then spawn worker while primary is alive
+    # This requires restructuring primary_process_workflow to not immediately shutdown

Committable suggestion skipped: line range outside the PR's diff.

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is being reviewed by Cursor Bugbot

Details

You are on the Bugbot Free tier. On this plan, Bugbot will review limited PRs each billing cycle.

To receive Bugbot reviews on all of your PRs, visit the Cursor dashboard to activate Pro and start your 14-day free trial.


# Stage 5: Rank by combined efficiency score and select top-k groups
max_groups: int = int(getattr(config, "grpo_max_groups", 8))
max_groups: int = int(getattr(config, "grpo_max_groups_per_window", 10000))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Config attribute name mismatch disables group filtering

The code uses getattr(config, "grpo_max_groups_per_window", 10000) but TrainingConfig defines the attribute as grpo_max_groups (not grpo_max_groups_per_window). The attribute name mismatch causes getattr to always return the fallback value of 10000, effectively disabling the configured GRPO_MAX_GROUPS limit (default 32) and allowing far more groups than intended in training.

Additional Locations (1)

Fix in Cursor Fix in Web

timers,
monitor,
self.use_drand,
checkpoint_window,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Miner passes None checkpoint_window to function expecting int

When get_latest_ready_checkpoint() returns None but a model is already loaded from a previous iteration, the code continues past the checkpoint loading block and calls generate_rollouts_for_window passing checkpoint_window which is None. The function signature expects checkpoint_window: int, causing a type mismatch. The code should likely use current_checkpoint_window (which stores the previously loaded checkpoint window) instead of the local checkpoint_window variable.

Fix in Cursor Fix in Web

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (3)
grail/trainer/training_process.py (1)

831-832: Critical: Remove hardcoded UID 80 filter from production code.

Lines 831-832 override the trusted miners list with a hardcoded filter to UID 80, which will break training for all other miners in production. The # TEST: comment indicates this is debug code that was not removed before commit.

Remove or conditionally disable this test filter:

-            # TEST: filter to uid 80 hotkey and choose between all hotkeys
-            trusted_hotkeys = [hk for hk, uid in uid_by_hotkey.items() if uid == 80]
+            # Use all trusted hotkeys for training (test filter removed)

If testing is needed, gate behind an environment variable:

if os.getenv("GRAIL_TEST_SINGLE_MINER"):
    test_uid = int(os.getenv("GRAIL_TEST_UID", "80"))
    trusted_hotkeys = [hk for hk, uid in uid_by_hotkey.items() if uid == test_uid]
    logger.warning("TEST MODE: Filtering to single UID %d", test_uid)
grail/neurons/trainer.py (1)

798-808: Critical: wallet.hotkey_str attribute does not exist on bittensor wallet objects.

Line 806 accesses self._context.wallet.hotkey_str, which is not a valid attribute on real bt.wallet instances. This will cause an AttributeError at runtime in production. The test suite only passes because _FakeWallet manually defines this attribute.

Real bittensor wallets have wallet.hotkey (a Keypair object), wallet.name, and wallet.path, but not hotkey_str.

Replace with a valid attribute:

     def _serialize_wallet(self) -> dict[str, str]:
         """Serialize wallet for pickling to child process.
 
         Returns:
             Dictionary of wallet arguments
         """
         return {
             "name": self._context.wallet.name,
-            "hotkey": self._context.wallet.hotkey_str,
+            "hotkey": self._context.wallet.hotkey.ss58_address,
             "path": self._context.wallet.path,
         }

This extracts the SS58 address from the Keypair object, which bt.wallet(**wallet_args) accepts as the hotkey parameter.

grail/trainer/snapshot_manager.py (1)

106-112: Race condition: rmtree + rename is not atomic and risks snapshot loss.

If a crash occurs between shutil.rmtree(target_dir) at line 109 and temp_dir.rename(target_dir) at line 112, the snapshot is permanently lost. This violates the atomicity guarantee stated in the docstring.

Implement a safer three-step swap pattern:

             # Atomic rename: remove old target if exists, then rename
             if target_dir.exists():
-                # Remove old snapshot
-                shutil.rmtree(target_dir)
-
-            # Rename temp to target (atomic on POSIX)
-            temp_dir.rename(target_dir)
+                # Step 1: Rename old snapshot to backup (atomic)
+                backup_dir = self.snapshot_dir / "latest.old"
+                if backup_dir.exists():
+                    shutil.rmtree(backup_dir)
+                target_dir.rename(backup_dir)
+
+            # Step 2: Rename temp to target (atomic on POSIX)
+            temp_dir.rename(target_dir)
+
+            # Step 3: Clean up backup after successful rename
+            backup_dir = self.snapshot_dir / "latest.old"
+            if backup_dir.exists():
+                try:
+                    shutil.rmtree(backup_dir)
+                except Exception:
+                    pass  # Best effort cleanup

This ensures at least one valid snapshot (either old or new) always exists on disk, even if the process crashes mid-operation.

🧹 Nitpick comments (1)
grail/trainer/checkpoint_publisher.py (1)

328-354: Verify the compressed file detection logic aligns with upload_file_chunked behavior.

Lines 333-335 infer whether a file was compressed by checking if it's a small JSON file (<10MB). However, this duplicates the compression decision logic from upload_file_chunked in grail/infrastructure/comms.py (lines 429-435). If the compression threshold or logic changes in upload_file_chunked, this verification will break.

Consider one of these approaches:

  1. Return compression info from upload_file_chunked (preferred):
-async def upload_file_chunked(...) -> bool:
+async def upload_file_chunked(...) -> tuple[bool, bool]:
     # ... existing logic ...
-    return True
+    return True, was_compressed  # Return compression status

Then use the returned value here instead of re-inferring.

  1. Extract shared compression logic to a utility function:
def should_compress_file(key: str, size: int) -> bool:
    return key.endswith(".json") and size < 10 * 1024 * 1024

Use this in both places to ensure consistency.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bd212b5 and e9feb20.

📒 Files selected for processing (6)
  • grail/infrastructure/comms.py (1 hunks)
  • grail/neurons/trainer.py (2 hunks)
  • grail/trainer/checkpoint_publisher.py (1 hunks)
  • grail/trainer/snapshot_manager.py (1 hunks)
  • grail/trainer/training_process.py (1 hunks)
  • tests/integration/test_async_trainer.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/integration/test_async_trainer.py
🧰 Additional context used
🧬 Code graph analysis (2)
grail/trainer/checkpoint_publisher.py (2)
grail/infrastructure/checkpoint_consumer.py (2)
  • CheckpointMetadata (63-75)
  • remote_prefix (74-75)
grail/infrastructure/comms.py (4)
  • delete_prefix (1056-1087)
  • get_file_size (1025-1053)
  • upload_file_chunked (402-541)
  • list_bucket_files (984-1022)
grail/trainer/snapshot_manager.py (2)
tests/integration/protocol/test_proof_cross_framework.py (2)
  • model (39-43)
  • tokenizer (47-52)
tests/integration/test_train_cli_smoke.py (1)
  • save_pretrained (60-63)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: Cursor Bugbot
  • GitHub Check: Run Tests (3.11)
  • GitHub Check: Run Tests (3.10)
  • GitHub Check: Run Tests (3.11)
  • GitHub Check: Run Tests (3.10)
🔇 Additional comments (1)
grail/infrastructure/comms.py (1)

984-1022: LGTM! Pagination support correctly implemented.

The pagination logic properly handles S3/R2's continuation token mechanism for listing >1000 objects. The safety check at lines 1014-1015 prevents infinite loops if NextContinuationToken is unexpectedly missing. This enhancement is essential for the checkpoint cleanup operations that may need to list many checkpoint files.


batch_size = config.batch_size
micro_batch_size = max(1, int(config.batch_size))
grad_accum_steps = max(1, self.config.grad_accum_steps)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: train_epoch mixes config parameter with self.config

The train_epoch function accepts a config parameter but inconsistently uses it only for batch_size while using self.config for grad_accum_steps, entropy_coef, grad_clip, and max_length. If a caller passes a different config to train_epoch, only the batch size would reflect the passed config while other hyperparameters would use the instance's config, leading to unexpected training behavior.

Additional Locations (1)

Fix in Cursor Fix in Web

)
self._windows_since_last_eval += 1

await asyncio.sleep(ORCHESTRATION_SLEEP_SECONDS)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Evaluation interval counts loop iterations not training windows

The _windows_since_last_eval counter is now incremented per orchestration loop iteration (every 60 seconds) rather than per actual training window processed. This changes the semantics of window_interval from "evaluate every N training windows" to "evaluate every N minutes." Users configuring window_interval=20 expecting evaluation every 20 training windows will instead get evaluation approximately every 20 minutes regardless of training progress.

Fix in Cursor Fix in Web

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e9feb20 and ea6d15e.

📒 Files selected for processing (3)
  • grail/cli/mine.py (8 hunks)
  • grail/shared/chat_templates.py (1 hunks)
  • grail/trainer/config.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • grail/cli/mine.py
🧰 Additional context used
🧬 Code graph analysis (1)
grail/trainer/config.py (1)
grail/trainer/inference_server.py (1)
  • start_server (70-72)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: Cursor Bugbot
  • GitHub Check: Run Tests (3.10)
  • GitHub Check: Run Tests (3.11)
  • GitHub Check: Run Tests (3.11)
  • GitHub Check: Run Tests (3.10)
🔇 Additional comments (5)
grail/trainer/config.py (2)

53-54: LGTM. The gradient checkpointing configuration follows the established pattern of sourcing defaults from constants, and the TRAINER_USE_GRADIENT_CHECKPOINTING constant is properly defined in grail/shared/constants.py (line 83-84).


88-88: This change is justified and aligns with standard evaluation practices.

The default split of "val" is the appropriate choice for evaluation, as validation sets are typically more controlled and representative than test sets. All EvalConfig instantiations throughout the codebase work correctly with this default, and the dataset providers (GSM8K, MATH) fully support the "val" split. The TODO comment accurately reflects a future improvement: making split configurable per environment rather than a global default. No breaking changes result from this modification.

grail/shared/chat_templates.py (3)

22-26: LGTM!

Good use of constants for ChatML tokens (IM_START, IM_END) to centralize token definitions and improve maintainability.


29-51: LGTM on the signature change and default handling.

The optional system_prompt parameter with sensible defaulting to SYSTEM_PROMPT is a clean API design. The past review already flagged the external call sites that need updating.


65-86: LGTM!

The template construction is well-structured with clear inline comments explaining the critical behavior around add_generation_prompt. The conditional system message handling and message iteration are correctly implemented.

Comment on lines +53 to +54
# Escape single quotes in system prompt for safe embedding in template
escaped_prompt = system_prompt.replace("'", "\\'")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Incomplete escaping may cause Jinja2 syntax errors.

Only single quotes are escaped, but backslashes are not. If system_prompt contains a backslash followed by a single quote (e.g., \'), the result \\' would be misinterpreted in the Jinja2 template as an escaped backslash followed by a string-terminating quote.

Apply this diff to escape backslashes before single quotes:

-    # Escape single quotes in system prompt for safe embedding in template
-    escaped_prompt = system_prompt.replace("'", "\\'")
+    # Escape backslashes and single quotes for safe embedding in Jinja2 template
+    escaped_prompt = system_prompt.replace("\\", "\\\\").replace("'", "\\'")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Escape single quotes in system prompt for safe embedding in template
escaped_prompt = system_prompt.replace("'", "\\'")
# Escape backslashes and single quotes for safe embedding in Jinja2 template
escaped_prompt = system_prompt.replace("\\", "\\\\").replace("'", "\\'")
🤖 Prompt for AI Agents
In grail/shared/chat_templates.py around lines 53-54, the code only escapes
single quotes which fails when the system_prompt contains backslashes before
quotes; first escape backslashes (replace "\" with "\\") then escape single
quotes (replace "'" with "\'") so backslashes are doubled before you escape
quotes, ensuring Jinja2 receives a correctly escaped string.

Comment on lines +71 to +76
# Replay buffer configuration
replay_buffer_enabled: bool = True
replay_buffer_max_windows: int = 1 # Store last 1 windows (~6 min of data)
replay_buffer_recent_fraction: float = 0.5 # 50% samples from most recent window
replay_buffer_decay_factor: float = 0.7 # Exponential decay for older windows
replay_buffer_max_groups_per_epoch: int = 64 # Max groups to sample per epoch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Move replay buffer defaults to constants module for consistency.

These fields break the established pattern in TrainingConfig where all other defaults are sourced from grail.shared.constants. Hardcoding values here creates an inconsistent maintainability model.

Move these defaults to grail/shared/constants.py and reference them here:

     # Replay buffer configuration
-    replay_buffer_enabled: bool = True
-    replay_buffer_max_windows: int = 1  # Store last 1 windows (~6 min of data)
-    replay_buffer_recent_fraction: float = 0.5  # 50% samples from most recent window
-    replay_buffer_decay_factor: float = 0.7  # Exponential decay for older windows
-    replay_buffer_max_groups_per_epoch: int = 64  # Max groups to sample per epoch
+    replay_buffer_enabled: bool = constants.REPLAY_BUFFER_ENABLED
+    replay_buffer_max_windows: int = constants.REPLAY_BUFFER_MAX_WINDOWS
+    replay_buffer_recent_fraction: float = constants.REPLAY_BUFFER_RECENT_FRACTION
+    replay_buffer_decay_factor: float = constants.REPLAY_BUFFER_DECAY_FACTOR
+    replay_buffer_max_groups_per_epoch: int = constants.REPLAY_BUFFER_MAX_GROUPS_PER_EPOCH

Then add to grail/shared/constants.py:

# Replay buffer configuration
REPLAY_BUFFER_ENABLED: bool = True
REPLAY_BUFFER_MAX_WINDOWS: int = 1  # Store last 1 windows (~6 min of data)
REPLAY_BUFFER_RECENT_FRACTION: float = 0.5  # 50% samples from most recent window
REPLAY_BUFFER_DECAY_FACTOR: float = 0.7  # Exponential decay for older windows
REPLAY_BUFFER_MAX_GROUPS_PER_EPOCH: int = 64  # Max groups to sample per epoch
🤖 Prompt for AI Agents
In grail/trainer/config.py around lines 71 to 76, the replay buffer default
values are hardcoded and must be moved to grail/shared/constants.py for
consistency; add the five constants (REPLAY_BUFFER_ENABLED,
REPLAY_BUFFER_MAX_WINDOWS, REPLAY_BUFFER_RECENT_FRACTION,
REPLAY_BUFFER_DECAY_FACTOR, REPLAY_BUFFER_MAX_GROUPS_PER_EPOCH) to
grail/shared/constants.py with the provided values, then update
grail/trainer/config.py to import these constants and use them as the default
values for the corresponding replay_buffer_* fields in TrainingConfig.

Comment on lines 100 to +105
# sgLang server options (used when backend == "sglang")
sglang_host: str = "127.0.0.1"
sglang_port: int = 30000
sglang_start_server: bool = True # Server runs in subprocess (avoids Gloo socket issues)
sglang_server_timeout_s: float = 120.0
sglang_trust_remote_code: bool = False
server_host: str = "127.0.0.1"
server_port: int = 30000
start_server: bool = True # Server runs in subprocess (avoids Gloo socket issues)
server_timeout: float = 120.0
server_trust_remote_code: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Update stale comment to reflect generic server naming.

The comment on line 100 still references "sgLang server options" but the fields below have been renamed to generic server_* naming (not sglang-specific). This is misleading.

Apply this diff to update the comment:

-    # sgLang server options (used when backend == "sglang")
+    # Server options (used when backend == "sglang" or other server backends)
     server_host: str = "127.0.0.1"
     server_port: int = 30000
     start_server: bool = True  # Server runs in subprocess (avoids Gloo socket issues)
     server_timeout: float = 120.0
     server_trust_remote_code: bool = False
🤖 Prompt for AI Agents
In grail/trainer/config.py around lines 100 to 105, the comment "sgLang server
options (used when backend == \"sglang\")" is stale and misleading given the
fields are now generic server_* names; update the comment to a generic
description like "Server options (used when backend == \"server\" or generic
server backends)" or similar to reflect generic server naming, keeping brief and
accurate.

idle_duration,
method_name,
)
await self._restart_subtensor()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Comment and log message claim timeout doubling that doesn't happen

The comment at line 230 and log message at line 240 claim the timeout is doubled when the connection has been idle, but the code only restarts the subtensor - it never actually modifies the timeout variable. This could lead to unexpected behavior where connections time out faster than intended after idle periods, despite the documentation suggesting otherwise.

Fix in Cursor Fix in Web

if self._should_run_evaluation():
logger.info("Evaluation due, coordinating with training process...")
await self._coordinate_evaluation(current_window)
self._windows_since_last_eval = 0
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Evaluation counter resets even when evaluation fails to start

In _orchestration_loop, _windows_since_last_eval is unconditionally reset to 0 after calling _coordinate_evaluation, even when _coordinate_evaluation returns early without running evaluation (e.g., if training process failed to pause within timeout). This causes evaluation to be incorrectly deferred for another full interval instead of being retried sooner.

Additional Locations (1)

Fix in Cursor Fix in Web

)

logger.info("Server shutdown complete")
return metrics
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Evaluator shutdown missing in server evaluation path

In _run_server_evaluation, the evaluator created at line 690 is never shut down after use. Unlike _run_direct_evaluation which calls _cleanup_evaluation_resources (including evaluator.shutdown()) in a finally block, _run_server_evaluation exits without any cleanup. The EvaluatorService.shutdown() method is documented as "Critical for freeing GPU memory" and "Must be called after evaluation completes."

Fix in Cursor Fix in Web

@erfanMhi erfanMhi merged commit 40cef73 into main Dec 8, 2025
7 of 8 checks passed
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (4)
grail/shared/constants.py (1)

50-85: Constants wiring and new async knobs look sound; just verify external expectations

The new defaults and flags (TRAINER_MAX_LENGTH=1024, TRAINER_USE_GRADIENT_CHECKPOINTING, snapshot/heartbeat/upload retry constants) are wired correctly and typed safely. This will change default sequence length and enable gradient checkpointing + async worker behaviour out of the box; please double‑check that any docs/CLI help and deployment configs are updated to reflect these new defaults and env names.

Also applies to: 247-257

grail/trainer/training_process.py (3)

245-277: Consider basing KL/ref‑model enablement on TrainingConfig instead of globals

Both _initialize_resources and _prepare_models_with_accelerator gate ref‑model loading and accelerator wrapping on the module‑level is_kl_enabled() helper, which reads TRAINER_KL_COEF from env. If callers ever override TrainingConfig.kl_coef at runtime (without changing the env), the env‑based check and the config value can diverge.

A slightly more robust pattern would be to make the decision based on the effective config:

-        load_ref_model = is_kl_enabled()
+        load_ref_model = self.config.kl_coef > 0.0
...
-        kl_enabled = is_kl_enabled()
+        kl_enabled = self.config.kl_coef > 0.0

This keeps behaviour consistent with whatever TrainingConfig the orchestrator passes in, while still defaulting from the shared constants.

Also applies to: 583-599


821-839: Optional: reuse metagraph to avoid redundant chain calls in _load_grpo_groups

_load_grpo_groups always fetches a fresh metagraph via self.subtensor.metagraph(NETUID) even though _get_trusted_miners does the same in the same iteration. For long training runs this is probably fine, but if metagraph fetches are expensive or rate‑limited, you could reduce load by reusing a single metagraph per loop iteration (e.g., pass it from _get_trusted_miners or cache it briefly in the service).

Not critical, but worth considering if you see metagraph calls becoming a bottleneck.

Also applies to: 850-873


448-481: Docstring and stop_event usage around initial checkpoint could be tightened

_upload_initial_checkpoint’s docstring says "Upload initial checkpoint and wait for miners to download", but the method now just saves the snapshot (upload handled by the worker) and does not touch stop_event. The actual wait happens in _wait_for_miners, where WAIT_WINDOW_LENGTH is currently hardcoded to 0.0 (no wait) with a TODO.

Not a functional bug, but to avoid confusion you might:

  • Update _upload_initial_checkpoint’s docstring to describe that it only saves the initial snapshot and lets the upload worker handle publishing.
  • Optionally use or remove the unused stop_event parameter there, or wire it into a future early‑exit hook.

This will make the intent of the startup sequence clearer to future readers.

Also applies to: 201-233

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ea6d15e and aca6ed2.

📒 Files selected for processing (4)
  • grail/shared/constants.py (3 hunks)
  • grail/trainer/__init__.py (1 hunks)
  • grail/trainer/service.py (0 hunks)
  • grail/trainer/training_process.py (1 hunks)
💤 Files with no reviewable changes (1)
  • grail/trainer/service.py
🧰 Additional context used
🧬 Code graph analysis (1)
grail/trainer/training_process.py (13)
grail/infrastructure/checkpoint_consumer.py (1)
  • CheckpointManager (87-562)
grail/infrastructure/network.py (1)
  • create_subtensor (330-404)
grail/model/train_loading.py (2)
  • ModelLoadSpec (30-41)
  • load_training_artifacts (149-190)
grail/monitoring/manager.py (2)
  • get_monitoring_manager (366-375)
  • initialize_monitoring (378-404)
grail/shared/constants.py (1)
  • is_kl_enabled (262-268)
grail/trainer/algorithms/base.py (1)
  • TrainingAlgorithm (14-40)
grail/trainer/config.py (1)
  • TrainingConfig (15-76)
grail/trainer/ipc.py (4)
  • IPCChannels (23-142)
  • update_heartbeat (62-68)
  • is_pause_requested (103-105)
  • queue_snapshot (122-142)
grail/trainer/replay_buffer.py (8)
  • ReplayBuffer (20-63)
  • create_replay_buffer (253-280)
  • get_stats (50-55)
  • get_stats (212-233)
  • add_window (29-35)
  • add_window (110-133)
  • sample_groups (38-47)
  • sample_groups (135-180)
grail/trainer/snapshot_manager.py (2)
  • SnapshotManager (26-223)
  • save_snapshot_atomic (59-127)
grail/neurons/miner.py (1)
  • run (52-308)
grail/monitoring/backends/wandb_backend.py (2)
  • initialize (99-137)
  • start_run (704-743)
grail/logging_utils.py (1)
  • configure_process_logging (227-296)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: Run Tests (3.10)
  • GitHub Check: Run Tests (3.11)
  • GitHub Check: Cursor Bugbot
  • GitHub Check: Run Tests (3.11)
  • GitHub Check: Run Tests (3.10)
🔇 Additional comments (1)
grail/trainer/__init__.py (1)

5-11: Updated trainer package surface looks consistent

Exporting checkpoint_publisher and replay_buffer (and dropping the old service/checkpointing modules) matches the new async trainer architecture; this keeps the public API aligned with how downstream code now imports these components.

Comment on lines +306 to +408
async def _initialize_monitoring(self) -> None:
"""Initialize monitoring after heavy resources are loaded.
This is called AFTER models and chain manager are initialized to avoid
resource contention during WandB connection.
"""
if not self.monitor_config:
logger.info("No monitoring config provided, skipping monitoring setup")
return

# FIRST: Test direct WandB connection (bypasses monitoring class)
import os

os.environ["WANDB_DISABLE_SERVICE"] = "false" # Ensure service is enabled
os.environ["WANDB_SERVICE"] = "" # Clear any inherited service path
try:
backend_type = self.monitor_config.get("backend_type", "wandb")
init_config = {k: v for k, v in self.monitor_config.items() if k != "backend_type"}

logger.debug(
"Initializing monitoring: backend_type=%s run_name=%s run_id=%s entity=%s project=%s mode=%s",
backend_type,
init_config.get("run_name"),
init_config.get("run_id"),
init_config.get("entity"),
init_config.get("project"),
init_config.get("mode"),
)
logger.debug(
"Full init_config keys received in training process: %s", list(init_config.keys())
)
# Verify critical parameters are present
if not init_config.get("entity"):
logger.warning("⚠️ WandB entity not set in training process - will use default")
if not init_config.get("project"):
logger.warning("⚠️ WandB project not set in training process - will use default")

if "run_id" in init_config:
logger.info(
"Resuming W&B run %s in training process for multi-process logging",
init_config["run_id"],
)

initialize_monitoring(backend_type=backend_type, **init_config)
self.monitor = get_monitoring_manager()
logger.info("Monitoring initialized in training process")

# Actually start the WandB run (connects to API, resumes run)
# This is critical - without this, the run is not actually initialized
run_name = init_config.get("run_name")
if run_name:
logger.info("Starting WandB run: %s", run_name)
start_time = time.time()
try:
actual_run_id = await self.monitor.start_run(run_name, init_config)
start_duration = time.time() - start_time
logger.info(
"✅ WandB run started successfully in %.1fs (run_id=%s)",
start_duration,
actual_run_id,
)
except Exception as start_exc:
start_duration = time.time() - start_time
logger.error(
"❌ Failed to start WandB run after %.1fs: %s",
start_duration,
start_exc,
)
logger.info(
"Hint: Increase WANDB_INIT_TIMEOUT (current: %s) if shared mode workers timeout",
init_config.get("init_timeout", "120"),
)
# Continue without monitoring
return

# Test WandB metric logging with a simple gauge
if self.monitor and run_name:
logger.info("Testing WandB metric logging...")
test_start = time.time()
try:
# Set current block context so metric appears in WandB properly
current_block = await self.subtensor.get_current_block()
current_window = (current_block // WINDOW_LENGTH) * WINDOW_LENGTH
self.monitor.set_block_context(current_block, current_window)
await self.monitor.log_gauge("training_process/connection_test", 1.0)
await self.monitor.flush_metrics()
test_duration = time.time() - test_start
logger.info(
"✅ WandB metric logging successful in %.3fs",
test_duration,
)
except Exception as test_exc:
test_duration = time.time() - test_start
logger.warning(
"⚠️ WandB metric logging test failed after %.3fs: %s (training will continue)",
test_duration,
test_exc,
)

except Exception as exc:
logger.warning("Failed to initialize monitoring in training process: %s", exc)
self.monitor = None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

On WandB start failure, monitoring is not actually disabled

In _initialize_monitoring, if self.monitor.start_run(...) raises, the code logs "Continue without monitoring" and returns, but self.monitor still points to a live MonitoringManager instance. The rest of the training flow checks if self.monitor: and will continue to call set_block_context, log_gauge, etc., which may keep failing or behave inconsistently because the run was never started.

Consider explicitly disabling monitoring in this failure path, e.g.:

-                except Exception as start_exc:
+                except Exception as start_exc:
                     start_duration = time.time() - start_time
                     logger.error(
                         "❌ Failed to start WandB run after %.1fs: %s",
                         start_duration,
                         start_exc,
                     )
                     logger.info(
                         "Hint: Increase WANDB_INIT_TIMEOUT (current: %s) if shared mode workers timeout",
                         init_config.get("init_timeout", "120"),
                     )
-                    # Continue without monitoring
-                    return
+                    # Continue without monitoring
+                    self.monitor = None
+                    return

This keeps later training logic from repeatedly exercising a half‑initialized monitoring backend.

🤖 Prompt for AI Agents
In grail/trainer/training_process.py around lines 306 to 408, when
self.monitor.start_run(...) raises the code currently logs and returns but
leaves self.monitor set to a live MonitoringManager; change the error path so
monitoring is explicitly disabled and cleaned up: call the monitor
shutdown/cleanup method if available (or drop any references) and set
self.monitor = None before returning so subsequent checks like if self.monitor:
won't attempt to use a half‑initialized monitoring backend; include a defensive
try/except around the cleanup call to ignore cleanup errors and log them at
debug level.

Comment on lines +571 to +600
def _prepare_models_with_accelerator(
self,
accelerator: Accelerator,
) -> tuple[Any, Any | None, torch.optim.Optimizer]:
"""Prepare models and optimizer with Accelerator for distributed training.
Args:
accelerator: Accelerator instance
Returns:
Tuple of (train_model, ref_model, optimizer)
"""
kl_enabled = is_kl_enabled()

if kl_enabled and self.ref_model is not None:
train_model, ref_model, optimizer = accelerator.prepare(
self.train_model,
self.ref_model,
self.optimizer,
)
if hasattr(ref_model, "eval"):
ref_model.eval()
else:
train_model, optimizer = accelerator.prepare(
self.train_model,
self.optimizer,
)
ref_model = None

return train_model, ref_model, optimizer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

accelerate library Accelerator.prepare optimizer in place mutation documentation

💡 Result:

Short answer: Accelerator.prepare (and prepare_optimizer) can replace or wrap the optimizer rather than mutate the original in-place — you should use the object returned by prepare/prepare_optimizer. The official docs show prepare returns the prepared optimizer (and recommend assigning it), and GH issues confirm optimizer state/identity can change in some distributed/offload/zero/DeepSpeed cases. [1][2]

Sources:

  • accelerate docs (Accelerator.prepare / prepare_optimizer) — shows prepare returns prepared objects and examples assigning the returned optimizer. [1]
  • GitHub issue reporting save/load problems after accelerator.prepare (shows optimizer can be altered by prepare in some setups). [2]

References:
[1] Hugging Face Accelerate — Accelerator.prepare / prepare_optimizer (docs).
[2] Issue: "Unable to save/load optimizer state after accelerator.prepare()" — huggingface/accelerate GitHub.


🏁 Script executed:

# First, let's examine the actual code at the mentioned lines
cat -n grail/trainer/training_process.py | sed -n '571,600p'

Repository: one-covenant/grail

Length of output: 1228


🏁 Script executed:

# Also check the _handle_pause method at lines 885-942
cat -n grail/trainer/training_process.py | sed -n '885,942p'

Repository: one-covenant/grail

Length of output: 2278


🏁 Script executed:

# Let's check if there's any direct reference to how self.optimizer is used after _prepare_models_with_accelerator
rg "self\.optimizer" grail/trainer/training_process.py -n -B2 -A2 | head -100

Repository: one-covenant/grail

Length of output: 1485


Use the optimizer returned by accelerator.prepare()—don't rely on in-place mutation

The accelerator.prepare() method can wrap or replace the optimizer rather than mutate it in place. The current code captures the returned optimizer in _prepare_models_with_accelerator but continues to use self.optimizer in _train_epoch (line 933), which may diverge from the prepared optimizer in distributed, DeepSpeed, or offload configurations.

Assign the prepared optimizer back to self.optimizer to ensure the rest of the code uses the correct instance:

    def _prepare_models_with_accelerator(
        self,
        accelerator: Accelerator,
    ) -> tuple[Any, Any | None, torch.optim.Optimizer]:
        ...
        if kl_enabled and self.ref_model is not None:
            train_model, ref_model, optimizer = accelerator.prepare(
                self.train_model,
                self.ref_model,
                self.optimizer,
            )
        else:
            train_model, optimizer = accelerator.prepare(
                self.train_model,
                self.optimizer,
            )
            ref_model = None
+
+        self.optimizer = optimizer
         return train_model, ref_model, optimizer

Apply the same fix in _handle_pause where the optimizer is re-prepared after pause events.

🤖 Prompt for AI Agents
In grail/trainer/training_process.py around lines 571-600, the prepared
optimizer returned from accelerator.prepare() is captured locally but not
assigned back to the instance; update the method to assign the returned
optimizer to self.optimizer so subsequent code uses the wrapped/replaced
optimizer (e.g., self.optimizer = optimizer), and do the same in the
pause-handling code path where the optimizer is re-prepared after pauses to
ensure the instance holds the prepared optimizer rather than the original.

Comment on lines +677 to +717
# Close subtensor connection before long idle period (prevents 10s timeout)
if self.subtensor:
try:
if hasattr(self.subtensor, "_subtensor"):
# Unwrap ResilientSubtensor to get underlying subtensor
inner = object.__getattribute__(self.subtensor, "_subtensor")
if hasattr(inner, "close"):
await inner.close()
elif hasattr(self.subtensor, "close"):
await self.subtensor.close()
logger.info("Closed subtensor connection during pause to prevent idle timeout")
except Exception as e:
logger.warning("Failed to close subtensor during pause: %s", e)
self.subtensor = None

# Wait for pause flag to be cleared via IPC (primary) or filesystem (fallback)
while not stop_event.is_set():
# Check if resume signal received
pause_still_active = (
self._ipc.is_pause_requested()
if self._ipc is not None
else self.snapshot_manager.check_pause_flag()
)
if not pause_still_active:
break
await asyncio.sleep(PAUSE_CHECK_INTERVAL_SECONDS)

if stop_event.is_set():
return train_model_cpu, ref_model_cpu, optimizer

# Clear the confirmed event for next cycle
if self._ipc is not None:
self._ipc.clear_pause_confirmed()

logger.info("🔄 STATE: resume_requested - pause signal cleared")

# Recreate subtensor connection after pause
logger.info("Recreating subtensor connection after pause...")
self.subtensor = await create_subtensor(resilient=True)
logger.info("Subtensor connection recreated successfully")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Recreating subtensor on pause does not refresh GrailChainManager

During _handle_pause, you correctly close the existing subtensor and then recreate a fresh one:

self.subtensor = await create_subtensor(resilient=True)

However, self.chain_manager was constructed once in _initialize_resources with the original subtensor and metagraph, and is not reinitialized here. After a pause/resume cycle, any chain interactions inside load_grpo_groups (via self.chain_manager) may still be using the old, closed subtensor, leading to repeated GRPO load failures and effectively stalling training after the first pause.

To keep chain interactions aligned with the new subtensor, you likely want to reinitialize the chain manager in this block, similar to _initialize_resources:

         logger.info("Recreating subtensor connection after pause...")
         self.subtensor = await create_subtensor(resilient=True)
-        logger.info("Subtensor connection recreated successfully")
+        logger.info("Subtensor connection recreated successfully")
+
+        # Reinitialize chain manager with fresh subtensor/metagraph
+        try:
+            logger.info("Reinitializing chain manager after pause...")
+            metagraph = await self.subtensor.metagraph(NETUID)
+            chain_config = SimpleNamespace(netuid=NETUID)
+            self.chain_manager = GrailChainManager(
+                chain_config,
+                self.wallet,
+                metagraph,
+                self.subtensor,
+                self.credentials,
+            )
+            await self.chain_manager.initialize()
+            logger.info("Chain manager reinitialized after pause")
+        except Exception as exc:  # noqa: BLE001
+            logger.error("Failed to reinitialize chain manager after pause: %s", exc)

Without this, pause/resume may quietly break subsequent GRPO loading and chain‑dependent logic.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants