diff --git a/ppfleetx/configs/nlp/gpt/auto/pretrain_gpt_base.yaml b/ppfleetx/configs/nlp/gpt/auto/pretrain_gpt_base.yaml index 27c90819d..32d829e2e 100644 --- a/ppfleetx/configs/nlp/gpt/auto/pretrain_gpt_base.yaml +++ b/ppfleetx/configs/nlp/gpt/auto/pretrain_gpt_base.yaml @@ -24,6 +24,11 @@ Engine: ckpt_dir: +FusedPasses: + enable: False + fused_passes_list: [] + + Model: module: "GPTModuleAuto" name: "GPT" diff --git a/ppfleetx/utils/config.py b/ppfleetx/utils/config.py index c51529633..93b0e4086 100644 --- a/ppfleetx/utils/config.py +++ b/ppfleetx/utils/config.py @@ -579,6 +579,11 @@ def process_auto_strategy(config): tuning.run_after_tuning = tuning_cfg.get('run_after_tuning', True) tuning.debug = tuning_cfg.get('debug', True) + fused_passes_cfg = config.get('FusedPasses', {}) + fused_passes = strategy.fused_passes + fused_passes.enable = fused_passes_cfg.get('enable', False) + fused_passes.fused_passes_list = fused_passes_cfg.get('fused_passes_list', []) + engine_cfg = config['Engine'] engine_cfg['strategy'] = strategy