Skip to content

Conversation

@racimrl
Copy link
Contributor

@racimrl racimrl commented Sep 13, 2025

Description

Adds automatic CPU fallback when GPU runs out of memory during gradient decompression for large parameters

Related Issue(s)

  • Closes #[issue number]

Type of Change

  • Feature (adding new functionality)
  • Fix (resolving a bug or issue)
  • Docs (documentation updates)
  • Refactor (code changes that don't affect functionality)
  • Maintenance (dependency updates or other maintenance)
  • Tests (adding or improving tests)
  • Breaking change (fix or feature with incompatible API changes)
  • Other: _____

Branch Naming

  • My branch follows the project's naming convention (e.g., feature/add-new-capability)

Commit Messages

  • My commits are small, atomic, and have proper commit messages
  • Commit messages are in imperative mood with a capitalized summary under 50 chars

Code Quality

  • I've performed a self-review of my code
  • I've added appropriate docstrings following the project's conventions
  • I've added proper logging where necessary (without trailing periods)
  • I've applied linting and formatting with Ruff
  • My code generates no new warnings

Testing

  • I've added tests for new functionality or bug fixes
  • All tests pass locally with my changes
  • Test coverage has not decreased

Documentation

  • I've updated documentation to reflect my changes
  • I've updated comments in hard-to-understand areas

If this is a breaking change

Screenshots/Examples

Additional Notes

Summary by CodeRabbit

  • New Features

    • Automatic GPU memory diagnostics and proactive cleanup across evaluation and checkpointing.
    • Clearer memory usage logging for easier troubleshooting on master nodes.
  • Bug Fixes

    • Reduced CUDA/NCCL out-of-memory errors and related hangs.
    • More reliable checkpoint save/load with graceful recovery and retries.
    • Evaluation now skips windows/UIDs when resources are insufficient, preventing crashes.
  • Chores

    • Safer default CUDA/NCCL runtime settings applied at startup to improve stability and reduce fragmentation.

@coderabbitai
Copy link

coderabbitai bot commented Sep 13, 2025

Walkthrough

Introduces global CUDA/NCCL environment defaults, adds an internal GPU memory check/cleanup helper, and integrates proactive memory management across evaluation and checkpoint paths. Enhances OOM/NCCL error handling with retries/skip logic, adds pre/post-operation cache clears, and expands memory usage logging. Minor non-semantic code reshaping accompanies these changes.

Changes

Cohort / File(s) Summary
Global CUDA/NCCL configuration
neurons/validator.py
Sets PYTORCH_CUDA_ALLOC_CONF and NCCL-related env vars at import/init time to standardize allocator and communication behavior.
Memory helper utility
neurons/validator.py
Adds Validator._check_memory_and_cleanup(context) to inspect allocated/reserved/total GPU memory, perform empty_cache/gc/sync, and gate operations based on utilization thresholds with diagnostics.
Evaluation flow hardening
neurons/validator.py
Inserts pre-checks before baseline and per-UID evaluation; adds OOM retry-once logic; conditionally skips windows/UIDs; performs cleanup after each UID and after window completion; logs memory after key steps.
Checkpoint save/load safeguards
neurons/validator.py
Wraps FSDP checkpoint save and load with pre-checks, explicit cleanup, NCCL/OOM exception handling, retry/skip decisions, and post-operation cache/GC.
Diagnostics and minor refactors
neurons/validator.py
Adds debug-level memory usage logs at master rank; minor arithmetic line reflows and comments without semantic changes; explicit cleanup after state transitions.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Eval as Validator.evaluate_window
  participant GPU as Torch CUDA
  participant Log as Logger

  Eval->>Eval: _check_memory_and_cleanup("pre-baseline")
  alt Insufficient memory
    Eval->>Log: Skip baseline window
    Note right of Eval: Window evaluation skipped
  else Sufficient
    Eval->>Eval: run_baseline()
    Note over Eval: Wrapped in try/except CUDA OOM
    alt CUDA OOM on baseline
      Eval->>GPU: empty_cache + gc + sync
      Eval->>Eval: retry once
      alt OOM again
        Eval->>Log: Skip window due to OOM
      else Success
        Eval->>Log: Baseline complete
      end
    else Success
      Eval->>Log: Baseline complete
    end

    loop For each UID
      Eval->>Eval: _check_memory_and_cleanup("pre-uid")
      alt Insufficient memory
        Eval->>Log: Skip UID
      else Sufficient
        Eval->>Eval: evaluate_uid(uid)
        Eval->>GPU: empty_cache + gc + sync (post-UID)
        Eval->>Log: Report mem usage
      end
    end

    Eval->>GPU: empty_cache + gc + sync (post-window)
  end
Loading
sequenceDiagram
  autonumber
  participant V as Validator
  participant GPU as Torch CUDA
  participant IO as Checkpoint I/O
  participant Log as Logger

  rect rgba(200,230,255,0.3)
  Note over V,GPU: Checkpoint Load
  V->>GPU: empty_cache + gc
  V->>V: _check_memory_and_cleanup("pre-load")
  alt Insufficient memory
    V->>Log: Abort load (insufficient memory)
  else Sufficient
    V->>IO: load_checkpoint()
    alt OOM/NCCL error
      V->>GPU: empty_cache + gc + sync
      V->>IO: retry load once
      alt Fails again
        V->>Log: Load failed, escalate/skip
      else Success
        V->>Log: Load succeeded on retry
      end
    else Success
      V->>Log: Load succeeded
    end
    V->>GPU: empty_cache + gc (post-load)
  end
  end

  rect rgba(220,255,220,0.3)
  Note over V,GPU: Checkpoint Save
  V->>V: _check_memory_and_cleanup("pre-save")
  alt Insufficient memory
    V->>Log: Skip/Delay save
  else Sufficient
    V->>IO: save_checkpoint()
    alt OOM/NCCL error
      V->>GPU: empty_cache + gc + sync
      V->>Log: Save failed, continue safely
    else Success
      V->>Log: Save completed
    end
  end
  end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

I tidied the VRAM with a broom of light,
Swept caches clean in the hush of night.
If OOM gremlins hop—retry, then shoo!
Checkpoints nest snug, NCCL in view.
Ears up, I log every byte and beat—
A careful bunny keeps GPUs neat. 🐇💾

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings, 1 inconclusive)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The PR contains a brief Description asserting CPU fallback behavior but leaves the repository template checklist mostly unchecked and omits metadata and validation details: no related issue link, no Type of Change selected, no branch/commit confirmations, and no testing or documentation statements, and the raw_summary shows additional memory-management work that is not described in the PR body, so the description is incomplete for review. Please complete the template before merging by selecting the appropriate "Type of Change", linking any related issue(s), confirming branch naming and commit message compliance, and describing testing performed (add tests if missing); also expand the Description with rationale, a short implementation summary (which files/functions implement CPU fallback versus memory-cleanup), and any compatibility or performance considerations so reviewers can validate the change.
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title Check ❓ Inconclusive The PR title "CPU Fallback for Decompression on OOM" matches the author description and PR objectives claiming an automatic CPU fallback, but the provided file-level summary (neurons/validator.py) only documents broad memory-management additions and does not mention an explicit CPU-decompression fallback, creating a discrepancy that prevents confidently confirming the title accurately reflects the primary code changes. Please clarify by either updating the title to reflect the visible memory-management changes in neurons/validator.py or by expanding the PR description/raw_summary to point to the exact files and functions that implement the CPU fallback (include code references or diffs) so reviewers can verify the title is correct.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
neurons/validator.py (2)

