Skip to content

Commit 9d54f08

Browse files
committed
fix(pu): fix mask_padding, obs_shape, polish cot prompts
1 parent 53803a5 commit 9d54f08

File tree

2 files changed

+517
-36
lines changed

2 files changed

+517
-36
lines changed

zoo/jericho/priorzero/priorzero_policy.py

Lines changed: 118 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -401,17 +401,21 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
401401

402402
# Unpack data
403403
# NOTE: game_segments is our custom GameSegment with mcts_policy_segment
404-
# [FIX] Handle case where game_segments might not be included
404+
# [FIX] Handle both 3-element (from buffer) and 4-element (with explicit train_iter) formats
405405
if len(data) == 4:
406+
# Format: [current_batch, target_batch, train_iter, game_segments]
407+
# This is when learner explicitly adds train_iter
406408
current_batch, target_batch, train_iter, game_segments = data
407409
elif len(data) == 3:
408-
current_batch, target_batch, train_iter = data
409-
game_segments = None
410+
# Format: [current_batch, target_batch, game_segments]
411+
# This is the standard format from PriorZeroGameBuffer.sample()
412+
current_batch, target_batch, game_segments = data
413+
train_iter = self._train_iteration # Get from instance variable
410414
import logging
411415
logger = logging.getLogger(__name__)
412-
logger.warning(
413-
"[PRIORZERO] game_segments not included in training data. "
414-
"SFT/RFT training will be skipped."
416+
logger.debug(
417+
f"[PRIORZERO] Using 3-element format. game_segments: "
418+
f"{type(game_segments)}, count: {len(game_segments) if game_segments else 0}"
415419
)
416420
else:
417421
raise ValueError(f"Unexpected data format: expected 3 or 4 elements, got {len(data)}")
@@ -540,20 +544,17 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
540544
}
541545

542546
# [FIX] Following unizero.py lines 673-675 exactly:
543-
# TODO======= 是否需要[:,-1:]
544-
# Convert mask_batch to boolean, then truncate observations and mask_padding
545-
batch_for_gpt['mask_padding'] = mask_batch == 1.0 # 0 means invalid padding data. Shape is now (B, T), e.g., (2, 4)
547+
# Convert mask_batch to boolean, then truncate to align with observations/rewards
548+
batch_for_gpt['mask_padding'] = mask_batch == 1.0 # 0 means invalid padding data. Shape: (B, T)
546549

547-
# =================================================================================
548-
# [!!! FIX !!!] REMOVE OR COMMENT OUT THE LINE BELOW.
549-
# This line is the source of the bug. It incorrectly truncates the mask from shape
550-
# (B, T) to (B, T-1), causing a mismatch with the rewards tensor.
551-
# The mask_batch from the replay buffer already has the correct length (T)
552-
# corresponding to the number of unroll steps.
553-
# =================================================================================
554-
# batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] # <--- REMOVE THIS
550+
# [CRITICAL] Truncate mask_padding to align with observations and rewards
551+
# This is REQUIRED because:
552+
# - observations are truncated to [:, :-1] → shape (B, T-1)
553+
# - rewards are already (B, T-1) from MCTS
554+
# - mask_padding MUST match to indicate valid data positions
555+
# batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] # Shape: (B, T-1)
555556

556-
batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1]
557+
batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] # Shape: (B, T-1, obs_dim)
557558

