@@ -71,6 +71,31 @@ def generate_task_loss_dict(multi_task_losses: List[Union[torch.Tensor, float]],
7171 task_loss_dict [task_name ] = task_loss
7272 return task_loss_dict
7373
74+ # # 修改后的函数:
75+ # def generate_task_loss_dict(
76+ # multi_task_losses: List[Union[torch.Tensor, float]],
77+ # task_name_template: str,
78+ # global_task_ids: List[int]
79+ # ) -> Dict[str, float]:
80+ # """
81+ # Overview:
82+ # Generates a dictionary for the losses of each task using their explicit global IDs.
83+ # Arguments:
84+ # - multi_task_losses (:obj:`List[Union[torch.Tensor, float]]`): A list containing the loss for each task.
85+ # - task_name_template (:obj:`str`): The template for the task name, e.g., 'obs_loss_task{}'.
86+ # - global_task_ids (:obj:`List[int]`): A list of global task IDs corresponding to each loss in multi_task_losses.
87+ # Returns:
88+ # - task_loss_dict (:obj:`Dict[str, float]`): A dictionary where keys are formatted task names and values are the corresponding losses.
89+ # """
90+ # task_loss_dict = {}
91+ # # 使用 zip 将每个损失与其正确的全局ID配对
92+ # for task_loss, global_id in zip(multi_task_losses, global_task_ids):
93+ # task_name = task_name_template.format(global_id)
94+ # try:
95+ # task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss
96+ # except Exception as e:
97+ # task_loss_dict[task_name] = task_loss
98+ # return task_loss_dict
7499
75100
76101class WrappedModel :
@@ -277,12 +302,15 @@ def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type,
277302 {
278303 'params' : tokenizer_params ,
279304 'lr' : learning_rate , # Tokenizer使用基础学习率,例如 1e-4
280- 'weight_decay' : weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化
305+ # 'weight_decay': weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化
306+ 'weight_decay' : weight_decay # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化
281307 },
282308 {
283309 'params' : head_params ,
284310 'lr' : learning_rate , # Heads也使用基础学习率率,例如 1e-4
285- 'weight_decay' : 0.0 # 通常Heads的权重不做衰减
311+ # 'weight_decay': 0.0 # 通常Heads的权重不做衰减
312+ 'weight_decay' : weight_decay
313+
286314 }
287315 ]
288316
@@ -845,6 +873,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
845873 orig_policy_loss_multi_task = []
846874 policy_entropy_multi_task = []
847875 weighted_total_loss = 0.0 # Initialize to 0.0 to avoid in-place operations.
876+ total_alpha_loss = 0.0
848877
849878 latent_state_l2_norms_multi_task = []
850879 average_target_policy_entropy_multi_task = []
@@ -869,12 +898,27 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
869898 # current_policy_label_eps = 0.0
870899 current_policy_label_eps = 0.01
871900
901+ # 新增一个列表来收集当前批次中所有任务的真实全局ID
902+ global_task_ids_in_batch = []
903+ alpha_loss = None
904+
872905
906+ # 用于Alpha日志记录的新列表
907+ alpha_loss_multi_task = []
908+ target_entropy_multi_task = []
909+
910+ # 仅在自适应alpha启用时,预先获取当前alpha值,确保在单次迭代中对所有任务一致
911+ current_alpha = self ._cfg .model .world_model_cfg .policy_entropy_weight
912+ if self .use_adaptive_entropy_weight :
913+ current_alpha = self .log_alpha .exp ().detach ()
873914
874915 losses_list = [] # Used to store the loss tensor for each task, required by gradient correction methods.
875916 for task_id , data_one_task in enumerate (data ):
876- current_batch , target_batch , task_id = data_one_task
917+ current_batch , target_batch , task_id = data_one_task # task_id 是真实的全局ID
877918
919+ # 将真实的全局ID添加到列表中
920+ global_task_ids_in_batch .append (task_id )
921+
878922 # TODO: Adapt RoPE for multitask settings (using timestep_batch).
879923 obs_batch_ori , action_batch , target_action_batch , mask_batch , indices , weights , make_time , timestep_batch = current_batch
880924 target_reward , target_value , target_policy = target_batch
@@ -948,7 +992,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
948992 # )
949993
950994 losses = self ._learn_model .world_model .compute_loss (
951- batch_for_gpt , self ._target_model .world_model .tokenizer , self .value_inverse_scalar_transform_handle , current_policy_label_eps = current_policy_label_eps ,task_id = task_id
995+ batch_for_gpt , self ._target_model .world_model .tokenizer , self .value_inverse_scalar_transform_handle , current_policy_label_eps = current_policy_label_eps , task_id = task_id
952996 )
953997
954998 # ==================== START MODIFICATION 2 ====================
@@ -960,7 +1004,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
9601004
9611005
9621006 # TODO: Accumulate the weighted total loss. This assumes the loss from `compute_loss` is already weighted.
963- weighted_total_loss += losses .loss_total
1007+ weighted_total_loss += losses .loss_total # NOTE:+=
9641008
9651009 # TODO: Add assertions to check for NaN or Inf values in the loss if needed for debugging.
9661010 # assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values"
@@ -986,9 +1030,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
9861030
9871031 # 从 losses 对象中提取策略熵
9881032 # ==================== START: 目标熵正则化更新逻辑 ====================
989- alpha_loss = None
9901033 current_alpha = self ._cfg .model .world_model_cfg .policy_entropy_weight # 默认使用固定值
9911034 if self .use_adaptive_entropy_weight :
1035+
9921036 # --- 动态计算目标熵 (这部分逻辑是正确的,予以保留) ---
9931037 progress = min (1.0 , train_iter / self .target_entropy_decay_steps )
9941038 current_ratio = self .target_entropy_start_ratio * (1 - progress ) + self .target_entropy_end_ratio * progress
@@ -999,12 +1043,19 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
9991043 # --- 计算 alpha_loss (已修正符号) ---
10001044 # 这是核心修正点:去掉了最前面的负号
10011045 # detach() 仍然是关键,确保 alpha_loss 的梯度只流向 log_alpha
1002- alpha_loss = (self .log_alpha * (policy_entropy .detach () - current_target_entropy )).mean ()
1046+ alpha_loss_task = (self .log_alpha * (policy_entropy .detach () - current_target_entropy )).mean () # NOTE:=
10031047
10041048 # # --- 更新 log_alpha ---
1005- self .alpha_optimizer .zero_grad ()
1006- alpha_loss .backward ()
1007- self .alpha_optimizer .step ()
1049+ # self.alpha_optimizer.zero_grad()
1050+ # alpha_loss.backward()
1051+ # self.alpha_optimizer.step()
1052+
1053+ # 累加alpha_loss
1054+ total_alpha_loss += alpha_loss_task
1055+ # 为日志记录收集每个任务的alpha_loss和目标熵
1056+ alpha_loss_multi_task .append (alpha_loss_task )
1057+ target_entropy_multi_task .append (current_target_entropy )
1058+
10081059 # --- [优化建议] 增加 log_alpha 裁剪作为安全措施 ---
10091060 with torch .no_grad ():
10101061 # 将 alpha 限制在例如 [1e-4, 10.0] 的范围内
@@ -1030,7 +1081,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
10301081 self .obs_loss_weight * obs_loss # 假设 ssl_loss_weight 是 obs_loss 的权重
10311082 # ... 如果还有其他损失项,也加进来 ...
10321083 )
1033- weighted_total_loss = (weights * total_loss ).mean ()
1084+ weighted_total_loss + = (weights * total_loss ).mean () # NOTE:+=
10341085 # ===================== END: 目标熵正则化更新逻辑 =====================
10351086
10361087 # ============ For value-based priority calculation ============
@@ -1098,24 +1149,52 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
10981149 # Core learn model update step.
10991150 self ._optimizer_world_model .zero_grad ()
11001151
1152+ if self .use_adaptive_entropy_weight :
1153+ self .alpha_optimizer .zero_grad ()
1154+ # 2. 计算最终的alpha loss (在累加后取平均)
1155+ final_alpha_loss = None
1156+ if self .use_adaptive_entropy_weight :
1157+ if len (data ) > 0 :
1158+ final_alpha_loss = total_alpha_loss / len (data )
1159+ else : # 防御性编程,避免除以0
1160+ final_alpha_loss = torch .tensor (0.0 , device = self ._cfg .device )
1161+
11011162 # Assuming losses_list is a list of tensors with gradients, e.g., [loss1, loss2, ...].
11021163 if self ._cfg .use_moco :
11031164 # Call MoCo's backward method, which handles gradient correction internally.
11041165 if self ._cfg .moco_version == "v0" :
11051166 lambd , stats = self .grad_correct .backward (losses = losses_list , ** self ._cfg .grad_correct_params )
11061167 elif self ._cfg .moco_version == "v1" :
11071168 lambd , stats = self .grad_correct .backward (losses_list )
1169+
1170+ # 单独为alpha loss进行反向传播
1171+ if self .use_adaptive_entropy_weight :
1172+ final_alpha_loss .backward ()
11081173
11091174 elif self ._cfg .only_use_moco_stats :
11101175 # Only compute MoCo stats without applying gradient correction.
11111176 lambd , stats = self .grad_correct .backward (losses = losses_list , ** self ._cfg .grad_correct_params )
1177+
11121178 # Each rank performs its own backpropagation.
1113- weighted_total_loss .backward ()
1179+ # weighted_total_loss.backward()
1180+
1181+ # 如果启用自适应alpha,将alpha loss加到主损失上一起反向传播
1182+ if self .use_adaptive_entropy_weight :
1183+ (weighted_total_loss + final_alpha_loss ).backward ()
1184+ elif weighted_total_loss != 0.0 : # 确保有损失可以反向传播
1185+ weighted_total_loss .backward ()
1186+
11141187 else :
11151188 # If not using gradient correction, each rank performs standard backpropagation.
11161189 lambd = torch .tensor ([0. for _ in range (self .task_num_for_current_rank )], device = self ._cfg .device )
1117- weighted_total_loss .backward ()
11181190
1191+ # weighted_total_loss.backward()
1192+
1193+ # 如果启用自适应alpha,将alpha loss加到主损失上一起反向传播
1194+ if self .use_adaptive_entropy_weight :
1195+ (weighted_total_loss + final_alpha_loss ).backward ()
1196+ elif weighted_total_loss != 0.0 : # 确保有损失可以反向传播
1197+ weighted_total_loss .backward ()
11191198
11201199 # -----------------------------------------------------------------
11211200 # 仍然在 torch.no_grad() 环境下执行
@@ -1150,9 +1229,6 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
11501229 scale_module_weights_vectorized (self ._model .world_model .tokenizer .encoder , scale_factor )
11511230
11521231
1153-
1154-
1155-
11561232 # For debugging purposes.
11571233 # for name, param in self._learn_model.world_model.tokenizer.encoder.named_parameters():
11581234 # print('name, param.mean(), param.std():', name, param.mean(), param.std())
@@ -1179,6 +1255,13 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
11791255
11801256 self ._optimizer_world_model .step ()
11811257
1258+ # 4. 更新Alpha优化器
1259+ if self .use_adaptive_entropy_weight :
1260+ self .alpha_optimizer .step ()
1261+ # 裁剪log_alpha以保证稳定性
1262+ with torch .no_grad ():
1263+ self .log_alpha .clamp_ (np .log (1e-4 ), np .log (10.0 ))
1264+
11821265 if self ._cfg .cos_lr_scheduler or self ._cfg .piecewise_decay_lr_scheduler :
11831266 self .lr_scheduler .step ()
11841267
@@ -1210,12 +1293,12 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
12101293 if self .use_adaptive_entropy_weight :
12111294 return_log_dict ['adaptive_alpha' ] = current_alpha .item ()
12121295 return_log_dict ['adaptive_target_entropy_ratio' ] = current_ratio
1213- return_log_dict ['alpha_loss ' ] = alpha_loss .item ()
1296+ return_log_dict ['final_alpha_loss ' ] = final_alpha_loss .item ()
12141297 # ==================== START: 添加新日志项 ====================
12151298
12161299 # Generate task-related loss dictionaries and prefix each task-related loss with "noreduce_".
12171300 multi_task_loss_dicts = {
1218- ** generate_task_loss_dict (obs_loss_multi_task , 'noreduce_obs_loss_task{}' , task_id = self .task_id ),
1301+ ** generate_task_loss_dict (obs_loss_multi_task , 'noreduce_obs_loss_task{}' , task_id = self .task_id ), #global_task_ids=global_task_ids_in_batch), # task_id=self.task_id),
12191302 ** generate_task_loss_dict (latent_recon_loss_multi_task , 'noreduce_latent_recon_loss_task{}' , task_id = self .task_id ),
12201303 ** generate_task_loss_dict (perceptual_loss_multi_task , 'noreduce_perceptual_loss_task{}' , task_id = self .task_id ),
12211304 ** generate_task_loss_dict (latent_state_l2_norms_multi_task , 'noreduce_latent_state_l2_norms_task{}' , task_id = self .task_id ),
@@ -1230,6 +1313,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
12301313 ** generate_task_loss_dict (lambd , 'noreduce_lambd_task{}' , task_id = self .task_id ),
12311314 ** generate_task_loss_dict (value_priority_multi_task , 'noreduce_value_priority_task{}' , task_id = self .task_id ),
12321315 ** generate_task_loss_dict (value_priority_mean_multi_task , 'noreduce_value_priority_mean_task{}' , task_id = self .task_id ),
1316+
1317+ # 新增alpha相关日志
1318+ ** generate_task_loss_dict (alpha_loss_multi_task , 'noreduce_alpha_loss_task{}' , self .task_id ),
1319+ ** generate_task_loss_dict (target_entropy_multi_task , 'noreduce_target_entropy_task{}' , self .task_id ),
12331320 }
12341321 return_log_dict .update (multi_task_loss_dicts )
12351322
@@ -1319,7 +1406,7 @@ def _monitor_vars_learn(self, num_tasks: int = 2) -> List[str]:
13191406 # 'value_priority',
13201407 'adaptive_alpha' ,
13211408 "adaptive_target_entropy_ratio" ,
1322- 'alpha_loss ' ,
1409+ 'final_alpha_loss ' ,
13231410 ]
13241411
13251412
@@ -1346,7 +1433,10 @@ def _monitor_vars_learn(self, num_tasks: int = 2) -> List[str]:
13461433 'noreduce_avg_weight_mag_transformer' ,
13471434 'noreduce_avg_weight_mag_head' ,
13481435 'noreduce_e_rank_last_linear' ,
1349- 'noreduce_e_rank_sim_norm'
1436+ 'noreduce_e_rank_sim_norm' ,
1437+ "noreduce_alpha_loss" ,
1438+ "noreduce_target_entropy" ,
1439+
13501440 ]
13511441
13521442 # Use self.task_num_for_current_rank as the number of tasks for the current rank.
0 commit comments