@@ -55,6 +55,7 @@ def __init__(self, model, optimizer, dense_allocation=1, T_end=None, sparsity_di
5555 # modify optimizer.step() function to call "reset_momentum" after
5656 _create_step_wrapper (self , optimizer )
5757
58+ self .dense_allocation = dense_allocation
5859 self .N = [torch .numel (w ) for w in self .W ]
5960
6061 if state_dict is not None :
@@ -71,12 +72,16 @@ def __init__(self, model, optimizer, dense_allocation=1, T_end=None, sparsity_di
7172 # define sparsity allocation
7273 self .S = []
7374 for i , (W , is_linear ) in enumerate (zip (self .W , self ._linear_layers_mask )):
74- if i == 0 and self .sparsity_distribution == 'uniform' :
75- # when using uniform sparsity, the first layer is always 100% dense
75+ # when using uniform sparsity, the first layer is always 100% dense
76+ # UNLESS there is only 1 layer
77+ is_first_layer = i == 0
78+ if is_first_layer and self .sparsity_distribution == 'uniform' and len (self .W ) > 1 :
7679 self .S .append (0 )
80+
7781 elif is_linear and self .ignore_linear_layers :
7882 # if choosing to ignore linear layers, keep them 100% dense
7983 self .S .append (0 )
84+
8085 else :
8186 self .S .append (1 - dense_allocation )
8287
@@ -115,6 +120,7 @@ def __init__(self, model, optimizer, dense_allocation=1, T_end=None, sparsity_di
115120
116121 def state_dict (self ):
117122 obj = {
123+ 'dense_allocation' : self .dense_allocation ,
118124 'S' : self .S ,
119125 'N' : self .N ,
120126 'hyperparams' : {
0 commit comments