@@ -401,17 +401,21 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
401401
402402 # Unpack data
403403 # NOTE: game_segments is our custom GameSegment with mcts_policy_segment
404- # [FIX] Handle case where game_segments might not be included
404+ # [FIX] Handle both 3-element (from buffer) and 4-element (with explicit train_iter) formats
405405 if len (data ) == 4 :
406+ # Format: [current_batch, target_batch, train_iter, game_segments]
407+ # This is when learner explicitly adds train_iter
406408 current_batch , target_batch , train_iter , game_segments = data
407409 elif len (data ) == 3 :
408- current_batch , target_batch , train_iter = data
409- game_segments = None
410+ # Format: [current_batch, target_batch, game_segments]
411+ # This is the standard format from PriorZeroGameBuffer.sample()
412+ current_batch , target_batch , game_segments = data
413+ train_iter = self ._train_iteration # Get from instance variable
410414 import logging
411415 logger = logging .getLogger (__name__ )
412- logger .warning (
413- "[PRIORZERO] game_segments not included in training data. "
414- "SFT/RFT training will be skipped. "
416+ logger .debug (
417+ f "[PRIORZERO] Using 3-element format. game_segments: "
418+ f" { type ( game_segments ) } , count: { len ( game_segments ) if game_segments else 0 } "
415419 )
416420 else :
417421 raise ValueError (f"Unexpected data format: expected 3 or 4 elements, got { len (data )} " )
@@ -540,20 +544,17 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
540544 }
541545
542546 # [FIX] Following unizero.py lines 673-675 exactly:
543- # TODO======= 是否需要[:,-1:]
544- # Convert mask_batch to boolean, then truncate observations and mask_padding
545- batch_for_gpt ['mask_padding' ] = mask_batch == 1.0 # 0 means invalid padding data. Shape is now (B, T), e.g., (2, 4)
547+ # Convert mask_batch to boolean, then truncate to align with observations/rewards
548+ batch_for_gpt ['mask_padding' ] = mask_batch == 1.0 # 0 means invalid padding data. Shape: (B, T)
546549
547- # =================================================================================
548- # [!!! FIX !!!] REMOVE OR COMMENT OUT THE LINE BELOW.
549- # This line is the source of the bug. It incorrectly truncates the mask from shape
550- # (B, T) to (B, T-1), causing a mismatch with the rewards tensor.
551- # The mask_batch from the replay buffer already has the correct length (T)
552- # corresponding to the number of unroll steps.
553- # =================================================================================
554- # batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] # <--- REMOVE THIS
550+ # [CRITICAL] Truncate mask_padding to align with observations and rewards
551+ # This is REQUIRED because:
552+ # - observations are truncated to [:, :-1] → shape (B, T-1)
553+ # - rewards are already (B, T-1) from MCTS
554+ # - mask_padding MUST match to indicate valid data positions
555+ # batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] # Shape: (B, T-1)
555556
556- batch_for_gpt ['observations' ] = batch_for_gpt ['observations' ][:, :- 1 ]
557+ batch_for_gpt ['observations' ] = batch_for_gpt ['observations' ][:, :- 1 ] # Shape: (B, T-1, obs_dim)
557558
558559 # [FIX] Add missing 'ends' field (following unizero.py line 676)
559560 # 'ends' marks terminal states in the trajectory (0 = not terminal)
@@ -1030,12 +1031,9 @@ def _monitor_vars_learn(self) -> List[str]:
10301031 Returns:
10311032 List of variable names that should be logged to TensorBoard/WandB
10321033 """
1033- # Start with all UniZero monitoring variables
1034- monitor_vars = super ()._monitor_vars_learn ()
10351034
1036- # Add PriorZero-specific LLM monitoring variables
1037- priorzero_vars = [
1038- # ============ LLM Loss Metrics ============
1035+ return [
1036+ # ============ LLM Loss Metrics ============
10391037 'llm_sft_loss' , # Supervised fine-tuning loss
10401038 'llm_rft_loss' , # Reinforcement fine-tuning loss
10411039 'llm_total_loss' , # Combined LLM loss
@@ -1057,20 +1055,104 @@ def _monitor_vars_learn(self) -> List[str]:
10571055 'wm_policy_loss' ,
10581056 'wm_reward_loss' ,
10591057 'wm_obs_loss' ,
1060- ]
1061-
1062- # Combine and deduplicate
1063- all_vars = monitor_vars + priorzero_vars
1064-
1065- # Remove duplicates while preserving order
1066- seen = set ()
1067- unique_vars = []
1068- for var in all_vars :
1069- if var not in seen :
1070- seen .add (var )
1071- unique_vars .append (var )
10721058
1073- return unique_vars
1059+ 'analysis/dormant_ratio_encoder' ,
1060+ 'analysis/dormant_ratio_transformer' ,
1061+ 'analysis/dormant_ratio_head' ,
1062+
1063+ 'analysis/avg_weight_mag_encoder' ,
1064+ 'analysis/avg_weight_mag_transformer' ,
1065+ 'analysis/avg_weight_mag_head' ,
1066+ 'analysis/e_rank_last_linear' ,
1067+ 'analysis/e_rank_sim_norm' ,
1068+
1069+ 'analysis/latent_state_l2_norms' ,
1070+ 'analysis/l2_norm_before' ,
1071+ 'analysis/l2_norm_after' ,
1072+ 'analysis/grad_norm_before' ,
1073+ 'analysis/grad_norm_after' ,
1074+
1075+ 'analysis/first_step_loss_value' ,
1076+ 'analysis/first_step_loss_policy' ,
1077+ 'analysis/first_step_loss_rewards' ,
1078+ 'analysis/first_step_loss_obs' ,
1079+
1080+ 'analysis/middle_step_loss_value' ,
1081+ 'analysis/middle_step_loss_policy' ,
1082+ 'analysis/middle_step_loss_rewards' ,
1083+ 'analysis/middle_step_loss_obs' ,
1084+
1085+ 'analysis/last_step_loss_value' ,
1086+ 'analysis/last_step_loss_policy' ,
1087+ 'analysis/last_step_loss_rewards' ,
1088+ 'analysis/last_step_loss_obs' ,
1089+
1090+ 'adaptive_alpha' ,
1091+ "adaptive_target_entropy_ratio" ,
1092+ 'alpha_loss' ,
1093+
1094+ 'Current_GPU' ,
1095+ 'Max_GPU' ,
1096+ 'collect_epsilon' ,
1097+ 'collect_mcts_temperature' ,
1098+ 'cur_lr_world_model' ,
1099+ 'cur_lr_tokenizer' ,
1100+
1101+ 'weighted_total_loss' ,
1102+ 'obs_loss' ,
1103+ 'policy_loss' ,
1104+ 'orig_policy_loss' ,
1105+ 'policy_entropy' ,
1106+ 'latent_recon_loss' ,
1107+ 'target_policy_entropy' ,
1108+ 'reward_loss' ,
1109+ 'value_loss' ,
1110+ 'consistency_loss' ,
1111+ 'value_priority' ,
1112+ 'target_reward' ,
1113+ 'target_value' ,
1114+ 'total_grad_norm_before_clip_wm' ,
1115+ # tokenizer
1116+ 'commitment_loss' ,
1117+ 'reconstruction_loss' ,
1118+ 'perceptual_loss' ,
1119+
1120+
1121+ "logits_value_mean" ,
1122+ "logits_value_max" ,
1123+ "logits_value_min" ,
1124+ "logits_policy_mean" ,
1125+ "logits_policy_max" ,
1126+ "logits_policy_min" ,
1127+
1128+ "temperature_value" ,
1129+ "temperature_reward" ,
1130+ "temperature_policy" ,
1131+ "current_policy_label_eps" ,
1132+ 'adaptive_alpha' ,
1133+ "adaptive_target_entropy_ratio" ,
1134+ 'alpha_loss' ,
1135+ "current_encoder_clip_value" ,
1136+
1137+ # ==================== [新增] 添加范数和中间张量监控变量 ====================
1138+ # 模块总范数
1139+ 'norm/encoder/_total_norm' ,
1140+ 'norm/transformer/_total_norm' ,
1141+ 'norm/head_value/_total_norm' ,
1142+ 'norm/head_reward/_total_norm' ,
1143+ 'norm/head_policy/_total_norm' ,
1144+ # 中间张量 x 的统计信息
1145+ 'norm/x_token/mean' ,
1146+ 'norm/x_token/std' ,
1147+ 'norm/x_token/max' ,
1148+ 'norm/x_token/min' ,
1149+ ]
1150+ # 注意:我们不把每一层的范数都加到这里,因为数量太多会导致日志混乱。
1151+ # 在实践中,如果通过总范数发现问题,可以临时在TensorBoard中搜索特定层的范数,
1152+ # 或者在本地打印 `norm_log_dict` 来进行详细分析。
1153+ # wandb等工具可以更好地处理大量的动态指标。
1154+ # ========================================================================
1155+
10741156
10751157 def _forward_collect (
10761158 self ,
0 commit comments