Skip to content

Conversation

@SmartDever02
Copy link

@SmartDever02 SmartDever02 commented Dec 2, 2025

Description

This PR addresses two high-impact issues identified in the distributed training pipeline:

  1. Race Condition in Distributed Checkpoint Download & Load

Problem:
During download_distributed() and download_and_load(), ranks could begin reading checkpoint files (especially extra_metadata.json) before rank 0 finished downloading or performing mop-up.
On shared filesystems (NFS, Lustre, FUSE-S3), this caused intermittent failures, partial file reads, or corrupted loads.

Fixes:

  • Added deterministic synchronization barriers before file reads and after rank-0 mop-up.
  • Passed the correct process_group into download_distributed() from download_and_load().
  • Added robust file-visibility retries for extra_metadata.json.
  • Added pre- and post-load barriers to ensure consistent DCP state across ranks.
  • Added optional initialization barrier in download_distributed() for clarity.

Result:
Checkpoint downloads and model loads are now strictly ordered and fully race-free across all ranks.

  1. Memory Leak / Long-run OOM in Gradient gather()

Problem:
The previous implementation accumulated unnecessary tensor references:

  • Temporary validation tensors (tensor.to(device))
  • Full state_dict_resp tensors per UID
  • Large batch_responses list retained until end of method

This led to GPU/CPU memory bloating and eventual OOM after many windows, especially on large models or high peer counts.

Fixes:

  • Wrapped validation and aggregation inside torch.no_grad().
  • Avoided unnecessary .to(device) copies.
  • Explicitly deleted: tensor_to_check, state_dict_resp, response, batch_responses
  • Ensured UID-local references are dropped immediately after aggregation.
  • Preserved compressed tensors without duplicating them.

Result:
gather() now has stable memory usage across long training runs, preventing silent memory leaks.

Related Issue(s)

No issues related

Type of Change

  • Feature (adding new functionality)
  • Fix (resolving a bug or issue)
  • Docs (documentation updates)
  • Refactor (code changes that don't affect functionality)
  • Maintenance (dependency updates or other maintenance)
  • Tests (adding or improving tests)
  • Breaking change (fix or feature with incompatible API changes)
  • Other: _____

Branch Naming

  • My branch follows the project's naming convention (e.g., feature/add-new-capability)

Commit Messages

  • My commits are small, atomic, and have proper commit messages
  • Commit messages are in imperative mood with a capitalized summary under 50 chars

Code Quality

  • [ x ] I've performed a self-review of my code
  • I've added appropriate docstrings following the project's conventions
  • I've added proper logging where necessary (without trailing periods)
  • I've applied linting and formatting with Ruff
  • My code generates no new warnings

Testing

  • I've added tests for new functionality or bug fixes
  • All tests pass locally with my changes
  • Test coverage has not decreased

Documentation

  • I've updated documentation to reflect my changes
  • I've updated comments in hard-to-understand areas

Summary by CodeRabbit

  • Bug Fixes

    • Improved validation and error handling for per-peer aggregated updates with safer skip paths, device-aware checks, and clearer logging.
    • Ensured no-gradient processing during aggregation to avoid unintended gradient tracking.
    • Strengthened synchronization during distributed checkpoint download/load with extra barriers and deterministic retry handling for metadata reads.
  • Chores

    • Reduced memory retention after per-peer processing via explicit cleanup and metric refinements.
    • Made test environment setup explicit to ensure consistent CI configuration.

✏️ Tip: You can customize this high-level summary in your review settings.

Contribution by Gittensor, learn more at https://gittensor.io/

@coderabbitai
Copy link

coderabbitai bot commented Dec 2, 2025

Walkthrough

Adds device-aware, no-grad aggregation with centralized per-UID validation and cleanup in comms; introduces process-group-aware barriers and retry-wrapped sidecar reads in distributed checkpoint download/load; and makes test environment variable assignments explicit.

Changes

Cohort / File(s) Change Summary
Aggregation hardening
src/tplr/comms.py
Added local type annotations; wrapped per-UID aggregation in torch.no_grad(); consolidated per-UID response validation/error handling; validated quant_params/idxs/vals (finite checks, 12-bit unpacking, device-aware NaN/Inf checks); updated aggregated_state_dict population and download-byte accounting; explicit per-UID and batch del cleanup; returns a SimpleNamespace with metrics, state_dict, valid_uids, skipped_uids.
Distributed checkpoint sync & robustness
src/tplr/dcp_checkpoint.py
download_distributed now accepts process_group and adds process-group-aware barriers at download start, after download/mop-up, before load, and after load; extra_metadata.json read moved to retry-wrapped asyncio.to_thread with JSON parse/error handling and deterministic failure returns.
Tests — env config
tests/conftest.py
Replaced setdefault environment initialization with direct assignments; added explicit R2-related write/read key env vars and DATASET_BINS_PATH; enforces explicit test-time env configuration.

Sequence Diagram(s)

sequenceDiagram
    participant Rank as Worker Rank (each)
    participant PG as ProcessGroup / Barrier
    participant Disk as Filesystem (sidecar)
    participant Loader as Model Loader

    Note over Rank,PG: download_distributed start
    Rank->>PG: barrier(start)
    PG-->>Rank: proceed
    Rank->>Disk: download/check for sidecar
    Disk-->>Rank: file or error
    alt sidecar read fails
        Rank->>Disk: retry read (asyncio.to_thread)
        Disk-->>Rank: parse error or success
    end
    Rank->>PG: barrier(after-download)
    PG-->>Rank: proceed
    Rank->>PG: barrier(before-load)
    PG-->>Rank: proceed
    Rank->>Loader: load checkpoint
    Loader-->>Rank: loaded
    Rank->>PG: barrier(after-load)
    PG-->>Rank: proceed
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Focus areas:
    • src/tplr/comms.py: 12-bit unpacking correctness, device transfers and no-grad context, NaN/Inf checks on CUDA vs CPU, proper tensor deletion to avoid memory leaks.
    • src/tplr/dcp_checkpoint.py: barrier placement and process_group usage to avoid deadlocks; retry/backoff and thread-based sidecar reads behavior.
    • tests/conftest.py: ensure CI tests tolerate explicit environment overrides and added variables.

Possibly related PRs

Poem

🐰 I hopped through tensors, light and spry,
No-grad on my paws as gradients passed by.
Barriers kept ranks marching true,
I retried the sidecar and swept metadata through.
Bytes cleaned, hops done — back to my burrow I flew.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title directly addresses the main changes: race condition fixes and memory leak prevention in gradient gathering, matching the core objectives.
Description check ✅ Passed The description comprehensively covers both issues, their fixes, and results, with detailed problem statements and solutions clearly documented.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3b00b67 and 1914945.

📒 Files selected for processing (2)
  • src/tplr/comms.py (2 hunks)
  • src/tplr/dcp_checkpoint.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tplr/comms.py (2)
src/tplr/logging.py (4)
  • log_with_context (207-223)
  • log_with_context (290-309)
  • P (51-62)
  • T (41-48)
src/tplr/schemas.py (1)
  • CommsGetResult (49-65)
🔇 Additional comments (7)
src/tplr/comms.py (4)

1673-1676: LGTM: Clear type annotations improve readability.

Explicit typing for aggregation state variables makes the code more maintainable and helps catch type mismatches during development.


1707-1708: LGTM: torch.no_grad() prevents unnecessary gradient tracking.

Wrapping the aggregation loop in no_grad eliminates GPU memory growth from gradient computation graphs that were never needed for this inference-only path.


1893-1896: LGTM: Explicit reference cleanup prevents memory accumulation.

Deleting state_dict_resp, response, and batch_responses immediately after use ensures tensor memory is released promptly rather than waiting for GC cycles, which is critical for long training runs.

Also applies to: 1925-1927, 1933-1934


1757-1792: LGTM: Comprehensive quant_params validation with correct operator precedence.

The validation now correctly handles both float and tensor scale cases with explicit parentheses. The checks for shift finiteness, scale bounds (1e-12 to 1e4), and lookup table validity provide robust protection against malformed gradients.

src/tplr/dcp_checkpoint.py (3)

756-758: LGTM: Initial barrier ensures all ranks agree before downloading.

Adding a barrier after window discovery and local_dir creation prevents races where some ranks might start downloading before others have set up their target directories.


878-880: LGTM: Process group passed consistently to download phase.

Passing process_group to download_distributed ensures barriers are scoped to the same group used in load_local, preventing deadlocks when operating on a subgroup.


937-943: LGTM: Barriers before and after load ensure consistent state across ranks.

The pre-load barrier at line 938 ensures all ranks are ready before DCP load begins. The post-load barrier at line 943 guarantees all ranks complete loading before any rank returns, preventing races in downstream code that depends on the loaded model state.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f7a67c5 and 565a078.

📒 Files selected for processing (2)
  • src/tplr/comms.py (2 hunks)
  • src/tplr/dcp_checkpoint.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tplr/comms.py (5)
src/tplr/logging.py (4)
  • log_with_context (207-223)
  • log_with_context (290-309)
  • P (51-62)
  • T (41-48)
tests/test_prepare_gradient_dict.py (1)
  • info (45-46)
src/tplr/schemas.py (1)
  • CommsGetResult (49-65)
tests/test_dcp_checkpoint.py (1)
  • warning (117-118)
tests/conftest.py (1)
  • totalks (141-144)
🔇 Additional comments (7)
src/tplr/comms.py (4)

1673-1676: LGTM – Explicit type annotations improve clarity.

The type annotations for aggregation state make the expected data structures clear and help catch type errors early.


1707-1708: LGTM – torch.no_grad() wrapper for memory efficiency.

Wrapping the aggregation loop in no_grad() is the right approach since gradient tracking is unnecessary during gather operations. This addresses the memory leak concern mentioned in the PR objectives.


1884-1911: LGTM – Clean aggregation with explicit resource cleanup.

The aggregation correctly accumulates compressed tensors without unnecessary copies, and the explicit del statements drop per-UID references immediately after processing to prevent memory accumulation.


1917-1918: LGTM – Explicit cleanup of batch responses.

Deleting batch_responses after processing helps with immediate memory reclamation rather than waiting for garbage collection, which is important for long-running training processes.

src/tplr/dcp_checkpoint.py (3)

756-758: LGTM – Initial barrier ensures window agreement across ranks.

The barrier before file enumeration ensures all ranks have agreed on the target window and created the local directory, preventing race conditions where some ranks might start downloading before others are ready.


878-880: LGTM – Consistent process group for barriers and load.

Passing process_group to download_distributed ensures that all barriers use the same process group as the subsequent load() call, which is essential for HSDP and other topologies with multiple process groups.


895-904: LGTM – Robust retry logic for sidecar reads.

Wrapping the file read in asyncio.to_thread prevents blocking the event loop, and the retry mechanism with 5 attempts handles transient filesystem visibility issues common on NFS/Lustre.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
src/tplr/comms.py (1)

1834-1853: Fix device comparison to avoid always taking the .to() branch

In the vals validation path, the current check:

target_device = torch.device(device)

if tensor.device.type == target_device:
    tensor_to_check = tensor
    needs_delete = False
else:
    tensor_to_check = tensor.to(target_device, non_blocking=True)
    needs_delete = True

compares tensor.device.type (a string like "cuda") to a torch.device object, so it will always take the else branch and call .to(...), even when the tensor is already on the requested device. This defeats the goal of avoiding unnecessary copies and makes needs_delete unreliable.

A more robust pattern:

-                                target_device = torch.device(device)
-                                
-                                if tensor.device.type == target_device:
-                                    tensor_to_check = tensor
-                                    needs_delete = False
-                                else:
-                                    tensor_to_check = tensor.to(target_device, non_blocking=True)
-                                    needs_delete = True
+                                target_device = torch.device(device)
+
+                                if tensor.device == target_device or tensor.device.type == target_device.type:
+                                    tensor_to_check = tensor
+                                    needs_delete = False
+                                else:
+                                    tensor_to_check = tensor.to(target_device, non_blocking=True)
+                                    # only treat as temporary if a new tensor was created
+                                    needs_delete = tensor_to_check is not tensor

This keeps the NaN/Inf check device-aware without forcing redundant .to() calls when the tensor is already on the right device and avoids deleting aliases unnecessarily.

🧹 Nitpick comments (1)
src/tplr/dcp_checkpoint.py (1)

755-792: Align world/rank computation with the passed process_group

download_distributed now takes a process_group, but world, r are still derived from the default group via _world() / _rank() with no arguments. If you ever call this with a non-default PG (e.g., shard PG), assignment of assigned shards and logging will use the wrong world/r.

Consider threading process_group through:

-        world, r = _world(), _rank()
+        world, r = _world(process_group), _rank(process_group)

This keeps both the work partitioning and the barriers consistently scoped to the same group.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 565a078 and 5b6cc2e.

📒 Files selected for processing (2)
  • src/tplr/comms.py (2 hunks)
  • src/tplr/dcp_checkpoint.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tplr/comms.py (3)
src/tplr/logging.py (4)
  • log_with_context (207-223)
  • log_with_context (290-309)
  • P (51-62)
  • T (41-48)
src/tplr/schemas.py (1)
  • CommsGetResult (49-65)
tests/conftest.py (1)
  • totalks (141-144)
🔇 Additional comments (1)
src/tplr/comms.py (1)

1673-1925: Gather refactor (no_grad + per‑UID validation + cleanup) looks solid

The restructure around torch.no_grad(), explicit aggregated_state_dict / valid_uids / skipped_uids, and aggressive del of state_dict_resp, response, and batch_responses after aggregation is a nice improvement for both correctness (no autograd graphs) and memory footprint over long runs. Validation of quant_params, idxs, and vals before aggregation also reads coherent and defensive.

No issues spotted with control flow or early‑return behavior here.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
src/tplr/dcp_checkpoint.py (1)

914-919: Fix all_reduce device and group parameters to prevent NCCL errors and deadlocks.

The sidecar read synchronization has two issues previously flagged:

  1. ok_tensor on CPU will fail with NCCL backend (NCCL doesn't support CPU tensors).
  2. Missing group=process_group causes misalignment with other barriers when download_and_load is called on a subgroup.

Apply this diff:

         if dist.is_available() and dist.is_initialized():
             ok_tensor = torch.tensor([1 if read_ok else 0], dtype=torch.int32)
-            # always keep on CPU for safety
-            dist.all_reduce(ok_tensor, op=dist.ReduceOp.MIN)
+            # Use GPU when available (NCCL requires it), CPU is fine for Gloo
+            if torch.cuda.is_available():
+                ok_tensor = ok_tensor.to(torch.cuda.current_device())
+            dist.all_reduce(ok_tensor, op=dist.ReduceOp.MIN, group=process_group)
             read_ok = bool(ok_tensor.item())
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5b6cc2e and 3b00b67.

📒 Files selected for processing (3)
  • src/tplr/comms.py (2 hunks)
  • src/tplr/dcp_checkpoint.py (3 hunks)
  • tests/conftest.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tplr/comms.py (5)
src/tplr/logging.py (5)
  • log_with_context (207-223)
  • log_with_context (290-309)
  • debug (106-110)
  • P (51-62)
  • T (41-48)
tests/test_prepare_gradient_dict.py (1)
  • info (45-46)
src/tplr/schemas.py (1)
  • CommsGetResult (49-65)
tests/test_dcp_checkpoint.py (1)
  • warning (117-118)
tests/conftest.py (1)
  • totalks (141-144)
🔇 Additional comments (10)
tests/conftest.py (1)

12-33: Intentional override of environment variables for test reliability.

The switch from setdefault() to direct assignment correctly addresses the case where CI might set empty strings for these variables. This ensures mock values are consistently used during tests.

One consideration: developers with real R2 credentials in their shell environment will have them overridden during test runs. This is typically acceptable for test isolation, but worth noting in case debugging real S3 connectivity is needed.

src/tplr/dcp_checkpoint.py (3)

756-758: Initial barrier ensures consistent state before download.

This barrier correctly synchronizes all ranks before proceeding with the download, ensuring they've all created local_dir and agreed on the window. Good addition for shared filesystem scenarios.


876-881: Correct propagation of process_group to download_distributed.

Passing process_group ensures that all barriers within download_distributed operate on the same group as the caller, preventing cross-group deadlock scenarios.


936-942: Pre/post load barriers correctly synchronize DCP load phase.

The barriers ensure all ranks enter and exit the load phase together, which is essential for consistent distributed state.

One edge case: if load_local() raises an exception on a subset of ranks, the remaining ranks will hang at the post-load barrier (line 942). Consider wrapping in try/except with failure synchronization similar to the sidecar read pattern if load failures are expected.

src/tplr/comms.py (6)

1673-1676: Good: Explicit type annotations improve code clarity.

The type annotations for aggregated_state_dict, valid_uids, skipped_uids, and global_steps make the data flow clearer and help with IDE support.


1707-1708: Critical memory fix: no_grad context prevents gradient graph accumulation.

Wrapping the aggregation loop in torch.no_grad() is essential for preventing memory leaks. Without it, operations on tensors could accumulate computation graphs, causing unbounded memory growth over long runs.


1759-1786: Quant params validation correctly handles both float and tensor scales.

The validation now properly checks:

  • Finite shift value
  • Float scale: finite, not too small (< 1e-12), not too large (> 1e4)
  • Tensor scale: all finite, abs max not too small/large
  • Lookup table: all finite values

The explicit parentheses fix the operator precedence issue from the past review.


1889-1917: Aggregation preserves tensors efficiently with proper cleanup.

The code correctly:

  1. Appends tensors directly without unnecessary copies
  2. Tracks download bytes for metrics
  3. Explicitly deletes per-UID references after aggregation

The tensors in aggregated_state_dict remain alive (as needed for the return value), but UID-local references are promptly released.


1923-1924: Proper cleanup of batch_responses list after processing.

Deleting batch_responses after the loop releases the list structure. Combined with the per-UID deletions (lines 1916-1917), this ensures no stale references persist after the gather operation.


1878-1886: Proper cleanup on validation failure prevents memory leaks from invalid responses.

The code correctly releases references for skipped UIDs before continuing, ensuring invalid responses don't accumulate in memory over the gather loop.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant