Skip to content

Commit af99278

Browse files
committed
fix(pu): fix tb log when gpu_num<task_num, fix total_loss += bug, polish
alpha_loss
1 parent 9f69f5a commit af99278

File tree

6 files changed

+2168
-47
lines changed

6 files changed

+2168
-47
lines changed

lzero/entry/train_unizero_multitask_segment_ddp.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,12 +437,17 @@ def train_unizero_multitask_segment_ddp(
437437
tasks_per_rank = total_tasks // world_size
438438
remainder = total_tasks % world_size
439439

440+
# ==================== START: 关键修复 ====================
441+
# 1. 精确计算当前Rank负责的任务数量
440442
if rank < remainder:
441443
start_idx = rank * (tasks_per_rank + 1)
442444
end_idx = start_idx + tasks_per_rank + 1
445+
num_tasks_for_this_rank = tasks_per_rank + 1
443446
else:
444447
start_idx = rank * tasks_per_rank + remainder
445448
end_idx = start_idx + tasks_per_rank
449+
num_tasks_for_this_rank = tasks_per_rank
450+
# ==================== END: 关键修复 ====================
446451

447452
tasks_for_this_rank = input_cfg_list[start_idx:end_idx]
448453

@@ -465,8 +470,16 @@ def train_unizero_multitask_segment_ddp(
465470
# Use the config of the first task to create a shared policy.
466471
task_id, [cfg, create_cfg] = tasks_for_this_rank[0]
467472

468-
for config in tasks_for_this_rank:
469-
config[1][0].policy.task_num = tasks_per_rank
473+
# ==================== START: 关键修复 ====================
474+
# 2. 将正确的任务数量设置到 *所有* 相关配置中
475+
# 在创建Policy实例之前,必须确保配置是正确的
476+
for config_tuple in tasks_for_this_rank:
477+
# config_tuple is (task_id, [cfg_obj, create_cfg_obj])
478+
config_tuple[1][0].policy.task_num = num_tasks_for_this_rank
479+
480+
# 3. 确保用于创建Policy的那个cfg对象也拥有正确的task_num
481+
cfg.policy.task_num = num_tasks_for_this_rank
482+
# ==================== END: 关键修复 ====================
470483

471484
# Ensure the specified policy type is supported.
472485
assert create_cfg.policy.type in ['unizero_multitask', 'sampled_unizero_multitask'], \
@@ -602,7 +615,9 @@ def train_unizero_multitask_segment_ddp(
602615
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
603616

604617
# Check if it's time for evaluation.
605-
if learner.train_iter > 10 and learner.train_iter % cfg.policy.eval_freq == 0:
618+
# if learner.train_iter > 10 and learner.train_iter % cfg.policy.eval_freq == 0:
619+
if learner.train_iter == 0 or learner.train_iter % cfg.policy.eval_freq == 0: # only for debug TODO
620+
606621
print('=' * 20)
607622
print(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}...')
608623

lzero/policy/unizero.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,14 @@ def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type,
6868
'params': list(tokenizer_params.values()),
6969
'lr': learning_rate, # Tokenizer使用基础学习率,例如 1e-4
7070
# 'lr': learning_rate * 0.1, # 为encoder设置一个较小的学习率,例如 1e-5
71-
'weight_decay': weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化
72-
71+
# 'weight_decay': weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化
72+
'weight_decay': weight_decay # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化
7373
},
7474
{
7575
'params': list(head_params.values()),
7676
'lr': learning_rate, # Heads也使用基础学习率率,例如 1e-4
77-
'weight_decay': 0.0 # 通常Heads的权重不做衰减
78-
# 'weight_decay': weight_decay
77+
# 'weight_decay': 0.0 # 通常Heads的权重不做衰减
78+
'weight_decay': weight_decay
7979

8080
}
8181
]

lzero/policy/unizero_multitask.py

Lines changed: 110 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,31 @@ def generate_task_loss_dict(multi_task_losses: List[Union[torch.Tensor, float]],
7171
task_loss_dict[task_name] = task_loss
7272
return task_loss_dict
7373

74+
# # 修改后的函数:
75+
# def generate_task_loss_dict(
76+
# multi_task_losses: List[Union[torch.Tensor, float]],
77+
# task_name_template: str,
78+
# global_task_ids: List[int]
79+
# ) -> Dict[str, float]:
80+
# """
81+
# Overview:
82+
# Generates a dictionary for the losses of each task using their explicit global IDs.
83+
# Arguments:
84+
# - multi_task_losses (:obj:`List[Union[torch.Tensor, float]]`): A list containing the loss for each task.
85+
# - task_name_template (:obj:`str`): The template for the task name, e.g., 'obs_loss_task{}'.
86+
# - global_task_ids (:obj:`List[int]`): A list of global task IDs corresponding to each loss in multi_task_losses.
87+
# Returns:
88+
# - task_loss_dict (:obj:`Dict[str, float]`): A dictionary where keys are formatted task names and values are the corresponding losses.
89+
# """
90+
# task_loss_dict = {}
91+
# # 使用 zip 将每个损失与其正确的全局ID配对
92+
# for task_loss, global_id in zip(multi_task_losses, global_task_ids):
93+
# task_name = task_name_template.format(global_id)
94+
# try:
95+
# task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss
96+
# except Exception as e:
97+
# task_loss_dict[task_name] = task_loss
98+
# return task_loss_dict
7499

75100

76101
class WrappedModel:
@@ -277,12 +302,15 @@ def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type,
277302
{
278303
'params': tokenizer_params,
279304
'lr': learning_rate, # Tokenizer使用基础学习率,例如 1e-4
280-
'weight_decay': weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化
305+
# 'weight_decay': weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化
306+
'weight_decay': weight_decay # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化
281307
},
282308
{
283309
'params': head_params,
284310
'lr': learning_rate, # Heads也使用基础学习率率,例如 1e-4
285-
'weight_decay': 0.0 # 通常Heads的权重不做衰减
311+
# 'weight_decay': 0.0 # 通常Heads的权重不做衰减
312+
'weight_decay': weight_decay
313+
286314
}
287315
]
288316

@@ -845,6 +873,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
845873
orig_policy_loss_multi_task = []
846874
policy_entropy_multi_task = []
847875
weighted_total_loss = 0.0 # Initialize to 0.0 to avoid in-place operations.
876+
total_alpha_loss = 0.0
848877

849878
latent_state_l2_norms_multi_task = []
850879
average_target_policy_entropy_multi_task = []
@@ -869,12 +898,27 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
869898
# current_policy_label_eps = 0.0
870899
current_policy_label_eps = 0.01
871900

901+
# 新增一个列表来收集当前批次中所有任务的真实全局ID
902+
global_task_ids_in_batch = []
903+
alpha_loss = None
904+
872905

906+
# 用于Alpha日志记录的新列表
907+
alpha_loss_multi_task = []
908+
target_entropy_multi_task = []
909+
910+
# 仅在自适应alpha启用时,预先获取当前alpha值,确保在单次迭代中对所有任务一致
911+
current_alpha = self._cfg.model.world_model_cfg.policy_entropy_weight
912+
if self.use_adaptive_entropy_weight:
913+
current_alpha = self.log_alpha.exp().detach()
873914

874915
losses_list = [] # Used to store the loss tensor for each task, required by gradient correction methods.
875916
for task_id, data_one_task in enumerate(data):
876-
current_batch, target_batch, task_id = data_one_task
917+
current_batch, target_batch, task_id = data_one_task # task_id 是真实的全局ID
877918

919+
# 将真实的全局ID添加到列表中
920+
global_task_ids_in_batch.append(task_id)
921+
878922
# TODO: Adapt RoPE for multitask settings (using timestep_batch).
879923
obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time, timestep_batch = current_batch
880924
target_reward, target_value, target_policy = target_batch
@@ -948,7 +992,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
948992
# )
949993

950994
losses = self._learn_model.world_model.compute_loss(
951-
batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle, current_policy_label_eps=current_policy_label_eps,task_id=task_id
995+
batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle, current_policy_label_eps=current_policy_label_eps, task_id=task_id
952996
)
953997

954998
# ==================== START MODIFICATION 2 ====================
@@ -960,7 +1004,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
9601004

9611005

9621006
# TODO: Accumulate the weighted total loss. This assumes the loss from `compute_loss` is already weighted.
963-
weighted_total_loss += losses.loss_total
1007+
weighted_total_loss += losses.loss_total # NOTE:+=
9641008

9651009
# TODO: Add assertions to check for NaN or Inf values in the loss if needed for debugging.
9661010
# assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values"
@@ -986,9 +1030,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
9861030

9871031
# 从 losses 对象中提取策略熵
9881032
# ==================== START: 目标熵正则化更新逻辑 ====================
989-
alpha_loss = None
9901033
current_alpha = self._cfg.model.world_model_cfg.policy_entropy_weight # 默认使用固定值
9911034
if self.use_adaptive_entropy_weight:
1035+
9921036
# --- 动态计算目标熵 (这部分逻辑是正确的,予以保留) ---
9931037
progress = min(1.0, train_iter / self.target_entropy_decay_steps)
9941038
current_ratio = self.target_entropy_start_ratio * (1 - progress) + self.target_entropy_end_ratio * progress
@@ -999,12 +1043,19 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
9991043
# --- 计算 alpha_loss (已修正符号) ---
10001044
# 这是核心修正点:去掉了最前面的负号
10011045
# detach() 仍然是关键,确保 alpha_loss 的梯度只流向 log_alpha
1002-
alpha_loss = (self.log_alpha * (policy_entropy.detach() - current_target_entropy)).mean()
1046+
alpha_loss_task = (self.log_alpha * (policy_entropy.detach() - current_target_entropy)).mean() # NOTE:=
10031047

10041048
# # --- 更新 log_alpha ---
1005-
self.alpha_optimizer.zero_grad()
1006-
alpha_loss.backward()
1007-
self.alpha_optimizer.step()
1049+
# self.alpha_optimizer.zero_grad()
1050+
# alpha_loss.backward()
1051+
# self.alpha_optimizer.step()
1052+
1053+
# 累加alpha_loss
1054+
total_alpha_loss += alpha_loss_task
1055+
# 为日志记录收集每个任务的alpha_loss和目标熵
1056+
alpha_loss_multi_task.append(alpha_loss_task)
1057+
target_entropy_multi_task.append(current_target_entropy)
1058+
10081059
# --- [优化建议] 增加 log_alpha 裁剪作为安全措施 ---
10091060
with torch.no_grad():
10101061
# 将 alpha 限制在例如 [1e-4, 10.0] 的范围内
@@ -1030,7 +1081,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
10301081
self.obs_loss_weight * obs_loss # 假设 ssl_loss_weight 是 obs_loss 的权重
10311082
# ... 如果还有其他损失项,也加进来 ...
10321083
)
1033-
weighted_total_loss = (weights * total_loss).mean()
1084+
weighted_total_loss += (weights * total_loss).mean() # NOTE:+=
10341085
# ===================== END: 目标熵正则化更新逻辑 =====================
10351086