558559
# [FIX] Add missing 'ends' field (following unizero.py line 676)
559560
# 'ends' marks terminal states in the trajectory (0 = not terminal)
@@ -1030,12 +1031,9 @@ def _monitor_vars_learn(self) -> List[str]:
10301031
Returns:
10311032
List of variable names that should be logged to TensorBoard/WandB
10321033
"""
1033-
# Start with all UniZero monitoring variables
1034-
monitor_vars = super()._monitor_vars_learn()
10351034

1036-
# Add PriorZero-specific LLM monitoring variables
1037-
priorzero_vars = [
1038-
# ============ LLM Loss Metrics ============
1035+
return [
1036+
# ============ LLM Loss Metrics ============
10391037
'llm_sft_loss', # Supervised fine-tuning loss
10401038
'llm_rft_loss', # Reinforcement fine-tuning loss
10411039
'llm_total_loss', # Combined LLM loss
@@ -1057,20 +1055,104 @@ def _monitor_vars_learn(self) -> List[str]:
10571055
'wm_policy_loss',
10581056
'wm_reward_loss',
10591057
'wm_obs_loss',
1060-
]
1061-
1062-
# Combine and deduplicate
1063-
all_vars = monitor_vars + priorzero_vars
1064-
1065-
# Remove duplicates while preserving order
1066-
seen = set()
1067-
unique_vars = []
1068-
for var in all_vars:
1069-
if var not in seen:
1070-
seen.add(var)
1071-
unique_vars.append(var)
10721058

1073-
return unique_vars
1059+
'analysis/dormant_ratio_encoder',
1060+
'analysis/dormant_ratio_transformer',
1061+
'analysis/dormant_ratio_head',
1062+
1063+
'analysis/avg_weight_mag_encoder',
1064+
'analysis/avg_weight_mag_transformer',
1065+
'analysis/avg_weight_mag_head',
1066+
'analysis/e_rank_last_linear',
1067+
'analysis/e_rank_sim_norm',
1068+
1069+
'analysis/latent_state_l2_norms',
1070+
'analysis/l2_norm_before',
1071+
'analysis/l2_norm_after',
1072+
'analysis/grad_norm_before',
1073+
'analysis/grad_norm_after',
1074+
1075+
'analysis/first_step_loss_value',
1076+
'analysis/first_step_loss_policy',
1077+
'analysis/first_step_loss_rewards',
1078+
'analysis/first_step_loss_obs',
1079+
1080+
'analysis/middle_step_loss_value',
1081+
'analysis/middle_step_loss_policy',
1082+
'analysis/middle_step_loss_rewards',
1083+
'analysis/middle_step_loss_obs',
1084+
1085+
'analysis/last_step_loss_value',
1086+
'analysis/last_step_loss_policy',
1087+
'analysis/last_step_loss_rewards',
1088+
'analysis/last_step_loss_obs',
1089+
1090+
'adaptive_alpha',
1091+
"adaptive_target_entropy_ratio",
1092+
'alpha_loss',
1093+
1094+
'Current_GPU',
1095+
'Max_GPU',
1096+
'collect_epsilon',
1097+
'collect_mcts_temperature',
1098+
'cur_lr_world_model',
1099+
'cur_lr_tokenizer',
1100+
1101+
'weighted_total_loss',
1102+
'obs_loss',
1103+
'policy_loss',
1104+
'orig_policy_loss',
1105+
'policy_entropy',
1106+
'latent_recon_loss',
1107+
'target_policy_entropy',
1108+
'reward_loss',
1109+
'value_loss',
1110+
'consistency_loss',
1111+
'value_priority',
1112+
'target_reward',
1113+
'target_value',
1114+
'total_grad_norm_before_clip_wm',
1115+
# tokenizer
1116+
'commitment_loss',
1117+
'reconstruction_loss',
1118+
'perceptual_loss',
1119+
1120+
1121+
"logits_value_mean",
1122+
"logits_value_max",
1123+
"logits_value_min",
1124+
"logits_policy_mean",
1125+
"logits_policy_max",
1126+
"logits_policy_min",
1127+
1128+
"temperature_value",
1129+
"temperature_reward",
1130+
"temperature_policy",
1131+
"current_policy_label_eps",
1132+
'adaptive_alpha',
1133+
"adaptive_target_entropy_ratio",
1134+
'alpha_loss',
1135+
"current_encoder_clip_value",
1136+
1137+
# ==================== [新增] 添加范数和中间张量监控变量 ====================
1138+
# 模块总范数
1139+
'norm/encoder/_total_norm',
1140+
'norm/transformer/_total_norm',
1141+
'norm/head_value/_total_norm',
1142+
'norm/head_reward/_total_norm',
1143+
'norm/head_policy/_total_norm',
1144+
# 中间张量 x 的统计信息
1145+
'norm/x_token/mean',
1146+
'norm/x_token/std',
1147+
'norm/x_token/max',
1148+
'norm/x_token/min',
1149+
]
1150+
# 注意:我们不把每一层的范数都加到这里,因为数量太多会导致日志混乱。
1151+
# 在实践中,如果通过总范数发现问题,可以临时在TensorBoard中搜索特定层的范数,
1152+
# 或者在本地打印 `norm_log_dict` 来进行详细分析。
1153+
# wandb等工具可以更好地处理大量的动态指标。
1154+
# ========================================================================
1155+
10741156

10751157
def _forward_collect(
10761158
self,

0 commit comments

Comments
 (0)