@@ -19,7 +19,16 @@ def __name__(self):
1919 @torch .no_grad ()
2020 def __call__ (self , grad ):
2121 mask = self .scheduler .backward_masks [self .layer ]
22- self .dense_grad = grad .clone ()
22+
23+ # only calculate dense_grads when necessary
24+ if self .scheduler .check_if_backward_hook_should_accumulate_grad ():
25+ if self .dense_grad is None :
26+ # initialize as all 0s so we can do a rolling average
27+ self .dense_grad = torch .zeros_like (grad )
28+ self .dense_grad += grad / self .scheduler .grad_accumulation_n
29+ else :
30+ self .dense_grad = None
31+
2332 return grad * mask
2433
2534
@@ -34,17 +43,19 @@ def _wrapped_step():
3443
3544class RigLScheduler :
3645
37- def __init__ (self , model , optimizer , dense_allocation = 1 , T_end = None , sparsity_distribution = 'uniform' , ignore_linear_layers = True , is_already_sparsified = False , delta = 100 , alpha = 0.3 , static_topo = False ):
46+ def __init__ (self , model , optimizer , dense_allocation = 1 , T_end = None , sparsity_distribution = 'uniform' , ignore_linear_layers = True , is_already_sparsified = False , delta = 100 , alpha = 0.3 , static_topo = False , grad_accumulation_n = 1 ):
3847 if dense_allocation <= 0 or dense_allocation > 1 :
3948 raise Exception ('Dense allocation must be on the interval (0, 1]. Got: %f' % dense_allocation )
4049
4150 self .model = model
4251 self .optimizer = optimizer
4352 self .sparsity_distribution = sparsity_distribution
4453 self .static_topo = static_topo
54+ self .grad_accumulation_n = grad_accumulation_n
4555 self .ignore_linear_layers = ignore_linear_layers
4656 self .backward_masks = None
4757
58+ assert self .grad_accumulation_n > 0 and self .grad_accumulation_n < delta
4859 assert self .sparsity_distribution in ('uniform' , )
4960
5061 self .W , self ._linear_layers_mask = get_W (model , return_linear_layers_mask = True )
@@ -200,6 +211,19 @@ def apply_mask_to_gradients(self):
200211
201212 w .grad *= mask
202213
214+
215+ def check_if_backward_hook_should_accumulate_grad (self ):
216+ """
217+ Used by the backward hooks. Basically just checks how far away the next rigl step is,
218+ if it's within `self.grad_accumulation_n` steps, return True.
219+ """
220+
221+ if self .step >= self .T_end :
222+ return False
223+
224+ steps_til_next_rigl_step = self .delta_T - (self .step % self .delta_T )
225+ return steps_til_next_rigl_step <= self .grad_accumulation_n
226+
203227
204228 def cosine_annealing (self ):
205229 return self .alpha / 2 * (1 + np .cos ((self .step * np .pi ) / self .T_end ))
0 commit comments