Skip to content

Commit 4edc173

Browse files
committed
move ref calc to compute_log_prob
1 parent 263bc0d commit 4edc173

File tree

1 file changed

+38
-75
lines changed

1 file changed

+38
-75
lines changed

slime/backends/fsdp_utils/actor.py

Lines changed: 38 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,43 @@ def compute_log_prob(
133133
padded_batches,
134134
store_prefix="",
135135
):
136-
rollout_data = {f"{store_prefix}log_probs": []}
137-
with timer(f"{store_prefix}log_probs") and torch.no_grad():
138-
for batch in padded_batches:
139-
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
140-
logits = self.model(input_ids=batch["tokens"]).logits
141-
batch[f"{store_prefix}log_probs"] = gather_log_probs(logits, batch["tokens"])
142-
return rollout_data
136+
"""
137+
Compute log probabilities using specified model.
138+
139+
Args:
140+
model_tag: "actor" for current model, "ref" for reference model
141+
padded_batches: Input batches
142+
store_prefix: Prefix for storing results (e.g., "ref_")
143+
"""
144+
# Save current model parameters if switching to different model
145+
current_params = None
146+
if model_tag != "actor" and model_tag in self.weights:
147+
current_params = {}
148+
with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT):
149+
current_state_dict = self.model.state_dict()
150+
for name, param in current_state_dict.items():
151+
current_params[name] = param.clone()
152+
153+
# Load the specified model parameters
154+
self.update_gpu_params_dict(self.weights[model_tag])
155+
self.model.eval() # Set to eval mode for ref model
156+
157+
try:
158+
rollout_data = {f"{store_prefix}log_probs": []}
159+
with timer(f"{store_prefix}log_probs") and torch.no_grad():
160+
for batch in padded_batches:
161+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
162+
logits = self.model(input_ids=batch["tokens"]).logits
163+
batch[f"{store_prefix}log_probs"] = gather_log_probs(logits, batch["tokens"])
164+
return rollout_data
165+
166+
finally:
167+
# Restore original model parameters if we switched
168+
if current_params is not None:
169+
with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT):
170+
self.model.load_state_dict(current_params, strict=True)
171+
self.model.train() # Restore training mode
172+
torch.cuda.synchronize()
143173

144174
def pad_and_move_to_device(self, rollout_data):
145175
tokens = rollout_data["tokens"]
@@ -188,7 +218,7 @@ def train(self, rollout_id, rollout_data_ref): # type: ignore[override]
188218
), f"Invalid grad_accum {grad_accum} for micro_batch_size {self.args.micro_batch_size} and global_batch_size {self.args.global_batch_size}"
189219

190220
if "ref" in self.weights:
191-
self.compute_ref_log_probs(padded_batches)
221+
self.compute_log_prob("ref", padded_batches, store_prefix="ref_")
192222

193223
self.compute_log_prob("actor", padded_batches)
194224

@@ -347,8 +377,6 @@ def train(self, rollout_id, rollout_data_ref): # type: ignore[override]
347377

348378
self.update_cpu_params_dict(self.weights["actor"])
349379

350-
self._save_debug_train_data(rollout_id, rollout_data, padded_batches)
351-
352380
Timer().start("train_wait")
353381
return
354382

@@ -430,71 +458,6 @@ def load_ref_model(self, ref_load_path):
430458
self.model.load_state_dict(current_weights, strict=True)
431459
torch.cuda.synchronize()
432460

433-
def compute_ref_log_probs(self, padded_batches):
434-
"""
435-
Compute log probabilities using reference model parameters.
436-
437-
This method temporarily loads ref model parameters from CPU memory
438-
(loaded once during initialization) to GPU, computes forward pass,
439-
then restores original model parameters. No disk I/O involved.
440-
"""
441-
if "ref" not in self.weights:
442-
raise RuntimeError("Reference model weights not loaded")
443-
444-
current_params = {}
445-
with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT):
446-
current_state_dict = self.model.state_dict()
447-
for name, param in current_state_dict.items():
448-
current_params[name] = param.clone()
449-
450-
try:
451-
self.update_gpu_params_dict(self.weights["ref"])
452-
self.model.eval()
453-
for batch in padded_batches:
454-
with torch.no_grad():
455-
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
456-
logits = self.model(input_ids=batch["tokens"]).logits
457-
batch["ref_log_probs"] = gather_log_probs(logits, batch["tokens"])
458-
459-
finally:
460-
with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT):
461-
self.model.load_state_dict(current_params, strict=True)
462-
self.model.train()
463-
torch.cuda.synchronize()
464-
465-
def _log_debug_rollout_data(self, rollout_id, rollout_data):
466-
"""Log rollout data for debugging (similar to Megatron backend)"""
467-
print(f"Debug rollout {rollout_id}: logging rollout data")
468-
469-
def _save_debug_train_data(self, rollout_id, rollout_data, padded_batches):
470-
"""Save debug train data if requested"""
471-
from pathlib import Path
472-
473-
if (path_template := getattr(self.args, 'save_debug_train_data', None)) is not None:
474-
rank = dist.get_rank()
475-
path = Path(path_template.format(rollout_id=rollout_id, rank=rank))
476-
print(f"Save debug train data to {path}")
477-
path.parent.mkdir(parents=True, exist_ok=True)
478-
479-
debug_data = {
480-
'rollout_id': rollout_id,
481-
'rank': rank,
482-
'rollout_data': rollout_data,
483-
'batch_info': []
484-
}
485-
486-
for i, batch in enumerate(padded_batches):
487-
batch_info = {}
488-
for key, value in batch.items():
489-
if isinstance(value, torch.Tensor):
490-
batch_info[key] = value.cpu().detach()
491-
else:
492-
batch_info[key] = value
493-
debug_data['batch_info'].append(batch_info)
494-
495-
torch.save(debug_data, path)
496-
497-
498461
def gather_log_probs(logits: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
499462
# log_probs: [B, T-1, V]; input_ids: [B, T]
500463
pred_logits = logits[:, :-1]

0 commit comments

Comments
 (0)