@@ -45,34 +45,59 @@ def normalize_weights(model, eps=1e-5):
4545 return
4646
4747
48- def clip_grads (model , max_grad_norm ):
49-
48+ def _compute_total_grad_norm (model , norm_type = 2.0 ):
5049 # iterate over parameters
51- with torch . no_grad ():
52- for param in model .parameters ():
50+ gnorms = []
51+ for param in model .parameters ():
5352
54- if param .grad is None :
55- continue
53+ if param .grad is None :
54+ continue
5655
57- # compute local norm: compute abs first to support complex grads
56+ # compute local norm: compute abs first to support complex grads
57+ if norm_type == 2.0 :
5858 gnorm = torch .sum (torch .square (torch .abs (param .grad )))
59+ else :
60+ gnorm = torch .sum (torch .abs (param .grad ))
61+
62+ # compute global norm
63+ if hasattr (param , "sharded_dims_mp" ):
64+
65+ for group in param .sharded_dims_mp :
66+ # continue if there is nothing to do
67+ if (group is None ) or (comm .get_size (group ) == 1 ):
68+ continue
5969
60- # compute global norm
61- if hasattr (param , "sharded_dims_mp" ):
70+ dist .all_reduce (gnorm , group = comm .get_group (group ))
6271
63- for d , group in enumerate (param .sharded_dims_mp ):
64- # continue if there is nothing to do
65- if (group is None ) or (comm .get_size (group ) == 1 ):
66- continue
72+ gnorms .append (gnorm )
6773
68- dist .all_reduce (gnorm , group = comm .get_group (group ))
74+ # compute total norm
75+ if gnorms :
76+ total_gnorm = torch .sum (torch .stack (gnorms ))
77+ else :
78+ total_gnorm = torch .tensor (0.0 , device = model .device )
6979
70- # compute square root
71- gnorm = torch .sqrt (gnorm )
80+ # post-process norm
81+ if norm_type == 2.0 :
82+ total_gnorm = torch .sqrt (total_gnorm )
83+
84+ return total_gnorm
85+
86+
87+ def clip_grads (model , max_grad_norm , norm_type = 2.0 ):
88+
89+ # iterate over parameters
90+ with torch .no_grad ():
91+ total_gnorm = _compute_total_grad_norm (model , norm_type )
92+
93+ clip_factor = max_grad_norm / total_gnorm
94+ clip_factor = torch .clamp (clip_factor , max = 1.0 )
95+
96+ for param in model .parameters ():
97+ if param .grad is None :
98+ continue
7299
73- # update grads
74- if gnorm > max_grad_norm :
75- param .grad .mul_ (max_grad_norm / gnorm )
100+ param .grad .mul_ (clip_factor )
76101
77102 return
78103
0 commit comments