Skip to content

Commit 75d3944

Browse files
author
Dyllan McCreary
authored
Merge pull request #3 from McCrearyD/distributed
Added Distributed Functionality for RigL
2 parents 346c125 + 709a07e commit 75d3944

File tree

4 files changed

+212
-73
lines changed

4 files changed

+212
-73
lines changed

rigl_torch/RigL.py

Lines changed: 93 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
import torch
5+
import torch.distributed as dist
56

67
from rigl_torch.util import get_W
78

@@ -33,22 +34,36 @@ def _wrapped_step():
3334

3435
class RigLScheduler:
3536

36-
def __init__(self, model, optimizer, dense_allocation=1, T_end=None, ignore_linear_layers=True, is_already_sparsified=False, delta=100, alpha=0.3, static_topo=False):
37+
def __init__(self, model, optimizer, dense_allocation=1, T_end=None, sparsity_distribution='uniform', ignore_linear_layers=True, is_already_sparsified=False, delta=100, alpha=0.3, static_topo=False):
3738
if dense_allocation <= 0 or dense_allocation > 1:
3839
raise Exception('Dense allocation must be on the interval (0, 1]. Got: %f' % dense_allocation)
3940

4041
self.model = model
4142
self.optimizer = optimizer
42-
self.W = get_W(model, ignore_linear_layers=ignore_linear_layers)
43-
self.backward_masks = None
43+
self.sparsity_distribution = sparsity_distribution
4444
self.static_topo = static_topo
45+
self.ignore_linear_layers = ignore_linear_layers
46+
self.backward_masks = None
47+
48+
assert self.sparsity_distribution in ('uniform', )
49+
50+
self.W, self._linear_layers_mask = get_W(model, return_linear_layers_mask=True)
4551

4652
# modify optimizer.step() function to call "reset_momentum" after
4753
_create_step_wrapper(self, optimizer)
4854

4955
# define sparsity allocation
50-
layers = len(self.W)
51-
self.S = [1-dense_allocation] * layers # uniform sparsity
56+
self.S = []
57+
for i, (W, is_linear) in enumerate(zip(self.W, self._linear_layers_mask)):
58+
if i == 0 and self.sparsity_distribution == 'uniform':
59+
# when using uniform sparsity, the first layer is always 100% dense
60+
self.S.append(0)
61+
elif is_linear and self.ignore_linear_layers:
62+
# if choosing to ignore linear layers, keep them 100% dense
63+
self.S.append(0)
64+
else:
65+
self.S.append(1-dense_allocation)
66+
5267
self.N = [torch.numel(w) for w in self.W]
5368

