Skip to content

Commit ab597c8

Browse files
committed
Remove deprecated evalutate_during_training (#8852)
* Remove deprecated `evalutate_during_training` * Update src/transformers/training_args_tf.py Co-authored-by: Lysandre Debut <[email protected]> Co-authored-by: Lysandre Debut <[email protected]>
1 parent e72b4fa commit ab597c8

9 files changed

+24
-13
lines changed

examples/seq2seq/builtin_trainer/finetune.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
python finetune_trainer.py \
44
--learning_rate=3e-5 \
55
--fp16 \
6-
--do_train --do_eval --do_predict --evaluate_during_training \
6+
--do_train --do_eval --do_predict \
7+
--evaluation_strategy steps \
78
--predict_with_generate \
89
--n_val 1000 \
910
"$@"

examples/seq2seq/builtin_trainer/finetune_tpu.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ export TPU_NUM_CORES=8
55
python xla_spawn.py --num_cores $TPU_NUM_CORES \
66
finetune_trainer.py \
77
--learning_rate=3e-5 \
8-
--do_train --do_eval --evaluate_during_training \
8+
--do_train --do_eval \
9+
--evaluation_strategy steps \
910
--prediction_loss_only \
1011
--n_val 1000 \
1112
"$@"

examples/seq2seq/builtin_trainer/train_distil_marian_enro.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ python finetune_trainer.py \
1616
--num_train_epochs=6 \
1717
--save_steps 3000 --eval_steps 3000 \
1818
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
19-
--do_train --do_eval --do_predict --evaluate_during_training\
19+
--do_train --do_eval --do_predict \
20+
--evaluation_strategy steps \
2021
--predict_with_generate --logging_first_step \
2122
--task translation --label_smoothing 0.1 \
2223
"$@"

examples/seq2seq/builtin_trainer/train_distil_marian_enro_tpu.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ python xla_spawn.py --num_cores $TPU_NUM_CORES \
1717
--save_steps 500 --eval_steps 500 \
1818
--logging_first_step --logging_steps 200 \
1919
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
20-
--do_train --do_eval --evaluate_during_training \
20+
--do_train --do_eval \
21+
--evaluation_strategy steps \
2122
--prediction_loss_only \
2223
--task translation --label_smoothing 0.1 \
2324
"$@"

examples/seq2seq/builtin_trainer/train_distilbart_cnn.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ python finetune_trainer.py \
1919
--save_steps 3000 --eval_steps 3000 \
2020
--logging_first_step \
2121
--max_target_length 56 --val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \
22-
--do_train --do_eval --do_predict --evaluate_during_training \
22+
--do_train --do_eval --do_predict \
23+
--evaluation_strategy steps \
2324
--predict_with_generate --sortish_sampler \
2425
"$@"

examples/seq2seq/builtin_trainer/train_mbart_cc25_enro.sh

+3-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ python finetune_trainer.py \
1515
--sortish_sampler \
1616
--num_train_epochs 6 \
1717
--save_steps 25000 --eval_steps 25000 --logging_steps 1000 \
18-
--do_train --do_eval --do_predict --evaluate_during_training \
19-
--predict_with_generate --logging_first_step
18+
--do_train --do_eval --do_predict \
19+
--evaluation_strategy steps \
20+
--predict_with_generate --logging_first_step \
2021
--task translation \
2122
"$@"

src/transformers/integrations.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import math
33
import os
44

5+
from .trainer_utils import EvaluationStrategy
56
from .utils import logging
67

78

@@ -212,13 +213,13 @@ def _objective(trial, checkpoint_dir=None):
212213
# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
213214
if isinstance(
214215
kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
215-
) and (not trainer.args.do_eval or not trainer.args.evaluate_during_training):
216+
) and (not trainer.args.do_eval or trainer.args.evaluation_strategy == EvaluationStrategy.NO):
216217
raise RuntimeError(
217218
"You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
218219
"This means your trials will not report intermediate results to Ray Tune, and "
219220
"can thus not be stopped early or used to exploit other trials parameters. "
220221
"If this is what you want, do not use {cls}. If you would like to use {cls}, "
221-
"make sure you pass `do_eval=True` and `evaluate_during_training=True` in the "
222+
"make sure you pass `do_eval=True` and `evaluation_strategy='steps'` in the "
222223
"Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
223224
)
224225

src/transformers/trainer_tf.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from .modeling_tf_utils import TFPreTrainedModel
2121
from .optimization_tf import GradientAccumulator, create_optimizer
22-
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, set_seed
22+
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, EvaluationStrategy, PredictionOutput, set_seed
2323
from .training_args_tf import TFTrainingArguments
2424
from .utils import logging
2525

@@ -561,7 +561,7 @@ def train(self) -> None:
561561

562562
if (
563563
self.args.eval_steps > 0
564-
and self.args.evaluate_during_training
564+
and self.args.evaluate_strategy == EvaluationStrategy.STEPS
565565
and self.global_step % self.args.eval_steps == 0
566566
):
567567
self.evaluate()

src/transformers/training_args_tf.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,12 @@ class TFTrainingArguments(TrainingArguments):
3434
Whether to run evaluation on the dev set or not.
3535
do_predict (:obj:`bool`, `optional`, defaults to :obj:`False`):
3636
Whether to run predictions on the test set or not.
37-
evaluate_during_training (:obj:`bool`, `optional`, defaults to :obj:`False`):
38-
Whether to run evaluation during training at each logging step or not.
37+
evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`):
38+
The evaluation strategy to adopt during training. Possible values are:
39+
40+
* :obj:`"no"`: No evaluation is done during training.
41+
* :obj:`"steps"`: Evaluation is done (and logged) every :obj:`eval_steps`.
42+
3943
per_device_train_batch_size (:obj:`int`, `optional`, defaults to 8):
4044
The batch size per GPU/TPU core/CPU for training.
4145
per_device_eval_batch_size (:obj:`int`, `optional`, defaults to 8):

0 commit comments

Comments
 (0)