Skip to content

Commit d64e808

Browse files
authored
Support FSDP Checkpoint Saving & Loading (#633)
1 parent 2e27587 commit d64e808

File tree

3 files changed

+191
-15
lines changed

3 files changed

+191
-15
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
emoji
2+
immutabledict
23
nltk
4+
numpy==1.26.4
35
spacy==3.7.4
46
syllapy
5-
numpy==1.26.4
6-
immutabledict

slime/backends/fsdp_utils/actor.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from slime.utils.timer import Timer, timer
2323
from slime.utils.wandb_utils import init_wandb_secondary
2424

25+
from . import checkpoint
2526
from .data_packing import pack_sequences, unpack_sequences
2627
from .fsdp_cpu_adam_wrapper import FSDPCPUAdamWrapper
2728
from .update_weight_utils import UpdateWeightFromDistributed, UpdateWeightFromTensor
@@ -67,6 +68,9 @@ def init(self, args: Namespace, role: str, wandb_run_id: str, with_ref: bool = F
6768
self.args = args
6869
torch.manual_seed(args.seed)
6970

71+
if getattr(self.args, "start_rollout_id", None) is None:
72+
self.args.start_rollout_id = 0
73+
7074
if args.record_memory_history:
7175
profile_utils.attach_oom_dump_memory_history(profile_utils.get_memory_snapshot_full_path(args))
7276

@@ -91,6 +95,11 @@ def init(self, args: Namespace, role: str, wandb_run_id: str, with_ref: bool = F
9195
if args.gradient_checkpointing:
9296
model.gradient_checkpointing_enable()
9397

98+
checkpoint_payload = checkpoint.load(self)
99+
if checkpoint_payload is not None and checkpoint_payload.get("model") is not None:
100+
model.load_state_dict(checkpoint_payload["model"], strict=True)
101+
checkpoint_payload["model"] = None
102+
94103
# Create FSDP v2 model using FSDP
95104
self.model = apply_fsdp2(model)
96105

@@ -120,8 +129,9 @@ def init(self, args: Namespace, role: str, wandb_run_id: str, with_ref: bool = F
120129
f"Unsupported optimizer: {args.optimizer}. Supported options: 'adam', 'deepspeed_cpu_adam'"
121130
)
122131

123-
# TODO: load
124-
132+
self.global_step = 0
133+
self.micro_step = 0
134+
self._latest_checkpoint_iteration: int | None = None
125135
self.weights = {"actor": {}}
126136

127137
self.ref_model = None
@@ -136,16 +146,16 @@ def init(self, args: Namespace, role: str, wandb_run_id: str, with_ref: bool = F
136146
else UpdateWeightFromDistributed(self.args, self.model, self.weights)
137147
)
138148

149+
checkpoint.finalize_load(self, checkpoint_payload)
150+
139151
# Initialize data packing parameters
140152
self.max_tokens_per_gpu = args.max_tokens_per_gpu # From main arguments
141153

142154
if self.args.offload_train:
143155
self.sleep()
144156

145157
Timer().start("train_wait")
146-
self.global_step = 0
147-
self.micro_step = 0
148-
return 0
158+
return int(getattr(self.args, "start_rollout_id", 0))
149159

150160
def sleep(self) -> None:
151161
"""Pause CUDA memory for all tracked tensors."""
@@ -204,16 +214,11 @@ def wake_up(self) -> None:
204214
print_memory("after wake_up model")
205215

206216
def save_model(self, iteration: int) -> None:
207-
"""Save model state and optimizer state for the given iteration.
208-
209-
Parameters:
210-
iteration: Global training step to associate with the checkpoint.
211-
212-
"""
213-
if self.args.debug_rollout_only:
217+
"""Delegate checkpoint saving to the shared checkpoint utilities."""
218+
if self.args.debug_rollout_only or self.args.save is None:
214219
return
215220

216-
raise NotImplementedError()
221+
checkpoint.save(self, iteration)
217222

218223
def compute_log_prob(
219224
self,
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
from __future__ import annotations
2+
3+
import json
4+
import time
5+
from pathlib import Path
6+
from typing import Any
7+
8+
import torch
9+
import torch.distributed as dist
10+
11+
12+
def _read_checkpoint_metadata(path: Path) -> dict[str, Any]:
13+
if not path.exists():
14+
return {}
15+
try:
16+
return json.loads(path.read_text())
17+
except json.JSONDecodeError:
18+
print(f"Warning: failed to parse checkpoint metadata at {path}")
19+
return {}
20+
21+
22+
def _write_checkpoint_metadata(path: Path, metadata: dict[str, Any]) -> None:
23+
tmp_path = path.with_suffix(path.suffix + ".tmp")
24+
tmp_path.write_text(json.dumps(metadata, indent=2, sort_keys=True))
25+
tmp_path.replace(path)
26+
27+
28+
def load(actor: Any) -> dict[str, Any] | None:
29+
"""Prepare checkpoint payload for a training actor."""
30+
load_root = getattr(actor.args, "load", None)
31+
if load_root is None:
32+
return None
33+
34+
root_path = Path(load_root).expanduser()
35+
if not root_path.exists():
36+
print(f"[FSDP] Checkpoint directory {root_path} not found; skipping load.")
37+
return None
38+
39+
target_step = getattr(actor.args, "ckpt_step", None)
40+
if target_step is None:
41+
tracker_file = root_path / "latest_checkpointed_iteration.txt"
42+
if not tracker_file.exists():
43+
print(f"[FSDP] No tracker file at {tracker_file}; skipping load.")
44+
return None
45+
tracker_text = tracker_file.read_text().strip()
46+
target_step = int(tracker_text)
47+
48+
checkpoint_dir = root_path / f"iter_{target_step:07d}"
49+
model_ckpt = checkpoint_dir / "model.pt"
50+
if not model_ckpt.exists():
51+
print(f"[FSDP] Checkpoint {model_ckpt} not found; skipping load.")
52+
return None
53+
54+
model_payload = torch.load(model_ckpt, map_location="cpu")
55+
if isinstance(model_payload, dict) and any(isinstance(v, torch.Tensor) for v in model_payload.values()):
56+
model_state = model_payload
57+
else:
58+
model_state = model_payload.get("model", {})
59+
if not model_state:
60+
raise RuntimeError(f"Invalid model checkpoint payload at {model_ckpt}")
61+
62+
optimizer_state = None
63+
optimizer_path = checkpoint_dir / "optimizer.pt"
64+
if optimizer_path.exists():
65+
optimizer_state = torch.load(optimizer_path, map_location="cpu")
66+
67+
rng_state = None
68+
rng_path = checkpoint_dir / "rng.pt"
69+
if rng_path.exists():
70+
rng_state = torch.load(rng_path, map_location="cpu")
71+
72+
metadata = _read_checkpoint_metadata(checkpoint_dir / "meta.json")
73+
74+
return {
75+
"model": model_state,
76+
"optimizer": optimizer_state,
77+
"rng": rng_state,
78+
"metadata": metadata,
79+
"iteration": target_step,
80+
}
81+
82+
83+
def finalize_load(actor: Any, checkpoint_payload: dict[str, Any] | None) -> None:
84+
"""Finalize checkpoint load by restoring optimizer, RNG, and metadata."""
85+
if checkpoint_payload is None:
86+
dist.barrier()
87+
return
88+
89+
if checkpoint_payload.get("optimizer") is not None and not getattr(actor.args, "no_load_optim", False):
90+
optimizer_state = checkpoint_payload["optimizer"]
91+
if actor.args.optimizer == "deepspeed_cpu_adam":
92+
actor.optimizer.cpu_optimizer.load_state_dict(optimizer_state)
93+
else:
94+
actor.optimizer.load_state_dict(optimizer_state)
95+
checkpoint_payload["optimizer"] = None
96+
97+
if checkpoint_payload.get("rng") is not None and not getattr(actor.args, "no_load_rng", False):
98+
rng_state = checkpoint_payload["rng"]
99+
if "torch" in rng_state:
100+
torch.set_rng_state(rng_state["torch"])
101+
if torch.cuda.is_available() and "cuda" in rng_state:
102+
torch.cuda.set_rng_state_all(rng_state["cuda"])
103+
checkpoint_payload["rng"] = None
104+
105+
metadata = checkpoint_payload.get("metadata") or {}
106+
iteration = checkpoint_payload.get("iteration")
107+
if metadata:
108+
actor.global_step = int(metadata.get("global_step", actor.global_step))
109+
actor.micro_step = int(metadata.get("micro_step", actor.micro_step))
110+
actor._latest_checkpoint_iteration = int(metadata.get("iteration", iteration))
111+
next_rollout = metadata.get("next_rollout_id")
112+
if next_rollout is not None:
113+
actor.args.start_rollout_id = next_rollout
114+
elif iteration is not None:
115+
actor._latest_checkpoint_iteration = iteration
116+
if getattr(actor.args, "start_rollout_id", None) is None:
117+
actor.args.start_rollout_id = iteration
118+
checkpoint_payload["metadata"] = None
119+
120+
torch.cuda.synchronize()
121+
dist.barrier()
122+
123+
124+
def save(actor: Any, iteration: int) -> None:
125+
"""Persist model, optimizer, and metadata for the given iteration."""
126+
torch.cuda.synchronize()
127+
128+
base_dir = Path(actor.args.save).expanduser()
129+
step_id = iteration + 1
130+
checkpoint_dir = base_dir / f"iter_{step_id:07d}"
131+
132+
if dist.get_rank() == 0:
133+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
134+
dist.barrier()
135+
136+
actor.update_cpu_params_dict(actor.weights["actor"])
137+
138+
if actor.args.optimizer == "deepspeed_cpu_adam":
139+
optimizer_state = actor.optimizer.cpu_optimizer.state_dict()
140+
else:
141+
optimizer_state = actor.optimizer.state_dict()
142+
143+
if dist.get_rank() == 0:
144+
model_payload = {
145+
"format_version": 1,
146+
"model": {name: tensor for name, tensor in actor.weights["actor"].items()},
147+
}
148+
torch.save(model_payload, checkpoint_dir / "model.pt")
149+
torch.save(optimizer_state, checkpoint_dir / "optimizer.pt")
150+
151+
rng_state = {"torch": torch.get_rng_state()}
152+
rng_state["cuda"] = torch.cuda.get_rng_state_all()
153+
torch.save(rng_state, checkpoint_dir / "rng.pt")
154+
155+
metadata = {
156+
"iteration": step_id,
157+
"rollout_id": iteration,
158+
"next_rollout_id": iteration + 1,
159+
"global_step": actor.global_step,
160+
"micro_step": actor.micro_step,
161+
"world_size": dist.get_world_size(),
162+
"timestamp": time.time(),
163+
}
164+
_write_checkpoint_metadata(checkpoint_dir / "meta.json", metadata)
165+
166+
tracker_file = base_dir / "latest_checkpointed_iteration.txt"
167+
tracker_file.write_text(str(step_id))
168+
print(f"[FSDP] Saved checkpoint to {checkpoint_dir}")
169+
actor._latest_checkpoint_iteration = step_id
170+
171+
dist.barrier()

0 commit comments

Comments
 (0)