5469
# randomly sparsify model according to S
@@ -69,6 +84,11 @@ def __init__(self, model, optimizer, dense_allocation=1, T_end=None, ignore_line
6984
# also, register backward hook so sparse elements cannot be recovered during normal training
7085
self.backward_hook_objects = []
7186
for i, w in enumerate(self.W):
87+
# if sparsity is 0%, skip
88+
if self.S[i] <= 0:
89+
self.backward_hook_objects.append(None)
90+
continue
91+
7292
self.backward_hook_objects.append(IndexMaskHook(i, self))
7393
w.register_hook(self.backward_hook_objects[-1])
7494

@@ -83,26 +103,28 @@ def __init__(self, model, optimizer, dense_allocation=1, T_end=None, ignore_line
83103

84104
@torch.no_grad()
85105
def random_sparsify(self):
106+
is_dist = dist.is_initialized()
86107
self.backward_masks = []
87108
for l, w in enumerate(self.W):
109+
# if sparsity is 0%, skip
110+
if self.S[l] <= 0:
111+
self.backward_masks.append(None)
112+
continue
113+
88114
n = self.N[l]
89115
s = int(self.S[l] * n)
90116
perm = torch.randperm(n)
91117
perm = perm[:s]
92-
flat_mask = torch.ones(n, dtype=torch.bool, device=w.device)
118+
flat_mask = torch.ones(n, device=w.device)
93119
flat_mask[perm] = 0
94120
mask = torch.reshape(flat_mask, w.shape)
95-
w *= mask
96-
self.backward_masks.append(mask)
97121

122+
if is_dist:
123+
dist.broadcast(mask, 0)
98124

99-
def __call__(self):
100-
self.step += 1
101-
if (self.step % self.delta_T) == 0 and self.step < self.T_end: # check schedule
102-
self._rigl_step()
103-
self.rigl_steps += 1
104-
return False
105-
return True
125+
mask = mask.bool()
126+
w *= mask
127+
self.backward_masks.append(mask)
106128

107129

108130
def __str__(self):
@@ -114,36 +136,44 @@ def __str__(self):
114136
S_str = '['
115137
sparsity_percentages = []
116138
total_params = 0
139+
total_conv_params = 0
117140
total_nonzero = 0
141+
total_conv_nonzero = 0
118142

119-
for N, S, mask, W in zip(self.N, self.S, self.backward_masks, self.W):
143+
for N, S, mask, W, is_linear in zip(self.N, self.S, self.backward_masks, self.W, self._linear_layers_mask):
120144
actual_S = torch.sum(W[mask == 0] == 0).item()
121145
N_str += ('%i/%i, ' % (N-actual_S, N))
122146
sp_p = float(N-actual_S) / float(N) * 100
123147
S_str += '%.2f%%, ' % sp_p
124-
125148
sparsity_percentages.append(sp_p)
126149
total_params += N
127150
total_nonzero += N-actual_S
151+
if not is_linear:
152+
total_conv_nonzero += N-actual_S
153+
total_conv_params += N
154+
128155
N_str = N_str[:-2] + ']'
129156
S_str = S_str[:-2] + ']'
130157

131158
s += 'nonzero_params=' + N_str + ',\n'
132159
s += 'nonzero_percentages=' + S_str + ',\n'
133160
s += 'total_nonzero_params=' + ('%i/%i (%.2f%%)' % (total_nonzero, total_params, float(total_nonzero)/float(total_params)*100)) + ',\n'
161+
s += 'total_CONV_nonzero_params=' + ('%i/%i (%.2f%%)' % (total_conv_nonzero, total_conv_params, float(total_conv_nonzero)/float(total_conv_params)*100)) + ',\n'
134162
s += 'step=' + str(self.step) + ',\n'
135163
s += 'num_rigl_steps=' + str(self.rigl_steps) + ',\n'
164+
s += 'ignoring_linear_layers=' + str(self.ignore_linear_layers) + ',\n'
165+
s += 'sparsity_distribution=' + str(self.sparsity_distribution) + ',\n'
136166

137167
return s + ')'
138168

139169

140-
def cosine_annealing(self):
141-
return self.alpha / 2 * (1 + np.cos((self.step * np.pi) / self.T_end))
142-
143-
144170
@torch.no_grad()
145171
def reset_momentum(self):
146-
for w, mask in zip(self.W, self.backward_masks):
172+
for w, mask, s in zip(self.W, self.backward_masks, self.S):
173+
# if sparsity is 0%, skip
174+
if s <= 0:
175+
continue
176+
147177
param_state = self.optimizer.state[w]
148178
if 'momentum_buffer' in param_state:
149179
# mask the momentum matrix
@@ -153,30 +183,66 @@ def reset_momentum(self):
153183

154184
@torch.no_grad()
155185
def apply_mask_to_weights(self):
156-
for w, mask in zip(self.W, self.backward_masks):
186+
for w, mask, s in zip(self.W, self.backward_masks, self.S):
187+
# if sparsity is 0%, skip
188+
if s <= 0:
189+
continue
190+
157191
w *= mask
158192

159193

160194
@torch.no_grad()
161195
def apply_mask_to_gradients(self):
162-
for w, mask in zip(self.W, self.backward_masks):
196+
for w, mask, s in zip(self.W, self.backward_masks, self.S):
197+
# if sparsity is 0%, skip
198+
if s <= 0:
199+
continue
200+
163201
w.grad *= mask
164202

165203

166-
@torch.no_grad()
167-
def _rigl_step(self):
204+
def cosine_annealing(self):
205+
return self.alpha / 2 * (1 + np.cos((self.step * np.pi) / self.T_end))
206+
207+
208+
def __call__(self):
209+
self.step += 1
168210
if self.static_topo:
169-
return
211+
return True
212+
if (self.step % self.delta_T) == 0 and self.step < self.T_end: # check schedule
213+
self._rigl_step()
214+
self.rigl_steps += 1
215+
return False
216+
return True
217+
170218

219+
@torch.no_grad()
220+
def _rigl_step(self):
171221
drop_fraction = self.cosine_annealing()
172222

223+
# if distributed these values will be populated
224+
is_dist = dist.is_initialized()
225+
world_size = dist.get_world_size() if is_dist else None
226+
173227
for l, w in enumerate(self.W):
228+
# if sparsity is 0%, skip
229+
if self.S[l] <= 0:
230+
continue
231+
174232
current_mask = self.backward_masks[l]
175233

176234
# calculate raw scores
177235
score_drop = torch.abs(w)
178236
score_grow = torch.abs(self.backward_hook_objects[l].dense_grad)
179237

238+
# if is distributed, synchronize scores
239+
if is_dist:
240+
dist.all_reduce(score_drop) # get the sum of all drop scores
241+
score_drop /= world_size # divide by world size (average the drop scores)
242+
243+
dist.all_reduce(score_grow) # get the sum of all grow scores
244+
score_grow /= world_size # divide by world size (average the grow scores)
245+
180246
# calculate drop/grow quantities
181247
n_total = self.N[l]
182248
n_ones = torch.sum(current_mask).item()

rigl_torch/util.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,23 @@
22
import torchvision
33

44

5-
def get_conv_layers_with_activations(model, i=0, layers=None, identity_indices=None, ignore_linear_layers=False):
5+
def get_conv_layers_with_activations(model, i=0, layers=None, linear_layers_mask=None):
66
if layers is None:
77
layers = []
8-
if identity_indices is None:
9-
identity_indices = []
8+
if linear_layers_mask is None:
9+
linear_layers_mask = []
1010

1111
for layer_name, p in model._modules.items():
1212
if isinstance(p, torch.nn.Conv2d):
1313
layers.append([p])
14+
linear_layers_mask.append(0)
1415
i += 1
1516
# elif isinstance(p, torch.nn.AdaptiveAvgPool2d):
1617
# layers.append([p])
1718
# i += 1
18-
elif isinstance(p, torch.nn.Linear) and not ignore_linear_layers:
19+
elif isinstance(p, torch.nn.Linear):
1920
layers.append([p])
21+
linear_layers_mask.append(1)
2022
elif isinstance(p, torch.nn.BatchNorm2d):
2123
layers[-1].append(p)
2224
elif isinstance(p, torch.nn.ReLU):
@@ -25,21 +27,27 @@ def get_conv_layers_with_activations(model, i=0, layers=None, identity_indices=N
2527
layers[-1].append(p)
2628
elif layer_name == 'downsample':
2729
layers.append(p)
30+
linear_layers_mask.append(0)
2831
elif isinstance(p, torchvision.models.resnet.Bottleneck) or isinstance(p, torchvision.models.resnet.BasicBlock):
29-
if hasattr(p, 'downsample') and p.downsample is not None:
30-
identity_indices.append(i)
31-
_, identity_indices, i = get_conv_layers_with_activations(p, i=i, layers=layers, identity_indices=identity_indices)
32+
# if hasattr(p, 'downsample') and p.downsample is not None:
33+
# identity_indices.append(i)
34+
_, linear_layers_mask, i = get_conv_layers_with_activations(p, i=i, layers=layers, linear_layers_mask=linear_layers_mask)
3235
else:
33-
_, identity_indices, i = get_conv_layers_with_activations(p, i=i, layers=layers, identity_indices=identity_indices)
36+
_, linear_layers_mask, i = get_conv_layers_with_activations(p, i=i, layers=layers, linear_layers_mask=linear_layers_mask)
3437

35-
return layers, identity_indices, i
38+
return layers, linear_layers_mask, i
3639

3740

38-
def get_W(model, ignore_linear_layers=False):
39-
layers, _, _ = get_conv_layers_with_activations(model, ignore_linear_layers=ignore_linear_layers)
41+
def get_W(model, return_linear_layers_mask=False):
42+
layers, linear_layers_mask, _ = get_conv_layers_with_activations(model)
4043

4144
W = []
4245
for layer in layers:
4346
idx = 0 if hasattr(layer[0], 'weight') else 1
4447
W.append(layer[idx].weight)
48+
49+
assert len(W) == len(linear_layers_mask)
50+
51+
if return_linear_layers_mask:
52+
return W, linear_layers_mask
4553
return W

0 commit comments

Comments
 (0)