10361087
# ============ For value-based priority calculation ============
@@ -1098,24 +1149,52 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
10981149
# Core learn model update step.
10991150
self._optimizer_world_model.zero_grad()
11001151

1152+
if self.use_adaptive_entropy_weight:
1153+
self.alpha_optimizer.zero_grad()
1154+
# 2. 计算最终的alpha loss (在累加后取平均)
1155+
final_alpha_loss = None
1156+
if self.use_adaptive_entropy_weight:
1157+
if len(data) > 0:
1158+
final_alpha_loss = total_alpha_loss / len(data)
1159+
else: # 防御性编程,避免除以0
1160+
final_alpha_loss = torch.tensor(0.0, device=self._cfg.device)
1161+
11011162
# Assuming losses_list is a list of tensors with gradients, e.g., [loss1, loss2, ...].
11021163
if self._cfg.use_moco:
11031164
# Call MoCo's backward method, which handles gradient correction internally.
11041165
if self._cfg.moco_version=="v0":
11051166
lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params)
11061167
elif self._cfg.moco_version=="v1":
11071168
lambd, stats = self.grad_correct.backward(losses_list)
1169+
1170+
# 单独为alpha loss进行反向传播
1171+
if self.use_adaptive_entropy_weight:
1172+
final_alpha_loss.backward()
11081173

