Skip to content

Commit 5fd2584

Browse files
committed
polish(pu): optimize logging structure
1 parent f39e860 commit 5fd2584

File tree

7 files changed

+285
-58
lines changed

7 files changed

+285
-58
lines changed

zoo/jericho/priorzero/models/actor.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,23 @@ def train_batch(self, batch_data: Dict[str, torch.Tensor], kl_ctl: float, step_i
263263

264264
self.strategy.optimizer_step(self.actor_optim, self.actor, self.actor_scheduler, name="actor")
265265

266+
# Calculate response length statistics
267+
response_lengths = micro_batch['action_mask'].sum(dim=1).float()
268+
avg_response_length = response_lengths.mean().item()
269+
max_response_length = response_lengths.max().item()
270+
min_response_length = response_lengths.min().item()
271+
272+
# Calculate log_probs statistics
273+
valid_log_probs = action_log_probs[micro_batch['action_mask'] > 0]
274+
avg_log_prob = valid_log_probs.mean().item() if valid_log_probs.numel() > 0 else 0.0
275+
276+
# Calculate ratio statistics
277+
log_ratio = action_log_probs - micro_batch['old_action_logprob']
278+
ratio = log_ratio.exp()
279+
valid_ratio = ratio[micro_batch['action_mask'] > 0]
280+
avg_ratio = valid_ratio.mean().item() if valid_ratio.numel() > 0 else 1.0
281+
max_ratio = valid_ratio.max().item() if valid_ratio.numel() > 0 else 1.0
282+
266283
status = {
267284
"policy_loss": actor_loss.detach().float().mean().item(),
268285
"lr": self.actor_scheduler.get_last_lr()[0],
@@ -271,6 +288,14 @@ def train_batch(self, batch_data: Dict[str, torch.Tensor], kl_ctl: float, step_i
271288
# "approx_kl": approx_kl.detach().float().mean().item(),
272289
"cur_old_kl": approx_kl.detach().float().mean().item(),
273290
"iter": self.train_iter,
291+
# Response length statistics
292+
"response_length_avg": avg_response_length,
293+
"response_length_max": max_response_length,
294+
"response_length_min": min_response_length,
295+
# Log prob and ratio statistics
296+
"log_prob_avg": avg_log_prob,
297+
"ratio_avg": avg_ratio,
298+
"ratio_max": max_ratio,
274299
}
275300
log_status = micro_batch["log_status"]
276301
other_status = {k: [item[k] for item in log_status] for k in log_status[0].keys()}
@@ -298,6 +323,10 @@ def train_batch(self, batch_data: Dict[str, torch.Tensor], kl_ctl: float, step_i
298323
return status_list
299324

300325
def _deepspeed_broadcast(self):
326+
# FIX: Add barrier before vLLM weight update to prevent NCCL deadlock with tp>1
327+
if torch.distributed.is_initialized():
328+
torch.distributed.barrier()
329+
301330
use_prefix_cache = getattr(self.strategy.args, "enable_prefix_caching", False)
302331
if use_prefix_cache:
303332
self.vllm_engine.reset_prefix_cache()
@@ -310,7 +339,11 @@ def _deepspeed_broadcast(self):
310339
# For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0
311340
with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3):
312341
shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape
313-
self.vllm_engine.update_weight(name, dtype=param.dtype, shape=shape, weight=param.data, empty_cache=(count == num_params))
342+
self.vllm_engine.update_weight(name, dtype=param.dtype, shape=shape, weight=param.data, empty_cache=(count == num_params))
343+
344+
# FIX: Add barrier after vLLM weight update to ensure all ranks complete
345+
if torch.distributed.is_initialized():
346+
torch.distributed.barrier()
314347

315348
def _broadcast_to_vllm(self):
316349
use_prefix_cache = getattr(self.strategy.args, "enable_prefix_caching", False)

zoo/jericho/priorzero/priorzero_collector.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -481,14 +481,22 @@ def collect(
481481
# Episode Done
482482
# ==============================================================
483483
if episode_timestep.done:
484-
self._logger.info(f'======== Env {env_id} episode finished! ========')
485484
self._total_episode_count += 1
486485
# Logging
487486
info_log = {
488487
'reward': episode_timestep.info['eval_episode_return'],
489488
'time': self._env_info[env_id]['time'],
490489
'step': self._env_info[env_id]['step'],
491490
'llm_prior_entropy': sum(llm_prior_entropy[env_id])/len(llm_prior_entropy[env_id])}
491+
492+
# Structured episode completion log
493+
self._logger.info(
494+
f"[Episode Complete] Env={env_id} | "
495+
f"Reward={info_log['reward']:.2f} | "
496+
f"Steps={info_log['step']} | "
497+
f"Time={info_log['time']:.2f}s | "
498+
f"LLM_Entropy={info_log['llm_prior_entropy']:.3f}"
499+
)
492500
if not collect_with_pure_policy:
493501
info_log['visit_entropy'] = (
494502
visit_entropies_lst[env_id] / eps_steps_lst[env_id]
@@ -540,8 +548,7 @@ def collect(
540548
# ==================================================================
541549
if len(self.game_segment_pool) >= self._default_num_segments:
542550
self._logger.info(
543-
f'✓ Collected {len(self.game_segment_pool)} segments '
544-
f'(target: {self._default_num_segments})'
551+
f"[Collection Complete] Segments={len(self.game_segment_pool)}/{self._default_num_segments}"
545552
)
546553

547554
# Format return data
@@ -565,13 +572,17 @@ def collect(
565572
collected_duration = sum([d['time'] for d in self._episode_info])
566573

567574
if self._world_size > 1:
568-
# Before allreduce
569-
self._logger.info(f"Rank {self._rank} before allreduce: collected_step={collected_step}, collected_episode={collected_episode}")
575+
# Aggregate data across ranks
576+
local_step, local_episode = collected_step, collected_episode
570577
collected_step = allreduce_data(collected_step, 'sum')
571578
collected_episode = allreduce_data(collected_episode, 'sum')
572579
collected_duration = allreduce_data(collected_duration, 'sum')
573-
# After allreduce
574-
self._logger.info(f"Rank {self._rank} after allreduce: collected_step={collected_step}, collected_episode={collected_episode}")
580+
581+
self._logger.info(
582+
f"[Rank {self._rank} Aggregation] "
583+
f"Local: steps={local_step}, episodes={local_episode} | "
584+
f"Global: steps={collected_step}, episodes={collected_episode}"
585+
)
575586

576587

577588
self._total_envstep_count += collected_step
@@ -625,9 +636,26 @@ def _output_log(self, train_iter: int) -> None:
625636
info['completed_value_mean'] = np.mean(completed_value)
626637

627638
self._episode_info.clear()
628-
629-
# Log to console
630-
self._logger.info("Collector Training Summary:\n{}".format('\n'.join([f' {k}: {v}' for k, v in info.items()])))
639+
640+
# Structured summary log
641+
self._logger.info(
642+
f"\n{'='*80}\n"
643+
f"[Collector Summary] Train Iter: {train_iter}\n"
644+
f"{'-'*80}\n"
645+
f"Episodes: {info['episode_count']} (Total: {info['total_episode_count']})\n"
646+
f"Steps: {info['envstep_count']} (Total: {info['total_envstep_count']})\n"
647+
f"Avg Steps/Ep: {info['avg_envstep_per_episode']:.1f}\n"
648+
f"Throughput: {info['avg_envstep_per_sec']:.2f} steps/s, {info['avg_episode_per_sec']:.3f} eps/s\n"
649+
f"Duration: {info['collect_time']:.2f}s (Total: {info['total_duration']:.2f}s)\n"
650+
f"{'-'*80}\n"
651+
f"Reward: mean={info['reward_mean']:.2f}, std={info['reward_std']:.2f}, "
652+
f"min={info['reward_min']:.2f}, max={info['reward_max']:.2f}\n"
653+
f"LLM Entropy: mean={info['llm_prior_entropy_mean']:.3f}, "
654+
f"min={info['llm_prior_entropy_min']:.3f}, max={info['llm_prior_entropy_max']:.3f}\n"
655+
+ (f"Visit Entropy: {info.get('visit_entropy_mean', 0):.3f}\n" if not self.collect_with_pure_policy else "")
656+
+ (f"Completed Val: {info.get('completed_value_mean', 0):.3f}\n" if self.policy_config.gumbel_algo else "")
657+
+ f"{'='*80}"
658+
)
631659

632660
# Log to TensorBoard and WandB
633661
for k, v in info.items():

zoo/jericho/priorzero/priorzero_config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"qwen2.5-7b": {
3030
"model_name_or_path": "/mnt/shared-storage-user/puyuan/xiongjyu/models/Qwen2.5-7B-Instruct",
3131
"vllm_tensor_parallel_size": 1,
32-
# "vllm_tensor_parallel_size": 2,
32+
# "vllm_tensor_parallel_size": 2, # TODO
3333
"gpu_memory_utilization": 0.35,
3434
"description": "Qwen2.5-7B-Instruct (high quality, needs 2+ GPUs)",
3535
},
@@ -115,7 +115,8 @@ class PriorZeroLLMConfig:
115115

116116
# 需要注意的是,buffer中取一条经验是 10个样本,因为包含10次交互; num_unroll_steps = 10
117117
train_batch_size: int = 640 # 总的train_size, 结果= micro_batch_size * GPUS * gradient_accumulation_steps
118-
micro_train_batch_size: int = 16 # 一次micro_train_batch_size 用来计算梯度;只有一次 train_batch_size 才会更新参数
118+
# micro_train_batch_size: int = 16 # 一次micro_train_batch_size 用来计算梯度;只有一次 train_batch_size 才会更新参数
119+
micro_train_batch_size: int = 4 # 一次micro_train_batch_size 用来计算梯度;只有一次 train_batch_size 才会更新参数
119120
broadcast_every: int = 1 # 每次训练多少次 train_batch_size 才同步 vllm 参数;也就是说 vllm 中的模型 off 多少次参数更新
120121

121122
learning_rate: float = 1e-6

0 commit comments

Comments
 (0)