Skip to content

Commit bf3cd12

Browse files
committed
polish(pu): polish scale_factor in DPS
1 parent b18f892 commit bf3cd12

File tree

5 files changed

+290
-117
lines changed

5 files changed

+290
-117
lines changed

lzero/entry/train_unizero_multitask_balance_segment_ddp.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,45 @@ def train_unizero_multitask_balance_segment_ddp(
437437
tb_logger.add_scalar('Curriculum/Stage', curriculum_controller.stage, learner.train_iter)
438438
tb_logger.add_scalar('Curriculum/GlobalSolvedTasks', global_solved_count, learner.train_iter)
439439

440+
# TODO 遍历 transformer 中所有子模块,根据其名称查找 CurriculumLoRALinear 模块
441+
# transformer = policy._learn_model.world_model.transformer
442+
# for module_name, module in transformer.named_modules():
443+
# if isinstance(module, CurriculumLoRALinear) and module.adapters is not None:
444+
# for adapter_idx, scale_param in enumerate(module.adapter_scales):
445+
# tb_logger.add_scalar(
446+
# f'Curriculum/adapter_scales/{module_name}/adapter_{adapter_idx}',
447+
# scale_param().item(),
448+
# global_step=learner.train_iter
449+
# )
450+
451+
# 新增的 alpha 缩放因子日志记录
452+
try:
453+
transformer = policy._learn_model.world_model.transformer
454+
for module_name, module in transformer.named_modules():
455+
if isinstance(module, CurriculumLoRALinear):
456+
# 检查模块是否有 base_weight_scale 属性
457+
if hasattr(module, 'base_weight_scale') and module.base_weight_scale is not None:
458+
# 1. 记录基座权重的缩放因子 (alpha_0)
459+
tb_logger.add_scalar(
460+
f'Curriculum/alpha_scales/{module_name}/alpha_0_base_weight',
461+
module.base_weight_scale().item(),
462+
global_step=learner.train_iter
463+
)
464+
465+
# 检查模块是否有 adapter_scales 属性
466+
if hasattr(module, 'adapter_scales') and module.adapter_scales is not None:
467+
# 2. 遍历并记录所有适配器的缩放因子 (alpha_1, alpha_2, ...)
468+
for adapter_idx, scale_param in enumerate(module.adapter_scales):
469+
# adapter_idx 是从 0 开始的,对应 alpha_{idx+1}
470+
tb_logger.add_scalar(
471+
f'Curriculum/alpha_scales/{module_name}/alpha_{adapter_idx + 1}',
472+
scale_param().item(),
473+
global_step=learner.train_iter
474+
)
475+
except Exception as e:
476+
logging.warning(f"Failed to log alpha scales: {e}")
477+
478+
440479
# Ensure all processes are aware of a potential stage switch
441480
dist.barrier()
442481

0 commit comments

Comments
 (0)