11091174
elif self._cfg.only_use_moco_stats:
11101175
# Only compute MoCo stats without applying gradient correction.
11111176
lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params)
1177+
11121178
# Each rank performs its own backpropagation.
1113-
weighted_total_loss.backward()
1179+
# weighted_total_loss.backward()
1180+
1181+
# 如果启用自适应alpha,将alpha loss加到主损失上一起反向传播
1182+
if self.use_adaptive_entropy_weight:
1183+
(weighted_total_loss + final_alpha_loss).backward()
1184+
elif weighted_total_loss != 0.0: # 确保有损失可以反向传播
1185+
weighted_total_loss.backward()
1186+
11141187
else:
11151188
# If not using gradient correction, each rank performs standard backpropagation.
11161189
lambd = torch.tensor([0. for _ in range(self.task_num_for_current_rank)], device=self._cfg.device)
1117-
weighted_total_loss.backward()
11181190

1191+
# weighted_total_loss.backward()
1192+
1193+
# 如果启用自适应alpha,将alpha loss加到主损失上一起反向传播
1194+
if self.use_adaptive_entropy_weight:
1195+
(weighted_total_loss + final_alpha_loss).backward()
1196+
elif weighted_total_loss != 0.0: # 确保有损失可以反向传播
1197+
weighted_total_loss.backward()
11191198

