Skip to content

Commit 159f1e5

Browse files
author
Dyllan McCreary
authored
Merge pull request #4 from McCrearyD/accumulated-grad-scoring
Accumulated Gradient Scoring
2 parents 0751532 + 774f194 commit 159f1e5

File tree

6 files changed

+36
-9
lines changed

6 files changed

+36
-9
lines changed

rigl_torch/RigL.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,16 @@ def __name__(self):
1919
@torch.no_grad()
2020
def __call__(self, grad):
2121
mask = self.scheduler.backward_masks[self.layer]
22-
self.dense_grad = grad.clone()
22+
23+
# only calculate dense_grads when necessary
24+
if self.scheduler.check_if_backward_hook_should_accumulate_grad():
25+
if self.dense_grad is None:
26+
# initialize as all 0s so we can do a rolling average
27+
self.dense_grad = torch.zeros_like(grad)
28+
self.dense_grad += grad / self.scheduler.grad_accumulation_n
29+
else:
30+
self.dense_grad = None
31+
2332
return grad * mask
2433

2534

@@ -34,17 +43,19 @@ def _wrapped_step():
3443

3544
class RigLScheduler:
3645

37-
def __init__(self, model, optimizer, dense_allocation=1, T_end=None, sparsity_distribution='uniform', ignore_linear_layers=True, is_already_sparsified=False, delta=100, alpha=0.3, static_topo=False):
46+
def __init__(self, model, optimizer, dense_allocation=1, T_end=None, sparsity_distribution='uniform', ignore_linear_layers=True, is_already_sparsified=False, delta=100, alpha=0.3, static_topo=False, grad_accumulation_n=1):
3847
if dense_allocation <= 0 or dense_allocation > 1:
3948
raise Exception('Dense allocation must be on the interval (0, 1]. Got: %f' % dense_allocation)
4049

4150
self.model = model
4251
self.optimizer = optimizer
4352
self.sparsity_distribution = sparsity_distribution
4453
self.static_topo = static_topo
54+
self.grad_accumulation_n = grad_accumulation_n
4555
self.ignore_linear_layers = ignore_linear_layers
4656
self.backward_masks = None
4757

58+
assert self.grad_accumulation_n > 0 and self.grad_accumulation_n < delta
4859
assert self.sparsity_distribution in ('uniform', )
4960

5061
self.W, self._linear_layers_mask = get_W(model, return_linear_layers_mask=True)
@@ -200,6 +211,19 @@ def apply_mask_to_gradients(self):
200211

201212
w.grad *= mask
202213

214+
215+
def check_if_backward_hook_should_accumulate_grad(self):
216+
"""
217+
Used by the backward hooks. Basically just checks how far away the next rigl step is,
218+
if it's within `self.grad_accumulation_n` steps, return True.
219+
"""
220+
221+
if self.step >= self.T_end:
222+
return False
223+
224+
steps_til_next_rigl_step = self.delta_T - (self.step % self.delta_T)
225+
return steps_til_next_rigl_step <= self.grad_accumulation_n
226+
203227

204228
def cosine_annealing(self):
205229
return self.alpha / 2 * (1 + np.cos((self.step * np.pi) / self.T_end))

sagemaker/rigl.ipynb

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
"sagemaker_session = sagemaker.Session()\n",
1212
"\n",
1313
"bucket = sagemaker_session.default_bucket()\n",
14-
"prefix = 'sagemaker/rigl'\n",
1514
"\n",
1615
"role = sagemaker.get_execution_role()"
1716
]
@@ -66,6 +65,7 @@
6665
" 'static-topo': 0,\n",
6766
" 'alpha': 0.3,\n",
6867
" 'delta': 100,\n",
68+
"# 'grad-accumulation-n': 4, # if using a smaller batch size, this may be useful\n",
6969
" 'batch-size': 1024,\n",
7070
" 'lr': 0.1,\n",
7171
"# 'lr-warmup-end': 5,\n",
@@ -86,7 +86,6 @@
8686
"metadata": {},
8787
"outputs": [],
8888
"source": [
89-
"# estimator.fit(file_system_input) # train with FSx Lustre as input\n",
9089
"estimator.fit('s3://imagenet-compressed-oregon') # use imagenet s3 bucket"
9190
]
9291
}

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name="rigl-torch",
8-
version="0.2",
8+
version="0.3",
99
author="Dyllan McCreary",
1010
author_email="[email protected]",
1111
description="Implementation of Google Research's \"RigL\" sparse model training method in PyTorch.",

