Skip to content

Commit 8369d44

Browse files
authored
fixing grad norm (#37)
* fixing grad norm * adding epislon in nor division
1 parent d43b94b commit 8369d44

File tree

3 files changed

+45
-24
lines changed

3 files changed

+45
-24
lines changed

makani/models/networks/fourcastnet3.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@
3838
# more distributed stuff
3939
from makani.utils import comm
4040

41-
# layer normalization
42-
from physicsnemo.distributed.mappings import scatter_to_parallel_region, gather_from_parallel_region
43-
#from makani.mpu.layer_norm import DistributedInstanceNorm2d, DistributedLayerNorm
44-
4541
# for annotation of models
4642
from dataclasses import dataclass
4743
import physicsnemo

makani/utils/training/deterministic_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from makani.utils import visualize
5151

5252
from makani.mpu.mappings import init_gradient_reduction_hooks
53-
from makani.mpu.helpers import sync_params, gather_uneven
53+
from makani.mpu.helpers import sync_params
5454

5555
# for counting model parameters
5656
from makani.models.helpers import count_parameters

makani/utils/training/training_helpers.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -45,34 +45,59 @@ def normalize_weights(model, eps=1e-5):
4545
return
4646

4747

48-
def clip_grads(model, max_grad_norm):
49-
48+
def _compute_total_grad_norm(model, norm_type=2.0):
5049
# iterate over parameters
51-
with torch.no_grad():
52-
for param in model.parameters():
50+
gnorms = []
51+
for param in model.parameters():
5352

54-
if param.grad is None:
55-
continue
53+
if param.grad is None:
54+
continue
5655

57-
# compute local norm: compute abs first to support complex grads
56+
# compute local norm: compute abs first to support complex grads
57+
if norm_type == 2.0:
5858
gnorm = torch.sum(torch.square(torch.abs(param.grad)))
59+
else:
60+
gnorm = torch.sum(torch.abs(param.grad))
61+
62+
# compute global norm
63+
if hasattr(param, "sharded_dims_mp"):
64+
65+
for group in param.sharded_dims_mp:
66+
# continue if there is nothing to do
67+
if (group is None) or (comm.get_size(group) == 1):
68+
continue
5969

60-
# compute global norm
61-
if hasattr(param, "sharded_dims_mp"):
70+
dist.all_reduce(gnorm, group=comm.get_group(group))
6271

63-
for d, group in enumerate(param.sharded_dims_mp):
64-
# continue if there is nothing to do
65-
if (group is None) or (comm.get_size(group) == 1):
66-
continue
72+
gnorms.append(gnorm)
6773

68-
dist.all_reduce(gnorm, group=comm.get_group(group))
74+
# compute total norm
75+
if gnorms:
76+
total_gnorm = torch.sum(torch.stack(gnorms))
77+
else:
78+
total_gnorm = torch.tensor(0.0, device=model.device)
6979

70-
# compute square root
71-
gnorm = torch.sqrt(gnorm)
80+
# post-process norm
81+
if norm_type == 2.0:
82+
total_gnorm = torch.sqrt(total_gnorm)
83+
84+
return total_gnorm
85+
86+
87+
def clip_grads(model, max_grad_norm, norm_type=2.0):
88+
89+
# iterate over parameters
90+
with torch.no_grad():
91+
total_gnorm = _compute_total_grad_norm(model, norm_type)
92+
93+
clip_factor = max_grad_norm / (total_gnorm + 1e-6) # add small epsilon to avoid division by zero
94+
clip_factor = torch.clamp(clip_factor, max=1.0)
95+
96+
for param in model.parameters():
97+
if param.grad is None:
98+
continue
7299

73-
# update grads
74-
if gnorm > max_grad_norm:
75-
param.grad.mul_(max_grad_norm / gnorm)
100+
param.grad.mul_(clip_factor)
76101

77102
return
78103

0 commit comments

Comments
 (0)