22
33import numpy as np
44import torch
5+ import torch .distributed as dist
56
67from rigl_torch .util import get_W
78
@@ -33,22 +34,36 @@ def _wrapped_step():
3334
3435class RigLScheduler :
3536
36- def __init__ (self , model , optimizer , dense_allocation = 1 , T_end = None , ignore_linear_layers = True , is_already_sparsified = False , delta = 100 , alpha = 0.3 , static_topo = False ):
37+ def __init__ (self , model , optimizer , dense_allocation = 1 , T_end = None , sparsity_distribution = 'uniform' , ignore_linear_layers = True , is_already_sparsified = False , delta = 100 , alpha = 0.3 , static_topo = False ):
3738 if dense_allocation <= 0 or dense_allocation > 1 :
3839 raise Exception ('Dense allocation must be on the interval (0, 1]. Got: %f' % dense_allocation )
3940
4041 self .model = model
4142 self .optimizer = optimizer
42- self .W = get_W (model , ignore_linear_layers = ignore_linear_layers )
43- self .backward_masks = None
43+ self .sparsity_distribution = sparsity_distribution
4444 self .static_topo = static_topo
45+ self .ignore_linear_layers = ignore_linear_layers
46+ self .backward_masks = None
47+
48+ assert self .sparsity_distribution in ('uniform' , )
49+
50+ self .W , self ._linear_layers_mask = get_W (model , return_linear_layers_mask = True )
4551
4652 # modify optimizer.step() function to call "reset_momentum" after
4753 _create_step_wrapper (self , optimizer )
4854
4955 # define sparsity allocation
50- layers = len (self .W )
51- self .S = [1 - dense_allocation ] * layers # uniform sparsity
56+ self .S = []
57+ for i , (W , is_linear ) in enumerate (zip (self .W , self ._linear_layers_mask )):
58+ if i == 0 and self .sparsity_distribution == 'uniform' :
59+ # when using uniform sparsity, the first layer is always 100% dense
60+ self .S .append (0 )
61+ elif is_linear and self .ignore_linear_layers :
62+ # if choosing to ignore linear layers, keep them 100% dense
63+ self .S .append (0 )
64+ else :
65+ self .S .append (1 - dense_allocation )
66+
5267 self .N = [torch .numel (w ) for w in self .W ]
5368
5469 # randomly sparsify model according to S
@@ -69,6 +84,11 @@ def __init__(self, model, optimizer, dense_allocation=1, T_end=None, ignore_line
6984 # also, register backward hook so sparse elements cannot be recovered during normal training
7085 self .backward_hook_objects = []
7186 for i , w in enumerate (self .W ):
87+ # if sparsity is 0%, skip
88+ if self .S [i ] <= 0 :
89+ self .backward_hook_objects .append (None )
90+ continue
91+
7292 self .backward_hook_objects .append (IndexMaskHook (i , self ))
7393 w .register_hook (self .backward_hook_objects [- 1 ])
7494
@@ -83,26 +103,28 @@ def __init__(self, model, optimizer, dense_allocation=1, T_end=None, ignore_line
83103
84104 @torch .no_grad ()
85105 def random_sparsify (self ):
106+ is_dist = dist .is_initialized ()
86107 self .backward_masks = []
87108 for l , w in enumerate (self .W ):
109+ # if sparsity is 0%, skip
110+ if self .S [l ] <= 0 :
111+ self .backward_masks .append (None )
112+ continue
113+
88114 n = self .N [l ]
89115 s = int (self .S [l ] * n )
90116 perm = torch .randperm (n )
91117 perm = perm [:s ]
92- flat_mask = torch .ones (n , dtype = torch . bool , device = w .device )
118+ flat_mask = torch .ones (n , device = w .device )
93119 flat_mask [perm ] = 0
94120 mask = torch .reshape (flat_mask , w .shape )
95- w *= mask
96- self .backward_masks .append (mask )
97121
122+ if is_dist :
123+ dist .broadcast (mask , 0 )
98124
99- def __call__ (self ):
100- self .step += 1
101- if (self .step % self .delta_T ) == 0 and self .step < self .T_end : # check schedule
102- self ._rigl_step ()
103- self .rigl_steps += 1
104- return False
105- return True
125+ mask = mask .bool ()
126+ w *= mask
127+ self .backward_masks .append (mask )
106128
107129
108130 def __str__ (self ):
@@ -114,36 +136,44 @@ def __str__(self):
114136 S_str = '['
115137 sparsity_percentages = []
116138 total_params = 0
139+ total_conv_params = 0
117140 total_nonzero = 0
141+ total_conv_nonzero = 0
118142
119- for N , S , mask , W in zip (self .N , self .S , self .backward_masks , self .W ):
143+ for N , S , mask , W , is_linear in zip (self .N , self .S , self .backward_masks , self .W , self . _linear_layers_mask ):
120144 actual_S = torch .sum (W [mask == 0 ] == 0 ).item ()
121145 N_str += ('%i/%i, ' % (N - actual_S , N ))
122146 sp_p = float (N - actual_S ) / float (N ) * 100
123147 S_str += '%.2f%%, ' % sp_p
124-
125148 sparsity_percentages .append (sp_p )
126149 total_params += N
127150 total_nonzero += N - actual_S
151+ if not is_linear :
152+ total_conv_nonzero += N - actual_S
153+ total_conv_params += N
154+
128155 N_str = N_str [:- 2 ] + ']'
129156 S_str = S_str [:- 2 ] + ']'
130157
131158 s += 'nonzero_params=' + N_str + ',\n '
132159 s += 'nonzero_percentages=' + S_str + ',\n '
133160 s += 'total_nonzero_params=' + ('%i/%i (%.2f%%)' % (total_nonzero , total_params , float (total_nonzero )/ float (total_params )* 100 )) + ',\n '
161+ s += 'total_CONV_nonzero_params=' + ('%i/%i (%.2f%%)' % (total_conv_nonzero , total_conv_params , float (total_conv_nonzero )/ float (total_conv_params )* 100 )) + ',\n '
134162 s += 'step=' + str (self .step ) + ',\n '
135163 s += 'num_rigl_steps=' + str (self .rigl_steps ) + ',\n '
164+ s += 'ignoring_linear_layers=' + str (self .ignore_linear_layers ) + ',\n '
165+ s += 'sparsity_distribution=' + str (self .sparsity_distribution ) + ',\n '
136166
137167 return s + ')'
138168
139169
140- def cosine_annealing (self ):
141- return self .alpha / 2 * (1 + np .cos ((self .step * np .pi ) / self .T_end ))
142-
143-
144170 @torch .no_grad ()
145171 def reset_momentum (self ):
146- for w , mask in zip (self .W , self .backward_masks ):
172+ for w , mask , s in zip (self .W , self .backward_masks , self .S ):
173+ # if sparsity is 0%, skip
174+ if s <= 0 :
175+ continue
176+
147177 param_state = self .optimizer .state [w ]
148178 if 'momentum_buffer' in param_state :
149179 # mask the momentum matrix
@@ -153,30 +183,66 @@ def reset_momentum(self):
153183
154184 @torch .no_grad ()
155185 def apply_mask_to_weights (self ):
156- for w , mask in zip (self .W , self .backward_masks ):
186+ for w , mask , s in zip (self .W , self .backward_masks , self .S ):
187+ # if sparsity is 0%, skip
188+ if s <= 0 :
189+ continue
190+
157191 w *= mask
158192
159193
160194 @torch .no_grad ()
161195 def apply_mask_to_gradients (self ):
162- for w , mask in zip (self .W , self .backward_masks ):
196+ for w , mask , s in zip (self .W , self .backward_masks , self .S ):
197+ # if sparsity is 0%, skip
198+ if s <= 0 :
199+ continue
200+
163201 w .grad *= mask
164202
165203
166- @torch .no_grad ()
167- def _rigl_step (self ):
204+ def cosine_annealing (self ):
205+ return self .alpha / 2 * (1 + np .cos ((self .step * np .pi ) / self .T_end ))
206+
207+
208+ def __call__ (self ):
209+ self .step += 1
168210 if self .static_topo :
169- return
211+ return True
212+ if (self .step % self .delta_T ) == 0 and self .step < self .T_end : # check schedule
213+ self ._rigl_step ()
214+ self .rigl_steps += 1
215+ return False
216+ return True
217+
170218
219+ @torch .no_grad ()
220+ def _rigl_step (self ):
171221 drop_fraction = self .cosine_annealing ()
172222
223+ # if distributed these values will be populated
224+ is_dist = dist .is_initialized ()
225+ world_size = dist .get_world_size () if is_dist else None
226+
173227 for l , w in enumerate (self .W ):
228+ # if sparsity is 0%, skip
229+ if self .S [l ] <= 0 :
230+ continue
231+
174232 current_mask = self .backward_masks [l ]
175233
176234 # calculate raw scores
177235 score_drop = torch .abs (w )
178236 score_grow = torch .abs (self .backward_hook_objects [l ].dense_grad )
179237
238+ # if is distributed, synchronize scores
239+ if is_dist :
240+ dist .all_reduce (score_drop ) # get the sum of all drop scores
241+ score_drop /= world_size # divide by world size (average the drop scores)
242+
243+ dist .all_reduce (score_grow ) # get the sum of all grow scores
244+ score_grow /= world_size # divide by world size (average the grow scores)
245+
180246 # calculate drop/grow quantities
181247 n_total = self .N [l ]
182248 n_ones = torch .sum (current_mask ).item ()
0 commit comments