4747
4848from ..shared .constants import (
4949 BASE_CHECKPOINT_RETENTION_LIMIT ,
50- CHECKPOINT_MILESTONE_INTERVAL ,
51- DELTA_BASE_INTERVAL ,
5250 GRAIL_CHECKPOINT_MOD10 ,
53- WINDOW_LENGTH ,
5451)
5552from . import comms
5653from .delta_checkpoint import apply_sparse_delta , compute_weights_hash
@@ -304,8 +301,6 @@ async def apply_delta_in_place(
304301 Returns:
305302 True if delta was applied successfully, False if fallback to full load needed
306303 """
307- import torch
308-
309304 # Validate inputs
310305 if target_window <= current_window :
311306 logger .debug (
@@ -358,7 +353,7 @@ async def apply_delta_in_place(
358353 # Get model's current state dict (on device)
359354 current_state = model .state_dict ()
360355
361- # Apply delta in float32, cast back to bf16
356+ # Apply delta - dtype is inferred from current_state
362357 logger .debug (
363358 "Applying delta: %.2f%% sparse, %d params changed" ,
364359 delta_info .get ("sparsity_ratio" , 0 ) * 100 ,
@@ -369,7 +364,7 @@ async def apply_delta_in_place(
369364 current_state ,
370365 sparse_tensors ,
371366 shapes ,
372- target_dtype = torch . bfloat16 ,
367+ target_dtype = None , # Infer from current_state
373368 )
374369
375370 # Verify hash if available
@@ -851,13 +846,24 @@ async def _handle_delta_checkpoint(
851846
852847 base_path , delta_chain = chain
853848
854- logger .info (
855- "Built delta chain: base=%s (cached=%s), chain_length=%d, target=%s" ,
856- delta_chain [0 ].prev_window if delta_chain else "N/A" ,
857- base_path is not None ,
858- len (delta_chain ),
859- metadata .window ,
860- )
849+ # Log recovery mode when catching up multiple missed windows
850+ if len (delta_chain ) > 1 :
851+ first_delta = delta_chain [0 ]
852+ last_delta = delta_chain [- 1 ]
853+ logger .info (
854+ "🔄 Recovery mode: catching up %d missed windows (%s -> %s)" ,
855+ len (delta_chain ),
856+ first_delta .prev_window ,
857+ last_delta .window ,
858+ )
859+ else :
860+ logger .info (
861+ "Built delta chain: base=%s (cached=%s), chain_length=%d, target=%s" ,
862+ delta_chain [0 ].prev_window if delta_chain else "N/A" ,
863+ base_path is not None ,
864+ len (delta_chain ),
865+ metadata .window ,
866+ )
861867
862868 if base_path is None :
863869 logger .error ("Cannot find base checkpoint for chain reconstruction" )
@@ -882,8 +888,6 @@ async def _apply_single_delta(
882888 Returns:
883889 Path to reconstructed checkpoint, or None on failure
884890 """
885- import torch
886-
887891 try :
888892 # Download and load delta
889893 delta_data = await self ._download_and_load_delta (delta_metadata )
@@ -906,12 +910,12 @@ async def _apply_single_delta(
906910 delta_info .get ("sparsity_ratio" , 0 ) * 100 ,
907911 )
908912
909- # Apply delta (float32 computation, bf16 output)
913+ # Apply delta - dtype is inferred from prev_state
910914 reconstructed = apply_sparse_delta (
911915 prev_state ,
912916 sparse_tensors ,
913917 shapes ,
914- target_dtype = torch . bfloat16 ,
918+ target_dtype = None , # Infer from prev_state
915919 )
916920
917921 # Verify hash
@@ -1014,8 +1018,6 @@ async def _apply_delta_chain(
10141018 Returns:
10151019 Path to reconstructed checkpoint directory, or None on failure
10161020 """
1017- import torch
1018-
10191021 try :
10201022 # Load anchor weights
10211023 current_state = load_model_state_dict (anchor_path )
@@ -1040,12 +1042,12 @@ async def _apply_delta_chain(
10401042
10411043 sparse_tensors , shapes , delta_info = delta_data
10421044
1043- # Apply sparse delta and cast to bf16 (bit-exact as analyzed)
1045+ # Apply sparse delta - dtype is inferred from current_state
10441046 current_state = apply_sparse_delta (
10451047 current_state ,
10461048 sparse_tensors ,
10471049 shapes ,
1048- target_dtype = torch . bfloat16 ,
1050+ target_dtype = None , # Infer from current_state
10491051 )
10501052
10511053 logger .debug (
@@ -1229,6 +1231,9 @@ async def _write_reconstructed_checkpoint(
12291231 def _compute_keep_windows (self , current_window : int ) -> set [int ]:
12301232 """Calculate which checkpoint windows should be retained.
12311233
1234+ Delegates to the shared retention utility for consistent behavior
1235+ between publisher (remote cleanup) and consumer (local cache cleanup).
1236+
12321237 For chained deltas, we must keep the entire chain from the current
12331238 anchor (FULL) to now, plus the previous anchor for miners catching up.
12341239
@@ -1243,44 +1248,9 @@ def _compute_keep_windows(self, current_window: int) -> set[int]:
12431248 Returns:
12441249 Set of window numbers to retain
12451250 """
1246- keep : set [int ] = set ()
1247- if current_window < 0 :
1248- return keep
1249-
1250- # Always keep windows 0-9 (bootstrap)
1251- keep .update (range (10 ))
1252-
1253- # Calculate current anchor (last FULL boundary)
1254- delta_base_interval_windows = max (1 , int (DELTA_BASE_INTERVAL ))
1255- anchor_stride = delta_base_interval_windows * int (WINDOW_LENGTH )
1256- current_anchor = (current_window // anchor_stride ) * anchor_stride
1257-
1258- # Keep all windows from current anchor to now (the active chain)
1259- w = current_anchor
1260- while w <= current_window :
1261- keep .add (w )
1262- w += WINDOW_LENGTH
1263-
1264- # Keep previous anchor for miners still catching up
1265- prev_anchor = current_anchor - anchor_stride
1266- if prev_anchor >= 0 :
1267- keep .add (prev_anchor )
1268- # Also keep the chain from previous anchor to current anchor
1269- # This allows miners who were on the old chain to transition
1270- w = prev_anchor
1271- while w < current_anchor :
1272- keep .add (w )
1273- w += WINDOW_LENGTH
1274-
1275- # Keep milestones (every CHECKPOINT_MILESTONE_INTERVAL windows)
1276- interval_blocks = CHECKPOINT_MILESTONE_INTERVAL * WINDOW_LENGTH
1277- if interval_blocks > 0 :
1278- milestone = (current_window // interval_blocks ) * interval_blocks
1279- while milestone >= 0 :
1280- keep .add (milestone )
1281- milestone -= interval_blocks
1282-
1283- return keep
1251+ from grail .shared .retention_utils import compute_retention_windows
1252+
1253+ return compute_retention_windows (current_window )
12841254
12851255 async def get_latest_ready_checkpoint (self , before_window : int ) -> int | None :
12861256 """Find the latest checkpoint that became READY before the given window.
0 commit comments