11201199
# -----------------------------------------------------------------
11211200
# 仍然在 torch.no_grad() 环境下执行
@@ -1150,9 +1229,6 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
11501229
scale_module_weights_vectorized(self._model.world_model.tokenizer.encoder, scale_factor)
11511230

11521231

1153-
1154-
1155-
11561232
# For debugging purposes.
11571233
# for name, param in self._learn_model.world_model.tokenizer.encoder.named_parameters():
11581234
# print('name, param.mean(), param.std():', name, param.mean(), param.std())
@@ -1179,6 +1255,13 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
11791255

11801256
self._optimizer_world_model.step()
11811257

1258+
# 4. 更新Alpha优化器
1259+
if self.use_adaptive_entropy_weight:
1260+
self.alpha_optimizer.step()
1261+
# 裁剪log_alpha以保证稳定性
1262+
with torch.no_grad():
1263+
self.log_alpha.clamp_(np.log(1e-4), np.log(10.0))
1264+
11821265
if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler:
11831266
self.lr_scheduler.step()
11841267

@@ -1210,12 +1293,12 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
12101293
if self.use_adaptive_entropy_weight:
12111294
return_log_dict['adaptive_alpha'] = current_alpha.item()
12121295
return_log_dict['adaptive_target_entropy_ratio'] = current_ratio
1213-
return_log_dict['alpha_loss'] = alpha_loss.item()
1296+
return_log_dict['final_alpha_loss'] = final_alpha_loss.item()
12141297
# ==================== START: 添加新日志项 ====================
12151298

12161299
# Generate task-related loss dictionaries and prefix each task-related loss with "noreduce_".
12171300
multi_task_loss_dicts = {
1218-
**generate_task_loss_dict(obs_loss_multi_task, 'noreduce_obs_loss_task{}', task_id=self.task_id),
1301+
**generate_task_loss_dict(obs_loss_multi_task, 'noreduce_obs_loss_task{}', task_id=self.task_id), #global_task_ids=global_task_ids_in_batch), # task_id=self.task_id),
12191302
**generate_task_loss_dict(latent_recon_loss_multi_task, 'noreduce_latent_recon_loss_task{}', task_id=self.task_id),
12201303
**generate_task_loss_dict(perceptual_loss_multi_task, 'noreduce_perceptual_loss_task{}', task_id=self.task_id),
12211304
**generate_task_loss_dict(latent_state_l2_norms_multi_task, 'noreduce_latent_state_l2_norms_task{}', task_id=self.task_id),
@@ -1230,6 +1313,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
12301313
**generate_task_loss_dict(lambd, 'noreduce_lambd_task{}', task_id=self.task_id),
12311314
**generate_task_loss_dict(value_priority_multi_task, 'noreduce_value_priority_task{}', task_id=self.task_id),
12321315
**generate_task_loss_dict(value_priority_mean_multi_task, 'noreduce_value_priority_mean_task{}', task_id=self.task_id),
1316+
1317+
# 新增alpha相关日志
1318+
**generate_task_loss_dict(alpha_loss_multi_task, 'noreduce_alpha_loss_task{}', self.task_id),
1319+
**generate_task_loss_dict(target_entropy_multi_task, 'noreduce_target_entropy_task{}', self.task_id),
12331320
}
12341321
return_log_dict.update(multi_task_loss_dicts)
12351322

@@ -1319,7 +1406,7 @@ def _monitor_vars_learn(self, num_tasks: int = 2) -> List[str]:
13191406
# 'value_priority',
13201407
'adaptive_alpha',
13211408
"adaptive_target_entropy_ratio",
1322-
'alpha_loss',
1409+
'final_alpha_loss',
13231410
]
13241411

13251412

@@ -1346,7 +1433,10 @@ def _monitor_vars_learn(self, num_tasks: int = 2) -> List[str]:
13461433
'noreduce_avg_weight_mag_transformer',
13471434
'noreduce_avg_weight_mag_head',
13481435
'noreduce_e_rank_last_linear',
1349-
'noreduce_e_rank_sim_norm'
1436+
'noreduce_e_rank_sim_norm',
1437+
"noreduce_alpha_loss",
1438+
"noreduce_target_entropy",
1439+
13501440
]
13511441

13521442
# Use self.task_num_for_current_rank as the number of tasks for the current rank.

0 commit comments

Comments
 (0)