diff --git a/examples/experiments/ernie_pretrain/ernie/src/callbacks/moe_correction_bias_adjust_callback.py b/examples/experiments/ernie_pretrain/ernie/src/callbacks/moe_correction_bias_adjust_callback.py index 2bb8aa28b39..d52f2116464 100644 --- a/examples/experiments/ernie_pretrain/ernie/src/callbacks/moe_correction_bias_adjust_callback.py +++ b/examples/experiments/ernie_pretrain/ernie/src/callbacks/moe_correction_bias_adjust_callback.py @@ -32,6 +32,11 @@ def __init__(self, lr, use_sp): self.use_sp = use_sp def on_optimizer_end(self, args, state, control, **kwargs): + # Skip bias update when freeze_training is enabled + if getattr(args, "freeze_training", False): + logger.warning("freeze_training is enabled! MoE e_score_correction_bias will NOT be updated.") + return + model = kwargs["model"] usages = {} diff --git a/examples/experiments/ernie_pretrain/ernie/src/trainers/pretraining_trainer.py b/examples/experiments/ernie_pretrain/ernie/src/trainers/pretraining_trainer.py index 84cc77ff5ef..c4de8418c0f 100644 --- a/examples/experiments/ernie_pretrain/ernie/src/trainers/pretraining_trainer.py +++ b/examples/experiments/ernie_pretrain/ernie/src/trainers/pretraining_trainer.py @@ -1300,6 +1300,18 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, self._end_save_time = time.time() def create_scheduler(self, num_training_steps): + # When freeze_training is enabled, use constant scheduler with lr=0 + if getattr(self.args, "freeze_training", False): + logger.warning( + "WARNING: freeze_training is enabled! " + "Learning rate is set to 0 and model parameters will NOT be updated. " + "This mode is intended for debugging/profiling only, NOT for actual training." + ) + from paddleformers.trainer.trainer_utils import get_constant_schedule + + self.lr_scheduler = get_constant_schedule(learning_rate=0.0) + return self.lr_scheduler + if self.args.warmup_steps > 0: warmup = self.args.warmup_steps else: diff --git a/paddleformers/cli/train/ernie_pretrain/src/callbacks/moe_correction_bias_adjust_callback.py b/paddleformers/cli/train/ernie_pretrain/src/callbacks/moe_correction_bias_adjust_callback.py index d7131a72f78..21cacad25a8 100644 --- a/paddleformers/cli/train/ernie_pretrain/src/callbacks/moe_correction_bias_adjust_callback.py +++ b/paddleformers/cli/train/ernie_pretrain/src/callbacks/moe_correction_bias_adjust_callback.py @@ -34,6 +34,11 @@ def __init__(self, lr, use_sp): self.use_sp = use_sp def on_optimizer_end(self, args, state, control, **kwargs): + # Skip bias update when freeze_training is enabled + if getattr(args, "freeze_training", False): + logger.warning("freeze_training is enabled! MoE e_score_correction_bias will NOT be updated.") + return + model = kwargs["model"] usages = {} diff --git a/paddleformers/cli/train/ernie_pretrain/src/trainers/pretraining_trainer.py b/paddleformers/cli/train/ernie_pretrain/src/trainers/pretraining_trainer.py index bff46e5dd8c..260a8697c8e 100644 --- a/paddleformers/cli/train/ernie_pretrain/src/trainers/pretraining_trainer.py +++ b/paddleformers/cli/train/ernie_pretrain/src/trainers/pretraining_trainer.py @@ -1314,6 +1314,18 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, self._end_save_time = time.time() def create_scheduler(self, num_training_steps): + # When freeze_training is enabled, use constant scheduler with lr=0 + if getattr(self.args, "freeze_training", False): + logger.warning( + "WARNING: freeze_training is enabled! " + "Learning rate is set to 0 and model parameters will NOT be updated. " + "This mode is intended for debugging/profiling only, NOT for actual training." + ) + from paddleformers.trainer.trainer_utils import get_constant_schedule + + self.lr_scheduler = get_constant_schedule(learning_rate=0.0) + return self.lr_scheduler + if self.args.warmup_steps > 0: warmup = self.args.warmup_steps else: diff --git a/paddleformers/trainer/trainer.py b/paddleformers/trainer/trainer.py index 338e048cd4a..c84c4acc6a2 100644 --- a/paddleformers/trainer/trainer.py +++ b/paddleformers/trainer/trainer.py @@ -1826,6 +1826,16 @@ def split_dtensor_by_axis(dtensor, axis=0): return global_micro_batchs def optimizer_step(self, args, model, parameters_list=None): + # When freeze_training is enabled, skip optimizer step and lr scheduler step + # to keep both model parameters and optimizer state unchanged + if args.freeze_training: + logger.warning( + "freeze_training is enabled! Model parameters and optimizer state will NOT be updated. " + "This is intended for debugging/profiling only." + ) + self.optimizer.clear_grad() + return + if parameters_list is None: parameters_list = [] @@ -3220,16 +3230,28 @@ def create_scheduler(self, num_training_steps: int): decay_steps = self.args.decay_steps if self.lr_scheduler is None: - self.lr_scheduler = get_scheduler( - self.args.lr_scheduler_type, - learning_rate=self.args.learning_rate, - num_warmup_steps=warmup, - num_training_steps=decay_steps, - num_cycles=self.args.num_cycles, - lr_end=self.args.lr_end, - power=self.args.power, - min_lr=self.args.min_lr, - ) + # When freeze_training is enabled, use constant scheduler with lr=0 + # to ensure learning rate stays 0 throughout training + if self.args.freeze_training: + logger.warning( + "WARNING: freeze_training is enabled! " + "Learning rate is set to 0 and model parameters will NOT be updated. " + "This mode is intended for debugging/profiling only, NOT for actual training." + ) + from .trainer_utils import get_constant_schedule + + self.lr_scheduler = get_constant_schedule(learning_rate=0.0) + else: + self.lr_scheduler = get_scheduler( + self.args.lr_scheduler_type, + learning_rate=self.args.learning_rate, + num_warmup_steps=warmup, + num_training_steps=decay_steps, + num_cycles=self.args.num_cycles, + lr_end=self.args.lr_end, + power=self.args.power, + min_lr=self.args.min_lr, + ) return self.lr_scheduler diff --git a/paddleformers/trainer/trainer_callback.py b/paddleformers/trainer/trainer_callback.py index 19a67f4baf8..a9ac345d175 100644 --- a/paddleformers/trainer/trainer_callback.py +++ b/paddleformers/trainer/trainer_callback.py @@ -779,6 +779,11 @@ def __init__(self, lr=0.001, use_mp=False): self.use_mp = use_mp def on_optimizer_end(self, args, state, control, **kwargs): + # Skip bias update when freeze_training is enabled + if getattr(args, "freeze_training", False): + logger.warning("freeze_training is enabled! MoE e_score_correction_bias will NOT be updated.") + return + model = kwargs["model"] biases = [] diff --git a/paddleformers/trainer/training_args.py b/paddleformers/trainer/training_args.py index d3f00a61666..f1397161cfe 100644 --- a/paddleformers/trainer/training_args.py +++ b/paddleformers/trainer/training_args.py @@ -515,6 +515,20 @@ class TrainingArguments: and decreased for the experts with more assigned tokens.""" }, ) + freeze_training: bool = field( + default=False, + metadata={ + "help": ( + "When set to True, the training process will be frozen: " + "1) Model parameters will not be updated (backward and optimizer step are skipped). " + "2) Optimizer state remains unchanged. " + "3) Learning rate scheduler is not updated. " + "4) MoE e_score_correction_bias is not updated. " + "This is useful for debugging, profiling, or running inference-only passes through the training loop. " + "Note: The learning rate will also be effectively set to 0 when this flag is enabled." + ) + }, + ) log_on_each_node: bool = field( default=True,