Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def __init__(self, lr, use_sp):
self.use_sp = use_sp

def on_optimizer_end(self, args, state, control, **kwargs):
# Skip bias update when freeze_training is enabled
if getattr(args, "freeze_training", False):
logger.warning("freeze_training is enabled! MoE e_score_correction_bias will NOT be updated.")
return

model = kwargs["model"]

usages = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1300,6 +1300,18 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
self._end_save_time = time.time()

def create_scheduler(self, num_training_steps):
# When freeze_training is enabled, use constant scheduler with lr=0
if getattr(self.args, "freeze_training", False):
logger.warning(
"WARNING: freeze_training is enabled! "
"Learning rate is set to 0 and model parameters will NOT be updated. "
"This mode is intended for debugging/profiling only, NOT for actual training."
)
from paddleformers.trainer.trainer_utils import get_constant_schedule

self.lr_scheduler = get_constant_schedule(learning_rate=0.0)
return self.lr_scheduler

if self.args.warmup_steps > 0:
warmup = self.args.warmup_steps
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def __init__(self, lr, use_sp):
self.use_sp = use_sp

def on_optimizer_end(self, args, state, control, **kwargs):
# Skip bias update when freeze_training is enabled
if getattr(args, "freeze_training", False):
logger.warning("freeze_training is enabled! MoE e_score_correction_bias will NOT be updated.")
return

model = kwargs["model"]

usages = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,18 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
self._end_save_time = time.time()

def create_scheduler(self, num_training_steps):
# When freeze_training is enabled, use constant scheduler with lr=0
if getattr(self.args, "freeze_training", False):
logger.warning(
"WARNING: freeze_training is enabled! "
"Learning rate is set to 0 and model parameters will NOT be updated. "
"This mode is intended for debugging/profiling only, NOT for actual training."
)
from paddleformers.trainer.trainer_utils import get_constant_schedule

self.lr_scheduler = get_constant_schedule(learning_rate=0.0)
return self.lr_scheduler

if self.args.warmup_steps > 0:
warmup = self.args.warmup_steps
else:
Expand Down
42 changes: 32 additions & 10 deletions paddleformers/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1826,6 +1826,16 @@ def split_dtensor_by_axis(dtensor, axis=0):
return global_micro_batchs

def optimizer_step(self, args, model, parameters_list=None):
# When freeze_training is enabled, skip optimizer step and lr scheduler step
# to keep both model parameters and optimizer state unchanged
if args.freeze_training:
logger.warning(
"freeze_training is enabled! Model parameters and optimizer state will NOT be updated. "
"This is intended for debugging/profiling only."
)
self.optimizer.clear_grad()
return

if parameters_list is None:
parameters_list = []

Expand Down Expand Up @@ -3220,16 +3230,28 @@ def create_scheduler(self, num_training_steps: int):
decay_steps = self.args.decay_steps

if self.lr_scheduler is None:
self.lr_scheduler = get_scheduler(
self.args.lr_scheduler_type,
learning_rate=self.args.learning_rate,
num_warmup_steps=warmup,
num_training_steps=decay_steps,
num_cycles=self.args.num_cycles,
lr_end=self.args.lr_end,
power=self.args.power,
min_lr=self.args.min_lr,
)
# When freeze_training is enabled, use constant scheduler with lr=0
# to ensure learning rate stays 0 throughout training
if self.args.freeze_training:
logger.warning(
"WARNING: freeze_training is enabled! "
"Learning rate is set to 0 and model parameters will NOT be updated. "
"This mode is intended for debugging/profiling only, NOT for actual training."
)
from .trainer_utils import get_constant_schedule

self.lr_scheduler = get_constant_schedule(learning_rate=0.0)
else:
self.lr_scheduler = get_scheduler(
self.args.lr_scheduler_type,
learning_rate=self.args.learning_rate,
num_warmup_steps=warmup,
num_training_steps=decay_steps,
num_cycles=self.args.num_cycles,
lr_end=self.args.lr_end,
power=self.args.power,
min_lr=self.args.min_lr,
)

return self.lr_scheduler

Expand Down
5 changes: 5 additions & 0 deletions paddleformers/trainer/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,11 @@ def __init__(self, lr=0.001, use_mp=False):
self.use_mp = use_mp

def on_optimizer_end(self, args, state, control, **kwargs):
# Skip bias update when freeze_training is enabled
if getattr(args, "freeze_training", False):
logger.warning("freeze_training is enabled! MoE e_score_correction_bias will NOT be updated.")
return

model = kwargs["model"]

biases = []
Expand Down
14 changes: 14 additions & 0 deletions paddleformers/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,20 @@ class TrainingArguments:
and decreased for the experts with more assigned tokens."""
},
)
freeze_training: bool = field(
default=False,
metadata={
"help": (
"When set to True, the training process will be frozen: "
"1) Model parameters will not be updated (backward and optimizer step are skipped). "
"2) Optimizer state remains unchanged. "
"3) Learning rate scheduler is not updated. "
"4) MoE e_score_correction_bias is not updated. "
"This is useful for debugging, profiling, or running inference-only passes through the training loop. "
"Note: The learning rate will also be effectively set to 0 when this flag is enabled."
)
},
)

log_on_each_node: bool = field(
default=True,
Expand Down
Loading