We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 15c40c4 commit fc489c3Copy full SHA for fc489c3
makani/utils/training/training_helpers.py
@@ -90,7 +90,7 @@ def clip_grads(model, max_grad_norm, norm_type=2.0):
90
with torch.no_grad():
91
total_gnorm = _compute_total_grad_norm(model, norm_type)
92
93
- clip_factor = max_grad_norm / total_gnorm
+ clip_factor = max_grad_norm / (total_gnorm + 1e-6) # add small epsilon to avoid division by zero
94
clip_factor = torch.clamp(clip_factor, max=1.0)
95
96
for param in model.parameters():
0 commit comments