diff --git a/main.py b/main.py index 84230ea7..88eec76f 100644 --- a/main.py +++ b/main.py @@ -184,7 +184,8 @@ def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mix if mixup_fn is not None: samples, targets = mixup_fn(samples, targets) - with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): + # Using torch.bfloat16 to prevent overflow. Float16 has three less integer bits compared to bfloat16 which causes NaN loss and NaN grad norms during AMP training. + with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE, dtype=torch.bfloat16): outputs = model(samples) loss = criterion(outputs, targets) loss = loss / config.TRAIN.ACCUMULATION_STEPS @@ -241,7 +242,8 @@ def validate(config, data_loader, model): target = target.cuda(non_blocking=True) # compute output - with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): + # Using torch.bfloat16 to prevent overflow. Float16 has three less integer bits compared to bfloat16 which causes NaN loss and NaN grad norms during AMP training. + with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE, dtype=torch.bfloat16): output = model(images) # measure accuracy and record loss