Skip to content

Commit ab91112

Browse files
authored
Improve infinity-check (k2-fsa#1862)
1. Attach the inf-check hooks if the grad scale is getting too small. 2. Add try-catch to avoid OOM in the inf-check hooks. 3. Set warmup_start=0.1 to reduce chances of divergence
1 parent 8d60280 commit ab91112

File tree

2 files changed

+36
-15
lines changed

2 files changed

+36
-15
lines changed

egs/librispeech/ASR/zipformer/train.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,23 +1165,34 @@ def save_bad_model(suffix: str = ""):
11651165
rank=rank,
11661166
)
11671167

1168-
if batch_idx % 100 == 0 and params.use_autocast:
1169-
# If the grad scale was less than 1, try increasing it. The _growth_interval
1170-
# of the grad scaler is configurable, but we can't configure it to have different
1171-
# behavior depending on the current grad scale.
1168+
if params.use_autocast:
11721169
cur_grad_scale = scaler._scale.item()
11731170

1174-
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
1175-
scaler.update(cur_grad_scale * 2.0)
11761171
if cur_grad_scale < 0.01:
11771172
if not saved_bad_model:
11781173
save_bad_model(suffix="-first-warning")
11791174
saved_bad_model = True
1175+
if not params.inf_check:
1176+
register_inf_check_hooks(model)
11801177
logging.warning(f"Grad scale is small: {cur_grad_scale}")
1178+
11811179
if cur_grad_scale < 1.0e-05:
11821180
save_bad_model()
11831181
raise_grad_scale_is_too_small_error(cur_grad_scale)
11841182

1183+
# If the grad scale was less than 1, try increasing it. The _growth_interval
1184+
# of the grad scaler is configurable, but we can't configure it to have different
1185+
# behavior depending on the current grad scale.
1186+
if (
1187+
batch_idx % 25 == 0
1188+
and cur_grad_scale < 2.0
1189+
or batch_idx % 100 == 0
1190+
and cur_grad_scale < 8.0
1191+
or batch_idx % 400 == 0
1192+
and cur_grad_scale < 32.0
1193+
):
1194+
scaler.update(cur_grad_scale * 2.0)
1195+
11851196
if batch_idx % params.log_interval == 0:
11861197
cur_lr = max(scheduler.get_last_lr())
11871198
cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0
@@ -1335,7 +1346,7 @@ def run(rank, world_size, args):
13351346
clipping_scale=2.0,
13361347
)
13371348

1338-
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
1349+
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=0.1)
13391350

13401351
if checkpoints and "optimizer" in checkpoints:
13411352
logging.info("Loading optimizer state dict")

icefall/hooks.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,24 +39,34 @@ def register_inf_check_hooks(model: nn.Module) -> None:
3939
# default param _name is a way to capture the current value of the variable "name".
4040
def forward_hook(_module, _input, _output, _name=name):
4141
if isinstance(_output, Tensor):
42-
if not torch.isfinite(_output.to(torch.float32).sum()):
43-
logging.warning(f"The sum of {_name}.output is not finite")
42+
try:
43+
if not torch.isfinite(_output.to(torch.float32).sum()):
44+
logging.warning(f"The sum of {_name}.output is not finite")
45+
except RuntimeError: # e.g. CUDA out of memory
46+
pass
4447
elif isinstance(_output, tuple):
4548
for i, o in enumerate(_output):
4649
if isinstance(o, tuple):
4750
o = o[0]
4851
if not isinstance(o, Tensor):
4952
continue
50-
if not torch.isfinite(o.to(torch.float32).sum()):
51-
logging.warning(f"The sum of {_name}.output[{i}] is not finite")
53+
try:
54+
if not torch.isfinite(o.to(torch.float32).sum()):
55+
logging.warning(
56+
f"The sum of {_name}.output[{i}] is not finite"
57+
)
58+
except RuntimeError: # e.g. CUDA out of memory
59+
pass
5260

5361
# default param _name is a way to capture the current value of the variable "name".
5462
def backward_hook(_module, _input, _output, _name=name):
5563
if isinstance(_output, Tensor):
56-
if not torch.isfinite(_output.to(torch.float32).sum()):
57-
logging.warning(
58-
f"The sum of {_name}.grad is not finite" # ": {_output}"
59-
)
64+
try:
65+
if not torch.isfinite(_output.to(torch.float32).sum()):
66+
logging.warning(f"The sum of {_name}.grad is not finite")
67+
except RuntimeError: # e.g. CUDA out of memory
68+
pass
69+
6070
elif isinstance(_output, tuple):
6171
for i, o in enumerate(_output):
6272
if isinstance(o, tuple):

0 commit comments

Comments
 (0)