Skip to content

Commit 5b2fb86

Browse files
author
McCrearyD
committed
store dense_allocation
1 parent f2ed1e5 commit 5b2fb86

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

rigl_torch/RigL.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(self, model, optimizer, dense_allocation=1, T_end=None, sparsity_di
5555
# modify optimizer.step() function to call "reset_momentum" after
5656
_create_step_wrapper(self, optimizer)
5757

58+
self.dense_allocation = dense_allocation
5859
self.N = [torch.numel(w) for w in self.W]
5960

6061
if state_dict is not None:
@@ -71,12 +72,16 @@ def __init__(self, model, optimizer, dense_allocation=1, T_end=None, sparsity_di
7172
# define sparsity allocation
7273
self.S = []
7374
for i, (W, is_linear) in enumerate(zip(self.W, self._linear_layers_mask)):
74-
if i == 0 and self.sparsity_distribution == 'uniform':
75-
# when using uniform sparsity, the first layer is always 100% dense
75+
# when using uniform sparsity, the first layer is always 100% dense
76+
# UNLESS there is only 1 layer
77+
is_first_layer = i == 0
78+
if is_first_layer and self.sparsity_distribution == 'uniform' and len(self.W) > 1:
7679
self.S.append(0)
80+
7781
elif is_linear and self.ignore_linear_layers:
7882
# if choosing to ignore linear layers, keep them 100% dense
7983
self.S.append(0)
84+
8085
else:
8186
self.S.append(1-dense_allocation)
8287

@@ -115,6 +120,7 @@ def __init__(self, model, optimizer, dense_allocation=1, T_end=None, sparsity_di
115120

116121
def state_dict(self):
117122
obj = {
123+
'dense_allocation': self.dense_allocation,
118124
'S': self.S,
119125
'N': self.N,
120126
'hyperparams': {

0 commit comments

Comments
 (0)