Skip to content

Commit 3bac959

Browse files
committed
delete unuseful logs
1 parent 4edc173 commit 3bac959

File tree

1 file changed

+2
-11
lines changed

1 file changed

+2
-11
lines changed

slime/backends/fsdp_utils/actor.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,12 @@ def init(self, args, role, wandb_run_id, with_ref: bool = False): # type: ignor
4040
self.args = args
4141
torch.manual_seed(args.seed)
4242

43-
# Serialize tokenizer/config loading across ranks to avoid HF cache race
4443
for i in range(dist.get_world_size()):
4544
if i == dist.get_rank():
4645
self.hf_config = AutoConfig.from_pretrained(self.args.hf_checkpoint, trust_remote_code=True)
4746
self.tokenizer = AutoTokenizer.from_pretrained(self.args.hf_checkpoint, trust_remote_code=True)
4847
dist.barrier(group=get_gloo_group())
4948

50-
# Load model
5149
with torch.device(f"cuda:{torch.cuda.current_device()}"):
5250
model = AutoModelForCausalLM.from_pretrained(
5351
self.args.hf_checkpoint,
@@ -141,7 +139,6 @@ def compute_log_prob(
141139
padded_batches: Input batches
142140
store_prefix: Prefix for storing results (e.g., "ref_")
143141
"""
144-
# Save current model parameters if switching to different model
145142
current_params = None
146143
if model_tag != "actor" and model_tag in self.weights:
147144
current_params = {}
@@ -150,9 +147,8 @@ def compute_log_prob(
150147
for name, param in current_state_dict.items():
151148
current_params[name] = param.clone()
152149

153-
# Load the specified model parameters
154150
self.update_gpu_params_dict(self.weights[model_tag])
155-
self.model.eval() # Set to eval mode for ref model
151+
self.model.eval()
156152

157153
try:
158154
rollout_data = {f"{store_prefix}log_probs": []}
@@ -164,11 +160,10 @@ def compute_log_prob(
164160
return rollout_data
165161

166162
finally:
167-
# Restore original model parameters if we switched
168163
if current_params is not None:
169164
with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT):
170165
self.model.load_state_dict(current_params, strict=True)
171-
self.model.train() # Restore training mode
166+
self.model.train()
172167
torch.cuda.synchronize()
173168

174169
def pad_and_move_to_device(self, rollout_data):
@@ -324,11 +319,9 @@ def train(self, rollout_id, rollout_data_ref): # type: ignore[override]
324319
reported["kl_loss"] = kl_loss.detach()
325320
reported["kl_loss_coef"] = torch.tensor(self.args.kl_loss_coef, device=kl_loss.device)
326321

327-
# Scale loss for gradient accumulation
328322
loss = loss / grad_accum
329323
loss.backward()
330324

331-
# Accumulate reported metrics (store tensors for later mean)
332325
for k, v in reported.items():
333326
reported_accum.setdefault(k, []).append(v)
334327

@@ -337,12 +330,10 @@ def train(self, rollout_id, rollout_data_ref): # type: ignore[override]
337330
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad)
338331
self.optimizer.step()
339332
self.optimizer.zero_grad(set_to_none=True)
340-
# Aggregate logs
341333
aggregated = {k: torch.stack(v).mean().item() for k, v in reported_accum.items()}
342334
# TODO: change this, this is slow.
343335
reduced_aggregated = [None] * world_size
344336
dist.all_gather_object(reduced_aggregated, aggregated)
345-
# Mean across dp ranks
346337
aggregated = {}
347338
for k in reported_accum.keys():
348339
aggregated[k] = sum([r[k] for r in reduced_aggregated]) / world_size

0 commit comments

Comments
 (0)