-
Notifications
You must be signed in to change notification settings - Fork 49
fix race conditions in distributed checkpoint loading and prevent memory leaks in gradient gathering #664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughAdds device-aware, no-grad aggregation with centralized per-UID validation and cleanup in comms; introduces process-group-aware barriers and retry-wrapped sidecar reads in distributed checkpoint download/load; and makes test environment variable assignments explicit. Changes
Sequence Diagram(s)sequenceDiagram
participant Rank as Worker Rank (each)
participant PG as ProcessGroup / Barrier
participant Disk as Filesystem (sidecar)
participant Loader as Model Loader
Note over Rank,PG: download_distributed start
Rank->>PG: barrier(start)
PG-->>Rank: proceed
Rank->>Disk: download/check for sidecar
Disk-->>Rank: file or error
alt sidecar read fails
Rank->>Disk: retry read (asyncio.to_thread)
Disk-->>Rank: parse error or success
end
Rank->>PG: barrier(after-download)
PG-->>Rank: proceed
Rank->>PG: barrier(before-load)
PG-->>Rank: proceed
Rank->>Loader: load checkpoint
Loader-->>Rank: loaded
Rank->>PG: barrier(after-load)
PG-->>Rank: proceed
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (2)
🧰 Additional context used🧬 Code graph analysis (1)src/tplr/comms.py (2)
🔇 Additional comments (7)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/tplr/comms.py(2 hunks)src/tplr/dcp_checkpoint.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tplr/comms.py (5)
src/tplr/logging.py (4)
log_with_context(207-223)log_with_context(290-309)P(51-62)T(41-48)tests/test_prepare_gradient_dict.py (1)
info(45-46)src/tplr/schemas.py (1)
CommsGetResult(49-65)tests/test_dcp_checkpoint.py (1)
warning(117-118)tests/conftest.py (1)
totalks(141-144)
🔇 Additional comments (7)
src/tplr/comms.py (4)
1673-1676: LGTM – Explicit type annotations improve clarity.The type annotations for aggregation state make the expected data structures clear and help catch type errors early.
1707-1708: LGTM –torch.no_grad()wrapper for memory efficiency.Wrapping the aggregation loop in
no_grad()is the right approach since gradient tracking is unnecessary during gather operations. This addresses the memory leak concern mentioned in the PR objectives.
1884-1911: LGTM – Clean aggregation with explicit resource cleanup.The aggregation correctly accumulates compressed tensors without unnecessary copies, and the explicit
delstatements drop per-UID references immediately after processing to prevent memory accumulation.
1917-1918: LGTM – Explicit cleanup of batch responses.Deleting
batch_responsesafter processing helps with immediate memory reclamation rather than waiting for garbage collection, which is important for long-running training processes.src/tplr/dcp_checkpoint.py (3)
756-758: LGTM – Initial barrier ensures window agreement across ranks.The barrier before file enumeration ensures all ranks have agreed on the target window and created the local directory, preventing race conditions where some ranks might start downloading before others are ready.
878-880: LGTM – Consistent process group for barriers and load.Passing
process_grouptodownload_distributedensures that all barriers use the same process group as the subsequentload()call, which is essential for HSDP and other topologies with multiple process groups.
895-904: LGTM – Robust retry logic for sidecar reads.Wrapping the file read in
asyncio.to_threadprevents blocking the event loop, and the retry mechanism with 5 attempts handles transient filesystem visibility issues common on NFS/Lustre.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
src/tplr/comms.py (1)
1834-1853: Fix device comparison to avoid always taking the.to()branchIn the
valsvalidation path, the current check:target_device = torch.device(device) if tensor.device.type == target_device: tensor_to_check = tensor needs_delete = False else: tensor_to_check = tensor.to(target_device, non_blocking=True) needs_delete = Truecompares
tensor.device.type(a string like"cuda") to atorch.deviceobject, so it will always take theelsebranch and call.to(...), even when the tensor is already on the requested device. This defeats the goal of avoiding unnecessary copies and makesneeds_deleteunreliable.A more robust pattern:
- target_device = torch.device(device) - - if tensor.device.type == target_device: - tensor_to_check = tensor - needs_delete = False - else: - tensor_to_check = tensor.to(target_device, non_blocking=True) - needs_delete = True + target_device = torch.device(device) + + if tensor.device == target_device or tensor.device.type == target_device.type: + tensor_to_check = tensor + needs_delete = False + else: + tensor_to_check = tensor.to(target_device, non_blocking=True) + # only treat as temporary if a new tensor was created + needs_delete = tensor_to_check is not tensorThis keeps the NaN/Inf check device-aware without forcing redundant
.to()calls when the tensor is already on the right device and avoids deleting aliases unnecessarily.
🧹 Nitpick comments (1)
src/tplr/dcp_checkpoint.py (1)
755-792: Align world/rank computation with the passed process_group
download_distributednow takes aprocess_group, butworld, rare still derived from the default group via_world()/_rank()with no arguments. If you ever call this with a non-default PG (e.g., shard PG), assignment ofassignedshards and logging will use the wrongworld/r.Consider threading
process_groupthrough:- world, r = _world(), _rank() + world, r = _world(process_group), _rank(process_group)This keeps both the work partitioning and the barriers consistently scoped to the same group.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/tplr/comms.py(2 hunks)src/tplr/dcp_checkpoint.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tplr/comms.py (3)
src/tplr/logging.py (4)
log_with_context(207-223)log_with_context(290-309)P(51-62)T(41-48)src/tplr/schemas.py (1)
CommsGetResult(49-65)tests/conftest.py (1)
totalks(141-144)
🔇 Additional comments (1)
src/tplr/comms.py (1)
1673-1925: Gather refactor (no_grad + per‑UID validation + cleanup) looks solidThe restructure around
torch.no_grad(), explicitaggregated_state_dict/valid_uids/skipped_uids, and aggressivedelofstate_dict_resp,response, andbatch_responsesafter aggregation is a nice improvement for both correctness (no autograd graphs) and memory footprint over long runs. Validation ofquant_params,idxs, andvalsbefore aggregation also reads coherent and defensive.No issues spotted with control flow or early‑return behavior here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (1)
src/tplr/dcp_checkpoint.py (1)
914-919: Fix all_reduce device and group parameters to prevent NCCL errors and deadlocks.The sidecar read synchronization has two issues previously flagged:
ok_tensoron CPU will fail with NCCL backend (NCCL doesn't support CPU tensors).- Missing
group=process_groupcauses misalignment with other barriers whendownload_and_loadis called on a subgroup.Apply this diff:
if dist.is_available() and dist.is_initialized(): ok_tensor = torch.tensor([1 if read_ok else 0], dtype=torch.int32) - # always keep on CPU for safety - dist.all_reduce(ok_tensor, op=dist.ReduceOp.MIN) + # Use GPU when available (NCCL requires it), CPU is fine for Gloo + if torch.cuda.is_available(): + ok_tensor = ok_tensor.to(torch.cuda.current_device()) + dist.all_reduce(ok_tensor, op=dist.ReduceOp.MIN, group=process_group) read_ok = bool(ok_tensor.item())
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/tplr/comms.py(2 hunks)src/tplr/dcp_checkpoint.py(3 hunks)tests/conftest.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tplr/comms.py (5)
src/tplr/logging.py (5)
log_with_context(207-223)log_with_context(290-309)debug(106-110)P(51-62)T(41-48)tests/test_prepare_gradient_dict.py (1)
info(45-46)src/tplr/schemas.py (1)
CommsGetResult(49-65)tests/test_dcp_checkpoint.py (1)
warning(117-118)tests/conftest.py (1)
totalks(141-144)
🔇 Additional comments (10)
tests/conftest.py (1)
12-33: Intentional override of environment variables for test reliability.The switch from
setdefault()to direct assignment correctly addresses the case where CI might set empty strings for these variables. This ensures mock values are consistently used during tests.One consideration: developers with real R2 credentials in their shell environment will have them overridden during test runs. This is typically acceptable for test isolation, but worth noting in case debugging real S3 connectivity is needed.
src/tplr/dcp_checkpoint.py (3)
756-758: Initial barrier ensures consistent state before download.This barrier correctly synchronizes all ranks before proceeding with the download, ensuring they've all created
local_dirand agreed on the window. Good addition for shared filesystem scenarios.
876-881: Correct propagation of process_group to download_distributed.Passing
process_groupensures that all barriers withindownload_distributedoperate on the same group as the caller, preventing cross-group deadlock scenarios.
936-942: Pre/post load barriers correctly synchronize DCP load phase.The barriers ensure all ranks enter and exit the load phase together, which is essential for consistent distributed state.
One edge case: if
load_local()raises an exception on a subset of ranks, the remaining ranks will hang at the post-load barrier (line 942). Consider wrapping in try/except with failure synchronization similar to the sidecar read pattern if load failures are expected.src/tplr/comms.py (6)
1673-1676: Good: Explicit type annotations improve code clarity.The type annotations for
aggregated_state_dict,valid_uids,skipped_uids, andglobal_stepsmake the data flow clearer and help with IDE support.
1707-1708: Critical memory fix: no_grad context prevents gradient graph accumulation.Wrapping the aggregation loop in
torch.no_grad()is essential for preventing memory leaks. Without it, operations on tensors could accumulate computation graphs, causing unbounded memory growth over long runs.
1759-1786: Quant params validation correctly handles both float and tensor scales.The validation now properly checks:
- Finite shift value
- Float scale: finite, not too small (< 1e-12), not too large (> 1e4)
- Tensor scale: all finite, abs max not too small/large
- Lookup table: all finite values
The explicit parentheses fix the operator precedence issue from the past review.
1889-1917: Aggregation preserves tensors efficiently with proper cleanup.The code correctly:
- Appends tensors directly without unnecessary copies
- Tracks download bytes for metrics
- Explicitly deletes per-UID references after aggregation
The tensors in
aggregated_state_dictremain alive (as needed for the return value), but UID-local references are promptly released.
1923-1924: Proper cleanup of batch_responses list after processing.Deleting
batch_responsesafter the loop releases the list structure. Combined with the per-UID deletions (lines 1916-1917), this ensures no stale references persist after the gather operation.
1878-1886: Proper cleanup on validation failure prevents memory leaks from invalid responses.The code correctly releases references for skipped UIDs before continuing, ensuring invalid responses don't accumulate in memory over the gather loop.
Description
This PR addresses two high-impact issues identified in the distributed training pipeline:
Problem:
During download_distributed() and download_and_load(), ranks could begin reading checkpoint files (especially extra_metadata.json) before rank 0 finished downloading or performing mop-up.
On shared filesystems (NFS, Lustre, FUSE-S3), this caused intermittent failures, partial file reads, or corrupted loads.
Fixes:
Result:
Checkpoint downloads and model loads are now strictly ordered and fully race-free across all ranks.
Problem:
The previous implementation accumulated unnecessary tensor references:
This led to GPU/CPU memory bloating and eventual OOM after many windows, especially on large models or high peer counts.
Fixes:
Result:
gather() now has stable memory usage across long training runs, preventing silent memory leaks.
Related Issue(s)
No issues related
Type of Change
Branch Naming
Commit Messages
Code Quality
Testing
Documentation
Summary by CodeRabbit
Bug Fixes
Chores
✏️ Tip: You can customize this high-level summary in your review settings.
Contribution by Gittensor, learn more at https://gittensor.io/