tests/test_rigl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
arch = 'resnet50'
1616
image_dimensionality = (3, 224, 224)
1717
num_classes = 1000
18-
max_iters = 6
18+
max_iters = 15
1919
T_end = int(max_iters * 0.75)
20-
delta = 2
20+
delta = 3
2121
dense_allocation = 0.1
2222
criterion = torch.nn.functional.cross_entropy
2323

train_imagenet_rigl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
help='percentage of dense parameters allowed. if None, pruning will not be used. must be on the interval (0, 1]')
5252
parser.add_argument('--delta', default=100, type=int,
5353
help='delta param for pruning')
54+
parser.add_argument('--grad-accumulation-n', default=1, type=int,
55+
help='number of gradients to accumulate before scoring for rigl')
5456
parser.add_argument('--alpha', default=0.3, type=float,
5557
help='alpha param for pruning')
5658
parser.add_argument('--static-topo', default=0, type=int, help='if 1, use random sparsity topo and remain static')
@@ -291,7 +293,7 @@ def main_worker(gpu, ngpus_per_node, args):
291293
if args.dense_allocation is not None:
292294
total_iterations = args.epochs * len(train_loader)
293295
T_end = int(0.75 * total_iterations) # (stop tweaking topology after 75% of training)
294-
pruner = RigLScheduler(model, optimizer, dense_allocation=args.dense_allocation, T_end=T_end, delta=args.delta, alpha=args.alpha, static_topo=args.static_topo)
296+
pruner = RigLScheduler(model, optimizer, dense_allocation=args.dense_allocation, T_end=T_end, delta=args.delta, alpha=args.alpha, static_topo=args.static_topo, grad_accumulation_n=args.grad_accumulation_n)
295297
print('pruning with dense allocation: %f & T_end=%i' % (args.dense_allocation, T_end))
296298
print(pruner)
297299

train_mnist_rigl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def main():
9191
help='percentage of dense parameters allowed. if None, pruning will not be used. must be on the interval (0, 1]')
9292
parser.add_argument('--delta', default=100, type=int,
9393
help='delta param for pruning')
94+
parser.add_argument('--grad-accumulation-n', default=1, type=int,
95+
help='number of gradients to accumulate before scoring for rigl')
9496
parser.add_argument('--alpha', default=0.3, type=float,
9597
help='alpha param for pruning')
9698
parser.add_argument('--static-topo', default=0, type=int, help='if 1, use random sparsity topo and remain static')
@@ -154,7 +156,7 @@ def main():
154156
pruner = lambda: True
155157
if args.dense_allocation is not None:
156158
T_end = int(0.75 * args.epochs * len(train_loader))
157-
pruner = RigLScheduler(model, optimizer, dense_allocation=args.dense_allocation, alpha=args.alpha, delta=args.delta, static_topo=args.static_topo, T_end=T_end, ignore_linear_layers=False)
159+
pruner = RigLScheduler(model, optimizer, dense_allocation=args.dense_allocation, alpha=args.alpha, delta=args.delta, static_topo=args.static_topo, T_end=T_end, ignore_linear_layers=False, grad_accumulation_n=args.grad_accumulation_n)
158160

159161
print(model)
160162
for epoch in range(1, args.epochs + 1):

0 commit comments

Comments
 (0)