Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ def make_parser():
Implemented loggers include `tensorboard`, `mlflow` and `wandb`.",
default="tensorboard"
)
parser.add_argument(
"--early-stopping",
dest="early_stopping",
default=False,
action="store_true",
help="Use early stopping to prevent overfitting.",
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
Expand Down Expand Up @@ -115,6 +122,13 @@ def main(exp: Exp, args):
cudnn.benchmark = True

trainer = exp.get_trainer(args)

# configure early stopping parameters
if args.early_stopping:
# requires 1% relative improvement over 10 epochs to reset patience
# available modes: "max", "min", "percentage"
trainer.early_stopper = exp.get_early_stopping(patience=10, min_delta=0.01, mode="percentage")

trainer.train()


Expand Down
1 change: 1 addition & 0 deletions yolox/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@

from .launch import launch
from .trainer import Trainer
from .trainer import EarlyStopping
48 changes: 47 additions & 1 deletion yolox/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,50 @@
synchronize
)

class EarlyStopping:
def __init__(self, patience: int, min_delta: float, mode="max"):
self.patience = patience
self.min_delta = min_delta
self.mode = mode # "max", "min", "percentage"
self.best = None
self.counter = 0

def step(self, value):
# Initialize best value on first call
if self.best is None:
self.best = value
return False

# Compute improvement depending on mode
if self.mode == "max":
improvement = value - self.best
elif self.mode == "min":
improvement = self.best - value
elif self.mode == "percentage":
if self.best == 0:
improvement = 0 # avoid division by zero
else:
improvement = (value - self.best) / abs(self.best)
else:
raise ValueError(f"Unknown mode: {self.mode}, supported modes are 'max', 'min', 'percentage'.")

# Check if improvement is sufficient
if improvement > self.min_delta:
self.best = value
self.counter = 0
else:
self.counter += 1

return self.counter >= self.patience


class Trainer:
def __init__(self, exp: Exp, args):
# init function only defines some basic attr, other attrs like model, optimizer are built in
# before_train methods.
self.exp = exp
self.args = args
self.early_stopper = None

# training related attr
self.max_epoch = exp.max_epoch
Expand Down Expand Up @@ -234,7 +271,15 @@ def after_epoch(self):

if (self.epoch + 1) % self.exp.eval_interval == 0:
all_reduce_norm(self.model)
self.evaluate_and_save_model()
ap50_95 = self.evaluate_and_save_model()

# Early stopping
if self.early_stopper is not None:
if self.early_stopper.step(ap50_95):
logger.info(f"Early stopping triggered at epoch {self.epoch}. " f"Best AP: {self.early_stopper.best}")
# save best checkpoint before exiting
self.save_ckpt("best_ckpt")
raise SystemExit

def before_iter(self):
pass
Expand Down Expand Up @@ -395,6 +440,7 @@ def evaluate_and_save_model(self):
}
self.mlflow_logger.save_checkpoints(self.args, self.exp, self.file_name, self.epoch,
metadata, update_best_ckpt)
return ap50_95

def save_ckpt(self, ckpt_name, update_best_ckpt=False, ap=None):
if self.rank == 0:
Expand Down
5 changes: 5 additions & 0 deletions yolox/exp/yolox_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,11 @@ def get_trainer(self, args):
# NOTE: trainer shouldn't be an attribute of exp object
return trainer

def get_early_stopping(self, patience, min_delta, mode):
from yolox.core import EarlyStopping

return EarlyStopping(patience=patience, min_delta=min_delta, mode=mode)

def eval(self, model, evaluator, is_distributed, half=False, return_outputs=False):
return evaluator.evaluate(model, is_distributed, half, return_outputs=return_outputs)

Expand Down