1697-1697: Guard torch.cuda.empty_cache() when CUDA is unavailable.

Direct calls will raise on CPU‑only runs.

-            torch.cuda.empty_cache()
+            if torch.cuda.is_available():
+                torch.cuda.empty_cache()

Apply similarly to the other occurrences in these ranges.

Also applies to: 1811-1811, 2351-2351, 2407-2407


3135-3159: Missing core PR objective: add CPU fallback for gradient decompression on OOM (and avoid per‑param empty_cache).

Currently decompress/transform always on GPU; on OOM you only clear cache. Implement CPU fallback with chunked update for non‑DT tensors and CPU source for DT distribution.

-                    # Check if we're in distributed mode
-                    # Use empty_like to avoid copying the param; just provide dtype/device/shape
-                    ref = torch.empty_like(p, device=self.device, dtype=p.dtype)
-
-                    decompressed = self.compressor.decompress(
-                        ref,
-                        idxs,
-                        vals,
-                        self.xshapes[n],
-                        self.totalks[n],
-                        quant_params,
-                    )
-
-                    full_grad_src = self.transformer.decode(
-                        decompressed, use_dct=self.hparams.use_dct
-                    )
-                    # Single conversion to target dtype+device to avoid extra temporaries
-                    full_grad_src = full_grad_src.to(
-                        dtype=p.dtype, device=p.device, non_blocking=True
-                    )
-
-                    # Free intermediate pieces ASAP
-                    del ref, decompressed
-                    # Force immediate cleanup
-                    torch.cuda.empty_cache()
+                    # Try GPU path first, then CPU fallback on OOM
+                    applied_update_eagerly = False  # non‑DT CPU fallback updates in‑place
+                    try:
+                        ref = torch.empty_like(p, device=p.device, dtype=p.dtype)
+                        decompressed = self.compressor.decompress(
+                            ref, idxs, vals, self.xshapes[n], self.totalks[n], quant_params
+                        )
+                        full_grad_src = self.transformer.decode(
+                            decompressed, use_dct=self.hparams.use_dct
+                        ).to(dtype=p.dtype, device=p.device, non_blocking=True)
+                        del ref, decompressed
+                    except (torch.cuda.OutOfMemoryError, RuntimeError) as oom:
+                        if "out of memory" not in str(oom).lower():
+                            raise
+                        if self.is_master:
+                            tplr.log_with_context(
+                                level="warning",
+                                message=f"GPU OOM while decompressing {n}; falling back to CPU.",
+                                sync_window=self.sync_window,
+                                current_window=self.current_window,
+                                eval_uid=eval_uid,
+                            )
+                        if torch.cuda.is_available():
+                            torch.cuda.empty_cache()
+                        # CPU fallback: decompress+decode on CPU
+                        ref_cpu = torch.empty_like(p, device="cpu", dtype=p.dtype)
+                        decompressed_cpu = self.compressor.decompress(
+                            ref_cpu, idxs, vals, self.xshapes[n], self.totalks[n], quant_params
+                        )
+                        full_grad_cpu = self.transformer.decode(
+                            decompressed_cpu, use_dct=self.hparams.use_dct
+                        ).to(dtype=p.dtype, device="cpu")
+                        del ref_cpu, decompressed_cpu
+                        if isinstance(p, DT):
+                            # DT case: keep CPU tensor for distribute_tensor below
+                            full_grad_src = full_grad_cpu
+                        else:
+                            # Non‑DT: apply update in chunks to avoid large GPU alloc
+                            alpha = self.lr * self.hparams.eval_lr_factor
+                            flat_cpu = full_grad_cpu.view(-1)
+                            flat_param = p.data.view(-1)
+                            elem_size = flat_cpu.element_size()
+                            CHUNK_BYTES = 64 * 1024 * 1024  # ~64MB per copy
+                            chunk_elems = max(1, CHUNK_BYTES // elem_size)
+                            for start in range(0, flat_cpu.numel(), chunk_elems):
+                                end = min(start + chunk_elems, flat_cpu.numel())
+                                chunk = flat_cpu[start:end].to(p.device, non_blocking=True)
+                                flat_param[start:end].sub_(chunk, alpha=alpha)
+                                del chunk
+                            del full_grad_cpu
+                            applied_update_eagerly = True

And skip the direct non‑DT subtract when the CPU fallback already applied:

-                else:
-                    # Single GPU case (non-DTensor)
-                    if on_src:
-                        p.data.sub_(
-                            full_grad_src,
-                            alpha=self.lr * self.hparams.eval_lr_factor,
-                        )
-                        del full_grad_src
+                else:
+                    # Single GPU case (non-DTensor)
+                    if on_src and not applied_update_eagerly:
+                        p.data.sub_(
+                            full_grad_src,
+                            alpha=self.lr * self.hparams.eval_lr_factor,
+                        )
+                        del full_grad_src

Also remove the per‑param empty_cache() calls in this hot path—they hurt perf without preventing fragmentation.

-                    # Force immediate cleanup
-                    torch.cuda.empty_cache()
@@
-                    # Force cleanup of large tensors
-                    torch.cuda.empty_cache()

Outside this range, add applied_update_eagerly = False at the start of each param iteration (see note below).

Also applies to: 3237-3243

🧹 Nitpick comments (5)
neurons/validator.py (5)

68-82: Don't set CUDA/NCCL env vars after importing torch; drop risky defaults and avoid hard‑coding NIC.

  • PYTORCH_CUDA_ALLOC_CONF is read when CUDA is initialized; here torch and torch.cuda are already imported/used, so this has no effect.
  • For NCCL, forcing IB off, P2P off, and IFNAME=eth0 can silently degrade perf or break multi‑node setups.
  • Duplicate import of os.

Recommend: move env setup to the process entrypoint before importing torch/torch.distributed, and gate via config flags (off by default). Remove the hard‑coded NIC and leave NCCL envs to deployment.

Apply at least this minimal cleanup:

-# Set CUDA memory allocator configuration to prevent fragmentation
-import os
-if not os.environ.get('PYTORCH_CUDA_ALLOC_CONF'):
-    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
-
-# Set NCCL environment variables for better memory management and debugging
-if not os.environ.get('NCCL_DEBUG'):
-    os.environ['NCCL_DEBUG'] = 'WARN'  # Change to INFO for more verbose debugging
-if not os.environ.get('NCCL_IB_DISABLE'):
-    os.environ['NCCL_IB_DISABLE'] = '1'  # Disable InfiniBand to use Ethernet
-if not os.environ.get('NCCL_SOCKET_IFNAME'):
-    os.environ['NCCL_SOCKET_IFNAME'] = 'eth0'  # Use ethernet interface
-if not os.environ.get('NCCL_P2P_DISABLE'):
-    os.environ['NCCL_P2P_DISABLE'] = '1'  # Disable P2P to reduce memory pressure
+# NOTE: CUDA/NCCL env should be set before importing torch/initializing CUDA.
+# Move these to the launcher/entrypoint behind explicit flags.

1330-1360: Retry path claims “smaller batch size” but does not change it.

Second attempt calls the exact same loader; likely OOMs again.

Either actually rebuild the dataloader with a smaller micro_batch_size or remove the misleading comment. I can wire a temporary halved micro_bs retry if you want.


3085-3094: Support var for CPU‑fallback flow control.

Initialize applied_update_eagerly = False per parameter so the non‑DT path can skip the second subtract after chunked CPU update.

Add just after has_valid_gradient = True:

applied_update_eagerly = False

3158-3159: Remove empty_cache() in tight per‑param loop.

This thrashes the allocator and slows everything down; it doesn’t free reserved blocks anyway.

Also applies to: 3211-3212


245-264: Minor: consolidate GC/empty_cache fences.

These pre/post fences are fine but repeated; consider a small helper (noop on CPU) to reduce duplication.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3dbf768 and df9a90a.

📒 Files selected for processing (1)
  • neurons/validator.py (17 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
neurons/validator.py (4)
src/tplr/logging.py (2)
  • log_with_context (207-223)
  • log_with_context (290-309)
src/tplr/sharded_sampler.py (1)
  • set_window_uid (74-79)
neurons/trainer.py (1)
  • evaluate_model (340-392)
src/tplr/dcp_checkpoint.py (3)
  • save_local_async (222-268)
  • upload (271-404)
  • download_and_load (657-680)

Comment on lines +184 to +240
def _check_memory_and_cleanup(self, context: str = "unknown") -> bool:
"""Check GPU memory usage and perform cleanup if needed.
Returns True if memory is available, False if critically low.
"""
if not torch.cuda.is_available():
return True

# Get memory info
memory_allocated = torch.cuda.memory_allocated()
memory_reserved = torch.cuda.memory_reserved()
memory_total = torch.cuda.get_device_properties(0).total_memory

# Calculate usage percentages
allocated_pct = (memory_allocated / memory_total) * 100
reserved_pct = (memory_reserved / memory_total) * 100

# Log memory usage
if self.is_master:
tplr.log_with_context(
level="debug",
message=f"Memory check ({context}) - Allocated: {allocated_pct:.1f}% ({memory_allocated/1024**3:.2f}GB), "
f"Reserved: {reserved_pct:.1f}% ({memory_reserved/1024**3:.2f}GB)",
sync_window=self.sync_window,
current_window=self.current_window,
)

# If allocated memory is over 85%, perform cleanup
if allocated_pct > 85.0:
if self.is_master:
tplr.log_with_context(
level="warning",
message=f"High memory usage detected ({allocated_pct:.1f}%), performing cleanup",
sync_window=self.sync_window,
current_window=self.current_window,
)
torch.cuda.empty_cache()
import gc
gc.collect()
torch.cuda.synchronize()

# Check again after cleanup
new_allocated = torch.cuda.memory_allocated()
new_allocated_pct = (new_allocated / memory_total) * 100

# If still critically low after cleanup, return False
if new_allocated_pct > 90.0:
if self.is_master:
tplr.log_with_context(
level="error",
message=f"Critical memory usage even after cleanup ({new_allocated_pct:.1f}%)",
sync_window=self.sync_window,
current_window=self.current_window,
)
return False

return True

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

GPU memory check uses device 0; fix for multi‑GPU and avoid false signals.

Current code queries device 0 regardless of local rank, skewing thresholds and skip logic.

Suggested fix:

-        memory_allocated = torch.cuda.memory_allocated()
-        memory_reserved = torch.cuda.memory_reserved()
-        memory_total = torch.cuda.get_device_properties(0).total_memory
+        # Use the current CUDA device (or self.device) to be rank-correct
+        device_idx = (
+            self.device.index
+            if isinstance(self.device, torch.device)
+            and self.device.type == "cuda"
+            and self.device.index is not None
+            else torch.cuda.current_device()
+        )
+        memory_allocated = torch.cuda.memory_allocated(device_idx)
+        memory_reserved = torch.cuda.memory_reserved(device_idx)
+        memory_total = torch.cuda.get_device_properties(device_idx).total_memory
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _check_memory_and_cleanup(self, context: str = "unknown") -> bool:
"""Check GPU memory usage and perform cleanup if needed.
Returns True if memory is available, False if critically low.
"""
if not torch.cuda.is_available():
return True
# Get memory info
memory_allocated = torch.cuda.memory_allocated()
memory_reserved = torch.cuda.memory_reserved()
memory_total = torch.cuda.get_device_properties(0).total_memory
# Calculate usage percentages
allocated_pct = (memory_allocated / memory_total) * 100
reserved_pct = (memory_reserved / memory_total) * 100
# Log memory usage
if self.is_master:
tplr.log_with_context(
level="debug",
message=f"Memory check ({context}) - Allocated: {allocated_pct:.1f}% ({memory_allocated/1024**3:.2f}GB), "
f"Reserved: {reserved_pct:.1f}% ({memory_reserved/1024**3:.2f}GB)",
sync_window=self.sync_window,
current_window=self.current_window,
)
# If allocated memory is over 85%, perform cleanup
if allocated_pct > 85.0:
if self.is_master:
tplr.log_with_context(
level="warning",
message=f"High memory usage detected ({allocated_pct:.1f}%), performing cleanup",
sync_window=self.sync_window,
current_window=self.current_window,
)
torch.cuda.empty_cache()
import gc
gc.collect()
torch.cuda.synchronize()
# Check again after cleanup
new_allocated = torch.cuda.memory_allocated()
new_allocated_pct = (new_allocated / memory_total) * 100
# If still critically low after cleanup, return False
if new_allocated_pct > 90.0:
if self.is_master:
tplr.log_with_context(
level="error",
message=f"Critical memory usage even after cleanup ({new_allocated_pct:.1f}%)",
sync_window=self.sync_window,
current_window=self.current_window,
)
return False
return True
def _check_memory_and_cleanup(self, context: str = "unknown") -> bool:
"""Check GPU memory usage and perform cleanup if needed.
Returns True if memory is available, False if critically low.
"""
if not torch.cuda.is_available():
return True
# Use the current CUDA device (or self.device) to be rank-correct
device_idx = (
self.device.index
if isinstance(self.device, torch.device)
and self.device.type == "cuda"
and self.device.index is not None
else torch.cuda.current_device()
)
# Get memory info
memory_allocated = torch.cuda.memory_allocated(device_idx)
memory_reserved = torch.cuda.memory_reserved(device_idx)
memory_total = torch.cuda.get_device_properties(device_idx).total_memory
# Calculate usage percentages
allocated_pct = (memory_allocated / memory_total) * 100
reserved_pct = (memory_reserved / memory_total) * 100
# Log memory usage
if self.is_master:
tplr.log_with_context(
level="debug",
message=f"Memory check ({context}) - Allocated: {allocated_pct:.1f}% ({memory_allocated/1024**3:.2f}GB), "
f"Reserved: {reserved_pct:.1f}% ({memory_reserved/1024**3:.2f}GB)",
sync_window=self.sync_window,
current_window=self.current_window,
)
# If allocated memory is over 85%, perform cleanup
if allocated_pct > 85.0:
if self.is_master:
tplr.log_with_context(
level="warning",
message=f"High memory usage detected ({allocated_pct:.1f}%), performing cleanup",
sync_window=self.sync_window,
current_window=self.current_window,
)
torch.cuda.empty_cache()
import gc
gc.collect()
torch.cuda.synchronize()
# Check again after cleanup
new_allocated = torch.cuda.memory_allocated(device_idx)
new_allocated_pct = (new_allocated / memory_total) * 100
# If still critically low after cleanup, return False
if new_allocated_pct > 90.0:
if self.is_master:
tplr.log_with_context(
level="error",
message=f"Critical memory usage even after cleanup ({new_allocated_pct:.1f}%)",
sync_window=self.sync_window,
current_window=self.current_window,
)
return False
return True
🤖 Prompt for AI Agents
In neurons/validator.py around lines 184 to 240, the GPU memory checks always
query device 0 which misreports usage on multi-GPU setups; change the code to
target the active device (e.g., use torch.cuda.current_device() or
self.local_rank/self.device if available), call memory APIs with that device
(torch.cuda.memory_allocated(device), torch.cuda.memory_reserved(device),
torch.cuda.get_device_properties(device)), perform
torch.cuda.synchronize(device) when cleaning, and include the device id in log
messages; also add a safe fallback (try/except) to device 0 if determining the
active device fails.

Comment on lines +1318 to +1328
# Check memory before baseline evaluation
if not self._check_memory_and_cleanup("before_baseline_eval"):
if self.is_master:
tplr.log_with_context(
level="critical",
message="Insufficient memory for baseline evaluation. Skipping evaluation window.",
sync_window=self.sync_window,
current_window=self.current_window,
)
continue

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Barrier/consensus needed when skipping baseline eval due to low memory.

Unilaterally continue risks desync: other ranks will hit later barriers and deadlock.

Use an all‑reduce style consensus:

-            if not self._check_memory_and_cleanup("before_baseline_eval"):
-                if self.is_master:
-                    tplr.log_with_context(
-                        level="critical",
-                        message="Insufficient memory for baseline evaluation. Skipping evaluation window.",
-                        sync_window=self.sync_window,
-                        current_window=self.current_window,
-                    )
-                continue
+            mem_ok_local = self._check_memory_and_cleanup("before_baseline_eval")
+            mem_ok = dist_helper.all_ok(
+                mem_ok_local, self.device, tag="mem_ok_before_baseline"
+            )
+            if not mem_ok:
+                if self.is_master:
+                    tplr.log_with_context(
+                        level="critical",
+                        message="Insufficient memory for baseline evaluation on one or more ranks. Skipping this window.",
+                        sync_window=self.sync_window,
+                        current_window=self.current_window,
+                    )
+                continue
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Check memory before baseline evaluation
if not self._check_memory_and_cleanup("before_baseline_eval"):
if self.is_master:
tplr.log_with_context(
level="critical",
message="Insufficient memory for baseline evaluation. Skipping evaluation window.",
sync_window=self.sync_window,
current_window=self.current_window,
)
continue
# Check memory before baseline evaluation
mem_ok_local = self._check_memory_and_cleanup("before_baseline_eval")
mem_ok = dist_helper.all_ok(
mem_ok_local, self.device, tag="mem_ok_before_baseline"
)
if not mem_ok:
if self.is_master:
tplr.log_with_context(
level="critical",
message="Insufficient memory for baseline evaluation on one or more ranks. Skipping this window.",
sync_window=self.sync_window,
current_window=self.current_window,
)
continue

Comment on lines +1421 to +1432
# Check memory before each UID evaluation
if not self._check_memory_and_cleanup(f"before_uid_{eval_uid}"):
if self.is_master:
tplr.log_with_context(
level="warning",
message=f"Insufficient memory for UID {eval_uid} evaluation. Skipping this UID.",
sync_window=self.sync_window,
current_window=self.current_window,
eval_uid=eval_uid,
)
continue

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Per‑UID memory gate must be rank‑aligned to avoid broadcast/barrier hangs.

A single‑rank continue before gradient validity broadcast will deadlock others.

-                if not self._check_memory_and_cleanup(f"before_uid_{eval_uid}"):
+                mem_ok_local = self._check_memory_and_cleanup(f"before_uid_{eval_uid}")
+                mem_ok = dist_helper.all_ok(
+                    mem_ok_local, self.device, tag=f"mem_ok_uid_{eval_uid}"
+                )
+                if not mem_ok:
                     if self.is_master:
                         tplr.log_with_context(
                             level="warning",
                             message=f"Insufficient memory for UID {eval_uid} evaluation. Skipping this UID.",
                             sync_window=self.sync_window,
                             current_window=self.current_window,
                             eval_uid=eval_uid,
                         )
                     continue
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Check memory before each UID evaluation
if not self._check_memory_and_cleanup(f"before_uid_{eval_uid}"):
if self.is_master:
tplr.log_with_context(
level="warning",
message=f"Insufficient memory for UID {eval_uid} evaluation. Skipping this UID.",
sync_window=self.sync_window,
current_window=self.current_window,
eval_uid=eval_uid,
)
continue
# Check memory before each UID evaluation
mem_ok_local = self._check_memory_and_cleanup(f"before_uid_{eval_uid}")
mem_ok = dist_helper.all_ok(
mem_ok_local, self.device, tag=f"mem_ok_uid_{eval_uid}"
)
if not mem_ok:
if self.is_master:
tplr.log_with_context(
level="warning",
message=f"Insufficient memory for UID {eval_uid} evaluation. Skipping this UID.",
sync_window=self.sync_window,
current_window=self.current_window,
eval_uid=eval_uid,
)
continue

Comment on lines +2368 to +2399
try:
# Check memory before FSDP checkpoint save
self._check_memory_and_cleanup(operation="FSDP_checkpoint_save")

handle = await self.ckpt.save_local_async(
model=self.model,
window=self.sync_window,
sync_window=self.sync_window,
topology="FSDP",
)

# Schedule an upload that will wait for the save to finish, then upload in background
await self.ckpt.upload(
window=self.sync_window,
background=True,
delete_local_on_success=True,
wait_for=handle,
)
# Schedule an upload that will wait for the save to finish, then upload in background
await self.ckpt.upload(
window=self.sync_window,
background=True,
delete_local_on_success=True,
wait_for=handle,
)

# Clean up after checkpoint save
torch.cuda.empty_cache()

except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
if "CUDA out of memory" in str(e) or "NCCL" in str(e):
tplr.logger.warning(f"CUDA/NCCL error during FSDP checkpoint save at step {self.global_step}: {e}")
# Force memory cleanup and continue
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.synchronize()
tplr.logger.info("Continuing without checkpoint save due to memory/NCCL constraints")
else:
raise
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Wrong keyword passed to _check_memory_and_cleanup().

_check_memory_and_cleanup() takes context, not operation; this will raise TypeError at runtime.

-                    self._check_memory_and_cleanup(operation="FSDP_checkpoint_save")
+                    self._check_memory_and_cleanup(context="FSDP_checkpoint_save")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
# Check memory before FSDP checkpoint save
self._check_memory_and_cleanup(operation="FSDP_checkpoint_save")
handle = await self.ckpt.save_local_async(
model=self.model,
window=self.sync_window,
sync_window=self.sync_window,
topology="FSDP",
)
# Schedule an upload that will wait for the save to finish, then upload in background
await self.ckpt.upload(
window=self.sync_window,
background=True,
delete_local_on_success=True,
wait_for=handle,
)
# Schedule an upload that will wait for the save to finish, then upload in background
await self.ckpt.upload(
window=self.sync_window,
background=True,
delete_local_on_success=True,
wait_for=handle,
)
# Clean up after checkpoint save
torch.cuda.empty_cache()
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
if "CUDA out of memory" in str(e) or "NCCL" in str(e):
tplr.logger.warning(f"CUDA/NCCL error during FSDP checkpoint save at step {self.global_step}: {e}")
# Force memory cleanup and continue
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.synchronize()
tplr.logger.info("Continuing without checkpoint save due to memory/NCCL constraints")
else:
raise
try:
# Check memory before FSDP checkpoint save
self._check_memory_and_cleanup(context="FSDP_checkpoint_save")
handle = await self.ckpt.save_local_async(
model=self.model,
window=self.sync_window,
sync_window=self.sync_window,
topology="FSDP",
)
# Schedule an upload that will wait for the save to finish, then upload in background
await self.ckpt.upload(
window=self.sync_window,
background=True,
delete_local_on_success=True,
wait_for=handle,
)
# Clean up after checkpoint save
torch.cuda.empty_cache()
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
if "CUDA out of memory" in str(e) or "NCCL" in str(e):
tplr.logger.warning(f"CUDA/NCCL error during FSDP checkpoint save at step {self.global_step}: {e}")
# Force memory cleanup and continue
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.synchronize()
tplr.logger.info("Continuing without checkpoint save due to memory/NCCL constraints")
else:
raise
🤖 Prompt for AI Agents
In neurons/validator.py around lines 2368 to 2399, the call to
_check_memory_and_cleanup uses the wrong keyword argument name (operation) which
will raise a TypeError; change the call to use the expected keyword context,
e.g. self._check_memory_and_cleanup(context="FSDP_checkpoint_save"), leaving the
rest of the logic unchanged so memory is checked/cleaned before the FSDP
checkpoint save.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant