Skip to content

Commit 4ea6032

Browse files
committed
feat(checkpoint): implement background FULL checkpoint upload and enhance retention strategy
- **Background FULL Upload**: Introduced a non-blocking method to upload FULL checkpoints during anchor windows, improving efficiency for new miners. - **Retention Policy Enhancement**: Updated retention logic to combine chain-based and dependency-based strategies for better checkpoint management. - **Dynamic Dtype Handling**: Modified delta application to infer data types from current states, enhancing flexibility. - **Logging Improvements**: Enhanced logging for recovery mode and upload processes to provide clearer insights into operations. - **Code Cleanup**: Removed unused imports and streamlined retention window calculations for improved readability.
1 parent 809afe2 commit 4ea6032

File tree

4 files changed

+379
-83
lines changed

4 files changed

+379
-83
lines changed

grail/infrastructure/checkpoint_consumer.py

Lines changed: 30 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,7 @@
4747

4848
from ..shared.constants import (
4949
BASE_CHECKPOINT_RETENTION_LIMIT,
50-
CHECKPOINT_MILESTONE_INTERVAL,
51-
DELTA_BASE_INTERVAL,
5250
GRAIL_CHECKPOINT_MOD10,
53-
WINDOW_LENGTH,
5451
)
5552
from . import comms
5653
from .delta_checkpoint import apply_sparse_delta, compute_weights_hash
@@ -304,8 +301,6 @@ async def apply_delta_in_place(
304301
Returns:
305302
True if delta was applied successfully, False if fallback to full load needed
306303
"""
307-
import torch
308-
309304
# Validate inputs
310305
if target_window <= current_window:
311306
logger.debug(
@@ -358,7 +353,7 @@ async def apply_delta_in_place(
358353
# Get model's current state dict (on device)
359354
current_state = model.state_dict()
360355

361-
# Apply delta in float32, cast back to bf16
356+
# Apply delta - dtype is inferred from current_state
362357
logger.debug(
363358
"Applying delta: %.2f%% sparse, %d params changed",
364359
delta_info.get("sparsity_ratio", 0) * 100,
@@ -369,7 +364,7 @@ async def apply_delta_in_place(
369364
current_state,
370365
sparse_tensors,
371366
shapes,
372-
target_dtype=torch.bfloat16,
367+
target_dtype=None, # Infer from current_state
373368
)
374369

375370
# Verify hash if available
@@ -851,13 +846,24 @@ async def _handle_delta_checkpoint(
851846

852847
base_path, delta_chain = chain
853848

854-
logger.info(
855-
"Built delta chain: base=%s (cached=%s), chain_length=%d, target=%s",
856-
delta_chain[0].prev_window if delta_chain else "N/A",
857-
base_path is not None,
858-
len(delta_chain),
859-
metadata.window,
860-
)
849+
# Log recovery mode when catching up multiple missed windows
850+
if len(delta_chain) > 1:
851+
first_delta = delta_chain[0]
852+
last_delta = delta_chain[-1]
853+
logger.info(
854+
"🔄 Recovery mode: catching up %d missed windows (%s -> %s)",
855+
len(delta_chain),
856+
first_delta.prev_window,
857+
last_delta.window,
858+
)
859+
else:
860+
logger.info(
861+
"Built delta chain: base=%s (cached=%s), chain_length=%d, target=%s",
862+
delta_chain[0].prev_window if delta_chain else "N/A",
863+
base_path is not None,
864+
len(delta_chain),
865+
metadata.window,
866+
)
861867

862868
if base_path is None:
863869
logger.error("Cannot find base checkpoint for chain reconstruction")
@@ -882,8 +888,6 @@ async def _apply_single_delta(
882888
Returns:
883889
Path to reconstructed checkpoint, or None on failure
884890
"""
885-
import torch
886-
887891
try:
888892
# Download and load delta
889893
delta_data = await self._download_and_load_delta(delta_metadata)
@@ -906,12 +910,12 @@ async def _apply_single_delta(
906910
delta_info.get("sparsity_ratio", 0) * 100,
907911
)
908912

909-
# Apply delta (float32 computation, bf16 output)
913+
# Apply delta - dtype is inferred from prev_state
910914
reconstructed = apply_sparse_delta(
911915
prev_state,
912916
sparse_tensors,
913917
shapes,
914-
target_dtype=torch.bfloat16,
918+
target_dtype=None, # Infer from prev_state
915919
)
916920

917921
# Verify hash
@@ -1014,8 +1018,6 @@ async def _apply_delta_chain(
10141018
Returns:
10151019
Path to reconstructed checkpoint directory, or None on failure
10161020
"""
1017-
import torch
1018-
10191021
try:
10201022
# Load anchor weights
10211023
current_state = load_model_state_dict(anchor_path)
@@ -1040,12 +1042,12 @@ async def _apply_delta_chain(
10401042

10411043
sparse_tensors, shapes, delta_info = delta_data
10421044

1043-
# Apply sparse delta and cast to bf16 (bit-exact as analyzed)
1045+
# Apply sparse delta - dtype is inferred from current_state
10441046
current_state = apply_sparse_delta(
10451047
current_state,
10461048
sparse_tensors,
10471049
shapes,
1048-
target_dtype=torch.bfloat16,
1050+
target_dtype=None, # Infer from current_state
10491051
)
10501052

10511053
logger.debug(
@@ -1229,6 +1231,9 @@ async def _write_reconstructed_checkpoint(
12291231
def _compute_keep_windows(self, current_window: int) -> set[int]:
12301232
"""Calculate which checkpoint windows should be retained.
12311233
1234+
Delegates to the shared retention utility for consistent behavior
1235+
between publisher (remote cleanup) and consumer (local cache cleanup).
1236+
12321237
For chained deltas, we must keep the entire chain from the current
12331238
anchor (FULL) to now, plus the previous anchor for miners catching up.
12341239
@@ -1243,44 +1248,9 @@ def _compute_keep_windows(self, current_window: int) -> set[int]:
12431248
Returns:
12441249
Set of window numbers to retain
12451250
"""
1246-
keep: set[int] = set()
1247-
if current_window < 0:
1248-
return keep
1249-
1250-
# Always keep windows 0-9 (bootstrap)
1251-
keep.update(range(10))
1252-
1253-
# Calculate current anchor (last FULL boundary)
1254-
delta_base_interval_windows = max(1, int(DELTA_BASE_INTERVAL))
1255-
anchor_stride = delta_base_interval_windows * int(WINDOW_LENGTH)
1256-
current_anchor = (current_window // anchor_stride) * anchor_stride
1257-
1258-
# Keep all windows from current anchor to now (the active chain)
1259-
w = current_anchor
1260-
while w <= current_window:
1261-
keep.add(w)
1262-
w += WINDOW_LENGTH
1263-
1264-
# Keep previous anchor for miners still catching up
1265-
prev_anchor = current_anchor - anchor_stride
1266-
if prev_anchor >= 0:
1267-
keep.add(prev_anchor)
1268-
# Also keep the chain from previous anchor to current anchor
1269-
# This allows miners who were on the old chain to transition
1270-
w = prev_anchor
1271-
while w < current_anchor:
1272-
keep.add(w)
1273-
w += WINDOW_LENGTH
1274-
1275-
# Keep milestones (every CHECKPOINT_MILESTONE_INTERVAL windows)
1276-
interval_blocks = CHECKPOINT_MILESTONE_INTERVAL * WINDOW_LENGTH
1277-
if interval_blocks > 0:
1278-
milestone = (current_window // interval_blocks) * interval_blocks
1279-
while milestone >= 0:
1280-
keep.add(milestone)
1281-
milestone -= interval_blocks
1282-
1283-
return keep
1251+
from grail.shared.retention_utils import compute_retention_windows
1252+
1253+
return compute_retention_windows(current_window)
12841254

12851255
async def get_latest_ready_checkpoint(self, before_window: int) -> int | None:
12861256
"""Find the latest checkpoint that became READY before the given window.

grail/shared/retention_utils.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""Shared checkpoint retention policy utilities.
2+
3+
This module provides a unified retention policy for determining which checkpoint
4+
windows should be kept in both remote storage (publisher) and local cache (consumer).
5+
6+
For chained deltas, retention must keep entire chains from anchor (FULL) to tip.
7+
"""
8+
9+
from __future__ import annotations
10+
11+
from grail.shared.constants import (
12+
CHECKPOINT_MILESTONE_INTERVAL,
13+
DELTA_BASE_INTERVAL,
14+
WINDOW_LENGTH,
15+
)
16+
17+
18+
def compute_retention_windows(
19+
current_window: int,
20+
bootstrap_windows: int = 10,
21+
) -> set[int]:
22+
"""Calculate which checkpoint windows should be retained.
23+
24+
For chained deltas, we must keep entire chains from anchor (FULL) to tip.
25+
This ensures miners can always reconstruct the current state by:
26+
1. Starting from an anchor (FULL checkpoint)
27+
2. Applying sequential deltas to reach the current window
28+
29+
Retention policy:
30+
- Keep all windows from current anchor to now (active chain)
31+
- Keep previous anchor and its entire chain (for miners catching up)
32+
- Keep milestone checkpoints (every CHECKPOINT_MILESTONE_INTERVAL)
33+
- Keep bootstrap windows (windows 0-N for initial network state)
34+
35+
Args:
36+
current_window: Current window number
37+
bootstrap_windows: Number of initial windows to always keep (default 10)
38+
39+
Returns:
40+
Set of window numbers to retain
41+
"""
42+
keep: set[int] = set()
43+
if current_window < 0:
44+
return keep
45+
46+
# Always keep bootstrap windows
47+
for i in range(bootstrap_windows):
48+
keep.add(i * WINDOW_LENGTH)
49+
50+
# Calculate anchor stride (blocks between FULL checkpoints)
51+
delta_base_interval_windows = max(1, int(DELTA_BASE_INTERVAL))
52+
anchor_stride = delta_base_interval_windows * int(WINDOW_LENGTH)
53+
54+
# Calculate current anchor (last FULL boundary)
55+
current_anchor = (current_window // anchor_stride) * anchor_stride
56+
57+
# Keep all windows from current anchor to now (the active chain)
58+
w = current_anchor
59+
while w <= current_window:
60+
keep.add(w)
61+
w += WINDOW_LENGTH
62+
63+
# Keep previous anchor and its chain (for miners catching up)
64+
prev_anchor = current_anchor - anchor_stride
65+
if prev_anchor >= 0:
66+
keep.add(prev_anchor)
67+
# Keep entire chain from previous anchor to current anchor
68+
w = prev_anchor
69+
while w < current_anchor:
70+
keep.add(w)
71+
w += WINDOW_LENGTH
72+
73+
# Keep milestone checkpoints (long-term preservation)
74+
if CHECKPOINT_MILESTONE_INTERVAL > 0:
75+
interval_blocks = CHECKPOINT_MILESTONE_INTERVAL * WINDOW_LENGTH
76+
if interval_blocks > 0:
77+
milestone = (current_window // interval_blocks) * interval_blocks
78+
while milestone >= 0:
79+
keep.add(milestone)
80+
milestone -= interval_blocks
81+
82+
return keep
83+
84+
85+
def get_anchor_window(target_window: int) -> int:
86+
"""Get the anchor window (nearest preceding FULL checkpoint) for a given window.
87+
88+
Args:
89+
target_window: The window to find the anchor for
90+
91+
Returns:
92+
The anchor window number
93+
"""
94+
delta_base_interval_windows = max(1, int(DELTA_BASE_INTERVAL))
95+
anchor_stride = delta_base_interval_windows * int(WINDOW_LENGTH)
96+
return (target_window // anchor_stride) * anchor_stride
97+
98+
99+
def is_anchor_window(window: int) -> bool:
100+
"""Check if a window is an anchor (FULL checkpoint) window.
101+
102+
Args:
103+
window: The window number to check
104+
105+
Returns:
106+
True if this window is an anchor window
107+
"""
108+
delta_base_interval_windows = max(1, int(DELTA_BASE_INTERVAL))
109+
anchor_stride = delta_base_interval_windows * int(WINDOW_LENGTH)
110+
return window % anchor_stride == 0

0 commit comments

Comments
 (0)