-
Notifications
You must be signed in to change notification settings - Fork 49
CPU Fallback for Decompression on OOM #587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Enhance memory management for CUDA and NCCL
WalkthroughIntroduces global CUDA/NCCL environment defaults, adds an internal GPU memory check/cleanup helper, and integrates proactive memory management across evaluation and checkpoint paths. Enhances OOM/NCCL error handling with retries/skip logic, adds pre/post-operation cache clears, and expands memory usage logging. Minor non-semantic code reshaping accompanies these changes. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Eval as Validator.evaluate_window
participant GPU as Torch CUDA
participant Log as Logger
Eval->>Eval: _check_memory_and_cleanup("pre-baseline")
alt Insufficient memory
Eval->>Log: Skip baseline window
Note right of Eval: Window evaluation skipped
else Sufficient
Eval->>Eval: run_baseline()
Note over Eval: Wrapped in try/except CUDA OOM
alt CUDA OOM on baseline
Eval->>GPU: empty_cache + gc + sync
Eval->>Eval: retry once
alt OOM again
Eval->>Log: Skip window due to OOM
else Success
Eval->>Log: Baseline complete
end
else Success
Eval->>Log: Baseline complete
end
loop For each UID
Eval->>Eval: _check_memory_and_cleanup("pre-uid")
alt Insufficient memory
Eval->>Log: Skip UID
else Sufficient
Eval->>Eval: evaluate_uid(uid)
Eval->>GPU: empty_cache + gc + sync (post-UID)
Eval->>Log: Report mem usage
end
end
Eval->>GPU: empty_cache + gc + sync (post-window)
end
sequenceDiagram
autonumber
participant V as Validator
participant GPU as Torch CUDA
participant IO as Checkpoint I/O
participant Log as Logger
rect rgba(200,230,255,0.3)
Note over V,GPU: Checkpoint Load
V->>GPU: empty_cache + gc
V->>V: _check_memory_and_cleanup("pre-load")
alt Insufficient memory
V->>Log: Abort load (insufficient memory)
else Sufficient
V->>IO: load_checkpoint()
alt OOM/NCCL error
V->>GPU: empty_cache + gc + sync
V->>IO: retry load once
alt Fails again
V->>Log: Load failed, escalate/skip
else Success
V->>Log: Load succeeded on retry
end
else Success
V->>Log: Load succeeded
end
V->>GPU: empty_cache + gc (post-load)
end
end
rect rgba(220,255,220,0.3)
Note over V,GPU: Checkpoint Save
V->>V: _check_memory_and_cleanup("pre-save")
alt Insufficient memory
V->>Log: Skip/Delay save
else Sufficient
V->>IO: save_checkpoint()
alt OOM/NCCL error
V->>GPU: empty_cache + gc + sync
V->>Log: Save failed, continue safely
else Success
V->>Log: Save completed
end
end
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings, 1 inconclusive)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
neurons/validator.py (2)
1697-1697: Guard torch.cuda.empty_cache() when CUDA is unavailable.Direct calls will raise on CPU‑only runs.
- torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache()Apply similarly to the other occurrences in these ranges.
Also applies to: 1811-1811, 2351-2351, 2407-2407
3135-3159: Missing core PR objective: add CPU fallback for gradient decompression on OOM (and avoid per‑param empty_cache).Currently decompress/transform always on GPU; on OOM you only clear cache. Implement CPU fallback with chunked update for non‑DT tensors and CPU source for DT distribution.
- # Check if we're in distributed mode - # Use empty_like to avoid copying the param; just provide dtype/device/shape - ref = torch.empty_like(p, device=self.device, dtype=p.dtype) - - decompressed = self.compressor.decompress( - ref, - idxs, - vals, - self.xshapes[n], - self.totalks[n], - quant_params, - ) - - full_grad_src = self.transformer.decode( - decompressed, use_dct=self.hparams.use_dct - ) - # Single conversion to target dtype+device to avoid extra temporaries - full_grad_src = full_grad_src.to( - dtype=p.dtype, device=p.device, non_blocking=True - ) - - # Free intermediate pieces ASAP - del ref, decompressed - # Force immediate cleanup - torch.cuda.empty_cache() + # Try GPU path first, then CPU fallback on OOM + applied_update_eagerly = False # non‑DT CPU fallback updates in‑place + try: + ref = torch.empty_like(p, device=p.device, dtype=p.dtype) + decompressed = self.compressor.decompress( + ref, idxs, vals, self.xshapes[n], self.totalks[n], quant_params + ) + full_grad_src = self.transformer.decode( + decompressed, use_dct=self.hparams.use_dct + ).to(dtype=p.dtype, device=p.device, non_blocking=True) + del ref, decompressed + except (torch.cuda.OutOfMemoryError, RuntimeError) as oom: + if "out of memory" not in str(oom).lower(): + raise + if self.is_master: + tplr.log_with_context( + level="warning", + message=f"GPU OOM while decompressing {n}; falling back to CPU.", + sync_window=self.sync_window, + current_window=self.current_window, + eval_uid=eval_uid, + ) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + # CPU fallback: decompress+decode on CPU + ref_cpu = torch.empty_like(p, device="cpu", dtype=p.dtype) + decompressed_cpu = self.compressor.decompress( + ref_cpu, idxs, vals, self.xshapes[n], self.totalks[n], quant_params + ) + full_grad_cpu = self.transformer.decode( + decompressed_cpu, use_dct=self.hparams.use_dct + ).to(dtype=p.dtype, device="cpu") + del ref_cpu, decompressed_cpu + if isinstance(p, DT): + # DT case: keep CPU tensor for distribute_tensor below + full_grad_src = full_grad_cpu + else: + # Non‑DT: apply update in chunks to avoid large GPU alloc + alpha = self.lr * self.hparams.eval_lr_factor + flat_cpu = full_grad_cpu.view(-1) + flat_param = p.data.view(-1) + elem_size = flat_cpu.element_size() + CHUNK_BYTES = 64 * 1024 * 1024 # ~64MB per copy + chunk_elems = max(1, CHUNK_BYTES // elem_size) + for start in range(0, flat_cpu.numel(), chunk_elems): + end = min(start + chunk_elems, flat_cpu.numel()) + chunk = flat_cpu[start:end].to(p.device, non_blocking=True) + flat_param[start:end].sub_(chunk, alpha=alpha) + del chunk + del full_grad_cpu + applied_update_eagerly = TrueAnd skip the direct non‑DT subtract when the CPU fallback already applied:
- else: - # Single GPU case (non-DTensor) - if on_src: - p.data.sub_( - full_grad_src, - alpha=self.lr * self.hparams.eval_lr_factor, - ) - del full_grad_src + else: + # Single GPU case (non-DTensor) + if on_src and not applied_update_eagerly: + p.data.sub_( + full_grad_src, + alpha=self.lr * self.hparams.eval_lr_factor, + ) + del full_grad_srcAlso remove the per‑param
empty_cache()calls in this hot path—they hurt perf without preventing fragmentation.- # Force immediate cleanup - torch.cuda.empty_cache() @@ - # Force cleanup of large tensors - torch.cuda.empty_cache()Outside this range, add
applied_update_eagerly = Falseat the start of each param iteration (see note below).Also applies to: 3237-3243
🧹 Nitpick comments (5)
neurons/validator.py (5)
68-82: Don't set CUDA/NCCL env vars after importing torch; drop risky defaults and avoid hard‑coding NIC.
- PYTORCH_CUDA_ALLOC_CONF is read when CUDA is initialized; here torch and torch.cuda are already imported/used, so this has no effect.
- For NCCL, forcing IB off, P2P off, and IFNAME=eth0 can silently degrade perf or break multi‑node setups.
- Duplicate import of os.
Recommend: move env setup to the process entrypoint before importing torch/torch.distributed, and gate via config flags (off by default). Remove the hard‑coded NIC and leave NCCL envs to deployment.
Apply at least this minimal cleanup:
-# Set CUDA memory allocator configuration to prevent fragmentation -import os -if not os.environ.get('PYTORCH_CUDA_ALLOC_CONF'): - os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' - -# Set NCCL environment variables for better memory management and debugging -if not os.environ.get('NCCL_DEBUG'): - os.environ['NCCL_DEBUG'] = 'WARN' # Change to INFO for more verbose debugging -if not os.environ.get('NCCL_IB_DISABLE'): - os.environ['NCCL_IB_DISABLE'] = '1' # Disable InfiniBand to use Ethernet -if not os.environ.get('NCCL_SOCKET_IFNAME'): - os.environ['NCCL_SOCKET_IFNAME'] = 'eth0' # Use ethernet interface -if not os.environ.get('NCCL_P2P_DISABLE'): - os.environ['NCCL_P2P_DISABLE'] = '1' # Disable P2P to reduce memory pressure +# NOTE: CUDA/NCCL env should be set before importing torch/initializing CUDA. +# Move these to the launcher/entrypoint behind explicit flags.
1330-1360: Retry path claims “smaller batch size” but does not change it.Second attempt calls the exact same loader; likely OOMs again.
Either actually rebuild the dataloader with a smaller micro_batch_size or remove the misleading comment. I can wire a temporary halved micro_bs retry if you want.
3085-3094: Support var for CPU‑fallback flow control.Initialize
applied_update_eagerly = Falseper parameter so the non‑DT path can skip the second subtract after chunked CPU update.Add just after
has_valid_gradient = True:applied_update_eagerly = False
3158-3159: Remove empty_cache() in tight per‑param loop.This thrashes the allocator and slows everything down; it doesn’t free reserved blocks anyway.
Also applies to: 3211-3212
245-264: Minor: consolidate GC/empty_cache fences.These pre/post fences are fine but repeated; consider a small helper (noop on CPU) to reduce duplication.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
neurons/validator.py(17 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
neurons/validator.py (4)
src/tplr/logging.py (2)
log_with_context(207-223)log_with_context(290-309)src/tplr/sharded_sampler.py (1)
set_window_uid(74-79)neurons/trainer.py (1)
evaluate_model(340-392)src/tplr/dcp_checkpoint.py (3)
save_local_async(222-268)upload(271-404)download_and_load(657-680)
| def _check_memory_and_cleanup(self, context: str = "unknown") -> bool: | ||
| """Check GPU memory usage and perform cleanup if needed. | ||
| Returns True if memory is available, False if critically low. | ||
| """ | ||
| if not torch.cuda.is_available(): | ||
| return True | ||
|
|
||
| # Get memory info | ||
| memory_allocated = torch.cuda.memory_allocated() | ||
| memory_reserved = torch.cuda.memory_reserved() | ||
| memory_total = torch.cuda.get_device_properties(0).total_memory | ||
|
|
||
| # Calculate usage percentages | ||
| allocated_pct = (memory_allocated / memory_total) * 100 | ||
| reserved_pct = (memory_reserved / memory_total) * 100 | ||
|
|
||
| # Log memory usage | ||
| if self.is_master: | ||
| tplr.log_with_context( | ||
| level="debug", | ||
| message=f"Memory check ({context}) - Allocated: {allocated_pct:.1f}% ({memory_allocated/1024**3:.2f}GB), " | ||
| f"Reserved: {reserved_pct:.1f}% ({memory_reserved/1024**3:.2f}GB)", | ||
| sync_window=self.sync_window, | ||
| current_window=self.current_window, | ||
| ) | ||
|
|
||
| # If allocated memory is over 85%, perform cleanup | ||
| if allocated_pct > 85.0: | ||
| if self.is_master: | ||
| tplr.log_with_context( | ||
| level="warning", | ||
| message=f"High memory usage detected ({allocated_pct:.1f}%), performing cleanup", | ||
| sync_window=self.sync_window, | ||
| current_window=self.current_window, | ||
| ) | ||
| torch.cuda.empty_cache() | ||
| import gc | ||
| gc.collect() | ||
| torch.cuda.synchronize() | ||
|
|
||
| # Check again after cleanup | ||
| new_allocated = torch.cuda.memory_allocated() | ||
| new_allocated_pct = (new_allocated / memory_total) * 100 | ||
|
|
||
| # If still critically low after cleanup, return False | ||
| if new_allocated_pct > 90.0: | ||
| if self.is_master: | ||
| tplr.log_with_context( | ||
| level="error", | ||
| message=f"Critical memory usage even after cleanup ({new_allocated_pct:.1f}%)", | ||
| sync_window=self.sync_window, | ||
| current_window=self.current_window, | ||
| ) | ||
| return False | ||
|
|
||
| return True | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GPU memory check uses device 0; fix for multi‑GPU and avoid false signals.
Current code queries device 0 regardless of local rank, skewing thresholds and skip logic.
Suggested fix:
- memory_allocated = torch.cuda.memory_allocated()
- memory_reserved = torch.cuda.memory_reserved()
- memory_total = torch.cuda.get_device_properties(0).total_memory
+ # Use the current CUDA device (or self.device) to be rank-correct
+ device_idx = (
+ self.device.index
+ if isinstance(self.device, torch.device)
+ and self.device.type == "cuda"
+ and self.device.index is not None
+ else torch.cuda.current_device()
+ )
+ memory_allocated = torch.cuda.memory_allocated(device_idx)
+ memory_reserved = torch.cuda.memory_reserved(device_idx)
+ memory_total = torch.cuda.get_device_properties(device_idx).total_memory📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _check_memory_and_cleanup(self, context: str = "unknown") -> bool: | |
| """Check GPU memory usage and perform cleanup if needed. | |
| Returns True if memory is available, False if critically low. | |
| """ | |
| if not torch.cuda.is_available(): | |
| return True | |
| # Get memory info | |
| memory_allocated = torch.cuda.memory_allocated() | |
| memory_reserved = torch.cuda.memory_reserved() | |
| memory_total = torch.cuda.get_device_properties(0).total_memory | |
| # Calculate usage percentages | |
| allocated_pct = (memory_allocated / memory_total) * 100 | |
| reserved_pct = (memory_reserved / memory_total) * 100 | |
| # Log memory usage | |
| if self.is_master: | |
| tplr.log_with_context( | |
| level="debug", | |
| message=f"Memory check ({context}) - Allocated: {allocated_pct:.1f}% ({memory_allocated/1024**3:.2f}GB), " | |
| f"Reserved: {reserved_pct:.1f}% ({memory_reserved/1024**3:.2f}GB)", | |
| sync_window=self.sync_window, | |
| current_window=self.current_window, | |
| ) | |
| # If allocated memory is over 85%, perform cleanup | |
| if allocated_pct > 85.0: | |
| if self.is_master: | |
| tplr.log_with_context( | |
| level="warning", | |
| message=f"High memory usage detected ({allocated_pct:.1f}%), performing cleanup", | |
| sync_window=self.sync_window, | |
| current_window=self.current_window, | |
| ) | |
| torch.cuda.empty_cache() | |
| import gc | |
| gc.collect() | |
| torch.cuda.synchronize() | |
| # Check again after cleanup | |
| new_allocated = torch.cuda.memory_allocated() | |
| new_allocated_pct = (new_allocated / memory_total) * 100 | |
| # If still critically low after cleanup, return False | |
| if new_allocated_pct > 90.0: | |
| if self.is_master: | |
| tplr.log_with_context( | |
| level="error", | |
| message=f"Critical memory usage even after cleanup ({new_allocated_pct:.1f}%)", | |
| sync_window=self.sync_window, | |
| current_window=self.current_window, | |
| ) | |
| return False | |
| return True | |
| def _check_memory_and_cleanup(self, context: str = "unknown") -> bool: | |
| """Check GPU memory usage and perform cleanup if needed. | |
| Returns True if memory is available, False if critically low. | |
| """ | |
| if not torch.cuda.is_available(): | |
| return True | |
| # Use the current CUDA device (or self.device) to be rank-correct | |
| device_idx = ( | |
| self.device.index | |
| if isinstance(self.device, torch.device) | |
| and self.device.type == "cuda" | |
| and self.device.index is not None | |
| else torch.cuda.current_device() | |
| ) | |
| # Get memory info | |
| memory_allocated = torch.cuda.memory_allocated(device_idx) | |
| memory_reserved = torch.cuda.memory_reserved(device_idx) | |
| memory_total = torch.cuda.get_device_properties(device_idx).total_memory | |
| # Calculate usage percentages | |
| allocated_pct = (memory_allocated / memory_total) * 100 | |
| reserved_pct = (memory_reserved / memory_total) * 100 | |
| # Log memory usage | |
| if self.is_master: | |
| tplr.log_with_context( | |
| level="debug", | |
| message=f"Memory check ({context}) - Allocated: {allocated_pct:.1f}% ({memory_allocated/1024**3:.2f}GB), " | |
| f"Reserved: {reserved_pct:.1f}% ({memory_reserved/1024**3:.2f}GB)", | |
| sync_window=self.sync_window, | |
| current_window=self.current_window, | |
| ) | |
| # If allocated memory is over 85%, perform cleanup | |
| if allocated_pct > 85.0: | |
| if self.is_master: | |
| tplr.log_with_context( | |
| level="warning", | |
| message=f"High memory usage detected ({allocated_pct:.1f}%), performing cleanup", | |
| sync_window=self.sync_window, | |
| current_window=self.current_window, | |
| ) | |
| torch.cuda.empty_cache() | |
| import gc | |
| gc.collect() | |
| torch.cuda.synchronize() | |
| # Check again after cleanup | |
| new_allocated = torch.cuda.memory_allocated(device_idx) | |
| new_allocated_pct = (new_allocated / memory_total) * 100 | |
| # If still critically low after cleanup, return False | |
| if new_allocated_pct > 90.0: | |
| if self.is_master: | |
| tplr.log_with_context( | |
| level="error", | |
| message=f"Critical memory usage even after cleanup ({new_allocated_pct:.1f}%)", | |
| sync_window=self.sync_window, | |
| current_window=self.current_window, | |
| ) | |
| return False | |
| return True |
🤖 Prompt for AI Agents
In neurons/validator.py around lines 184 to 240, the GPU memory checks always
query device 0 which misreports usage on multi-GPU setups; change the code to
target the active device (e.g., use torch.cuda.current_device() or
self.local_rank/self.device if available), call memory APIs with that device
(torch.cuda.memory_allocated(device), torch.cuda.memory_reserved(device),
torch.cuda.get_device_properties(device)), perform
torch.cuda.synchronize(device) when cleaning, and include the device id in log
messages; also add a safe fallback (try/except) to device 0 if determining the
active device fails.
| # Check memory before baseline evaluation | ||
| if not self._check_memory_and_cleanup("before_baseline_eval"): | ||
| if self.is_master: | ||
| tplr.log_with_context( | ||
| level="critical", | ||
| message="Insufficient memory for baseline evaluation. Skipping evaluation window.", | ||
| sync_window=self.sync_window, | ||
| current_window=self.current_window, | ||
| ) | ||
| continue | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Barrier/consensus needed when skipping baseline eval due to low memory.
Unilaterally continue risks desync: other ranks will hit later barriers and deadlock.
Use an all‑reduce style consensus:
- if not self._check_memory_and_cleanup("before_baseline_eval"):
- if self.is_master:
- tplr.log_with_context(
- level="critical",
- message="Insufficient memory for baseline evaluation. Skipping evaluation window.",
- sync_window=self.sync_window,
- current_window=self.current_window,
- )
- continue
+ mem_ok_local = self._check_memory_and_cleanup("before_baseline_eval")
+ mem_ok = dist_helper.all_ok(
+ mem_ok_local, self.device, tag="mem_ok_before_baseline"
+ )
+ if not mem_ok:
+ if self.is_master:
+ tplr.log_with_context(
+ level="critical",
+ message="Insufficient memory for baseline evaluation on one or more ranks. Skipping this window.",
+ sync_window=self.sync_window,
+ current_window=self.current_window,
+ )
+ continue📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Check memory before baseline evaluation | |
| if not self._check_memory_and_cleanup("before_baseline_eval"): | |
| if self.is_master: | |
| tplr.log_with_context( | |
| level="critical", | |
| message="Insufficient memory for baseline evaluation. Skipping evaluation window.", | |
| sync_window=self.sync_window, | |
| current_window=self.current_window, | |
| ) | |
| continue | |
| # Check memory before baseline evaluation | |
| mem_ok_local = self._check_memory_and_cleanup("before_baseline_eval") | |
| mem_ok = dist_helper.all_ok( | |
| mem_ok_local, self.device, tag="mem_ok_before_baseline" | |
| ) | |
| if not mem_ok: | |
| if self.is_master: | |
| tplr.log_with_context( | |
| level="critical", | |
| message="Insufficient memory for baseline evaluation on one or more ranks. Skipping this window.", | |
| sync_window=self.sync_window, | |
| current_window=self.current_window, | |
| ) | |
| continue |
| # Check memory before each UID evaluation | ||
| if not self._check_memory_and_cleanup(f"before_uid_{eval_uid}"): | ||
| if self.is_master: | ||
| tplr.log_with_context( | ||
| level="warning", | ||
| message=f"Insufficient memory for UID {eval_uid} evaluation. Skipping this UID.", | ||
| sync_window=self.sync_window, | ||
| current_window=self.current_window, | ||
| eval_uid=eval_uid, | ||
| ) | ||
| continue | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per‑UID memory gate must be rank‑aligned to avoid broadcast/barrier hangs.
A single‑rank continue before gradient validity broadcast will deadlock others.
- if not self._check_memory_and_cleanup(f"before_uid_{eval_uid}"):
+ mem_ok_local = self._check_memory_and_cleanup(f"before_uid_{eval_uid}")
+ mem_ok = dist_helper.all_ok(
+ mem_ok_local, self.device, tag=f"mem_ok_uid_{eval_uid}"
+ )
+ if not mem_ok:
if self.is_master:
tplr.log_with_context(
level="warning",
message=f"Insufficient memory for UID {eval_uid} evaluation. Skipping this UID.",
sync_window=self.sync_window,
current_window=self.current_window,
eval_uid=eval_uid,
)
continue📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Check memory before each UID evaluation | |
| if not self._check_memory_and_cleanup(f"before_uid_{eval_uid}"): | |
| if self.is_master: | |
| tplr.log_with_context( | |
| level="warning", | |
| message=f"Insufficient memory for UID {eval_uid} evaluation. Skipping this UID.", | |
| sync_window=self.sync_window, | |
| current_window=self.current_window, | |
| eval_uid=eval_uid, | |
| ) | |
| continue | |
| # Check memory before each UID evaluation | |
| mem_ok_local = self._check_memory_and_cleanup(f"before_uid_{eval_uid}") | |
| mem_ok = dist_helper.all_ok( | |
| mem_ok_local, self.device, tag=f"mem_ok_uid_{eval_uid}" | |
| ) | |
| if not mem_ok: | |
| if self.is_master: | |
| tplr.log_with_context( | |
| level="warning", | |
| message=f"Insufficient memory for UID {eval_uid} evaluation. Skipping this UID.", | |
| sync_window=self.sync_window, | |
| current_window=self.current_window, | |
| eval_uid=eval_uid, | |
| ) | |
| continue |
| try: | ||
| # Check memory before FSDP checkpoint save | ||
| self._check_memory_and_cleanup(operation="FSDP_checkpoint_save") | ||
|
|
||
| handle = await self.ckpt.save_local_async( | ||
| model=self.model, | ||
| window=self.sync_window, | ||
| sync_window=self.sync_window, | ||
| topology="FSDP", | ||
| ) | ||
|
|
||
| # Schedule an upload that will wait for the save to finish, then upload in background | ||
| await self.ckpt.upload( | ||
| window=self.sync_window, | ||
| background=True, | ||
| delete_local_on_success=True, | ||
| wait_for=handle, | ||
| ) | ||
| # Schedule an upload that will wait for the save to finish, then upload in background | ||
| await self.ckpt.upload( | ||
| window=self.sync_window, | ||
| background=True, | ||
| delete_local_on_success=True, | ||
| wait_for=handle, | ||
| ) | ||
|
|
||
| # Clean up after checkpoint save | ||
| torch.cuda.empty_cache() | ||
|
|
||
| except (torch.cuda.OutOfMemoryError, RuntimeError) as e: | ||
| if "CUDA out of memory" in str(e) or "NCCL" in str(e): | ||
| tplr.logger.warning(f"CUDA/NCCL error during FSDP checkpoint save at step {self.global_step}: {e}") | ||
| # Force memory cleanup and continue | ||
| torch.cuda.empty_cache() | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.synchronize() | ||
| tplr.logger.info("Continuing without checkpoint save due to memory/NCCL constraints") | ||
| else: | ||
| raise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrong keyword passed to _check_memory_and_cleanup().
_check_memory_and_cleanup() takes context, not operation; this will raise TypeError at runtime.
- self._check_memory_and_cleanup(operation="FSDP_checkpoint_save")
+ self._check_memory_and_cleanup(context="FSDP_checkpoint_save")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| try: | |
| # Check memory before FSDP checkpoint save | |
| self._check_memory_and_cleanup(operation="FSDP_checkpoint_save") | |
| handle = await self.ckpt.save_local_async( | |
| model=self.model, | |
| window=self.sync_window, | |
| sync_window=self.sync_window, | |
| topology="FSDP", | |
| ) | |
| # Schedule an upload that will wait for the save to finish, then upload in background | |
| await self.ckpt.upload( | |
| window=self.sync_window, | |
| background=True, | |
| delete_local_on_success=True, | |
| wait_for=handle, | |
| ) | |
| # Schedule an upload that will wait for the save to finish, then upload in background | |
| await self.ckpt.upload( | |
| window=self.sync_window, | |
| background=True, | |
| delete_local_on_success=True, | |
| wait_for=handle, | |
| ) | |
| # Clean up after checkpoint save | |
| torch.cuda.empty_cache() | |
| except (torch.cuda.OutOfMemoryError, RuntimeError) as e: | |
| if "CUDA out of memory" in str(e) or "NCCL" in str(e): | |
| tplr.logger.warning(f"CUDA/NCCL error during FSDP checkpoint save at step {self.global_step}: {e}") | |
| # Force memory cleanup and continue | |
| torch.cuda.empty_cache() | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| tplr.logger.info("Continuing without checkpoint save due to memory/NCCL constraints") | |
| else: | |
| raise | |
| try: | |
| # Check memory before FSDP checkpoint save | |
| self._check_memory_and_cleanup(context="FSDP_checkpoint_save") | |
| handle = await self.ckpt.save_local_async( | |
| model=self.model, | |
| window=self.sync_window, | |
| sync_window=self.sync_window, | |
| topology="FSDP", | |
| ) | |
| # Schedule an upload that will wait for the save to finish, then upload in background | |
| await self.ckpt.upload( | |
| window=self.sync_window, | |
| background=True, | |
| delete_local_on_success=True, | |
| wait_for=handle, | |
| ) | |
| # Clean up after checkpoint save | |
| torch.cuda.empty_cache() | |
| except (torch.cuda.OutOfMemoryError, RuntimeError) as e: | |
| if "CUDA out of memory" in str(e) or "NCCL" in str(e): | |
| tplr.logger.warning(f"CUDA/NCCL error during FSDP checkpoint save at step {self.global_step}: {e}") | |
| # Force memory cleanup and continue | |
| torch.cuda.empty_cache() | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| tplr.logger.info("Continuing without checkpoint save due to memory/NCCL constraints") | |
| else: | |
| raise |
🤖 Prompt for AI Agents
In neurons/validator.py around lines 2368 to 2399, the call to
_check_memory_and_cleanup uses the wrong keyword argument name (operation) which
will raise a TypeError; change the call to use the expected keyword context,
e.g. self._check_memory_and_cleanup(context="FSDP_checkpoint_save"), leaving the
rest of the logic unchanged so memory is checked/cleaned before the FSDP
checkpoint save.
Description
Adds automatic CPU fallback when GPU runs out of memory during gradient decompression for large parameters
Related Issue(s)
Type of Change
Branch Naming
Commit Messages
Code Quality
Testing
Documentation
If this is a breaking change
Screenshots/Examples
Additional Notes
Summary by CodeRabbit
New Features
Bug Fixes
Chores