@@ -437,6 +437,45 @@ def train_unizero_multitask_balance_segment_ddp(
437437 tb_logger .add_scalar ('Curriculum/Stage' , curriculum_controller .stage , learner .train_iter )
438438 tb_logger .add_scalar ('Curriculum/GlobalSolvedTasks' , global_solved_count , learner .train_iter )
439439
440+ # TODO 遍历 transformer 中所有子模块,根据其名称查找 CurriculumLoRALinear 模块
441+ # transformer = policy._learn_model.world_model.transformer
442+ # for module_name, module in transformer.named_modules():
443+ # if isinstance(module, CurriculumLoRALinear) and module.adapters is not None:
444+ # for adapter_idx, scale_param in enumerate(module.adapter_scales):
445+ # tb_logger.add_scalar(
446+ # f'Curriculum/adapter_scales/{module_name}/adapter_{adapter_idx}',
447+ # scale_param().item(),
448+ # global_step=learner.train_iter
449+ # )
450+
451+ # 新增的 alpha 缩放因子日志记录
452+ try :
453+ transformer = policy ._learn_model .world_model .transformer
454+ for module_name , module in transformer .named_modules ():
455+ if isinstance (module , CurriculumLoRALinear ):
456+ # 检查模块是否有 base_weight_scale 属性
457+ if hasattr (module , 'base_weight_scale' ) and module .base_weight_scale is not None :
458+ # 1. 记录基座权重的缩放因子 (alpha_0)
459+ tb_logger .add_scalar (
460+ f'Curriculum/alpha_scales/{ module_name } /alpha_0_base_weight' ,
461+ module .base_weight_scale ().item (),
462+ global_step = learner .train_iter
463+ )
464+
465+ # 检查模块是否有 adapter_scales 属性
466+ if hasattr (module , 'adapter_scales' ) and module .adapter_scales is not None :
467+ # 2. 遍历并记录所有适配器的缩放因子 (alpha_1, alpha_2, ...)
468+ for adapter_idx , scale_param in enumerate (module .adapter_scales ):
469+ # adapter_idx 是从 0 开始的,对应 alpha_{idx+1}
470+ tb_logger .add_scalar (
471+ f'Curriculum/alpha_scales/{ module_name } /alpha_{ adapter_idx + 1 } ' ,
472+ scale_param ().item (),
473+ global_step = learner .train_iter
474+ )
475+ except Exception as e :
476+ logging .warning (f"Failed to log alpha scales: { e } " )
477+
478+
440479 # Ensure all processes are aware of a potential stage switch
441480 dist .barrier ()
442481
0 commit comments