7
7
from torch import nn
8
8
9
9
from layer_to_layer_pytorch .helpers import enumerator , iterator , zipper
10
+ from layer_to_layer_pytorch .loss import L2LLoss
10
11
from layer_to_layer_pytorch .types import Device , LossFn , TensorOrTensorArray
11
12
12
13
@@ -54,7 +55,7 @@ def zero_grad(self) -> None:
54
55
55
56
@torch .no_grad ()
56
57
def forward (self , batch : torch .Tensor , ** kwargs ) -> torch .Tensor :
57
- layers : nn .ModuleList = getattr ( self .main_model , self . layers_attr )
58
+ layers : nn .ModuleList = self ._get_layers ( )
58
59
59
60
# layer by layer forward pass. only activations are stored
60
61
for idx , l in enumerator (
@@ -72,13 +73,7 @@ def forward(self, batch: torch.Tensor, **kwargs) -> torch.Tensor:
72
73
else :
73
74
input = self ._activations [idx - 1 ]
74
75
75
- # forward with microbatching
76
- batch_size = input .shape [0 ]
77
- microbatch_size = (
78
- batch_size
79
- if self .microbatch_size is None
80
- else self .microbatch_size
81
- )
76
+ microbatch_size = self ._get_microbatch_size (input )
82
77
num_steps : int = input .shape [0 ] // microbatch_size
83
78
84
79
for microbatch in iterator (
@@ -103,12 +98,13 @@ def calculate_gradients(
103
98
target : torch .Tensor ,
104
99
loss_fn : LossFn ,
105
100
loss_kwargs : dict = None ,
101
+ skip_last_layer : bool = False ,
106
102
** forward_kwargs ,
107
- ) -> torch .Tensor :
103
+ ) -> Optional [ torch .Tensor ] :
108
104
if loss_kwargs is None :
109
105
loss_kwargs = {}
110
106
# layer by layer backward pass (in reverse order)
111
- layers : nn .ModuleList = getattr ( self .main_model , self . layers_attr )
107
+ layers : nn .ModuleList = self ._get_layers ( )
112
108
losses : List [torch .Tensor ] = []
113
109
num_steps_in_loss : int = 1
114
110
@@ -122,26 +118,38 @@ def calculate_gradients(
122
118
layer = copy .deepcopy (l ).to (self .gpu_device )
123
119
f_idx : int = self .num_layers - idx - 1
124
120
121
+ if idx == 0 and skip_last_layer :
122
+ microbatch_size = self ._get_microbatch_size (
123
+ self ._activations [f_idx ]
124
+ )
125
+ num_steps : int = (
126
+ self ._activations [f_idx ].shape [0 ] // microbatch_size
127
+ )
128
+ self ._copy_grad_to_main_model (
129
+ idx ,
130
+ num_steps ,
131
+ local_params = layer .parameters (),
132
+ main_params = layers [f_idx ].parameters (),
133
+ )
134
+ continue
135
+
125
136
for param in layer .parameters ():
126
137
param .requires_grad = True
127
138
128
139
input : torch .Tensor
129
140
output : torch .Tensor
130
141
131
- if f_idx == 0 :
142
+ if idx == 0 : # last layer
143
+ input = self ._activations [f_idx ]
144
+ output = target
145
+ elif f_idx == 0 : # first layer
132
146
input = batch
133
- output = self ._grads [ idx - 1 ]
134
- else :
147
+ output = self ._activations [ f_idx ]
148
+ else : # any other layer
135
149
input = self ._activations [f_idx - 1 ]
136
- output = target
137
-
138
- batch_size = input .shape [0 ]
139
- microbatch_size = (
140
- batch_size
141
- if self .microbatch_size is None
142
- else self .microbatch_size
143
- )
150
+ output = self ._activations [f_idx ]
144
151
152
+ microbatch_size = self ._get_microbatch_size (input )
145
153
num_steps : int = input .shape [0 ] // microbatch_size
146
154
if idx == 0 :
147
155
num_steps_in_loss = num_steps
@@ -160,7 +168,12 @@ def calculate_gradients(
160
168
161
169
microtarget = microtarget .to (self .gpu_device )
162
170
163
- activation : torch .Tensor = layer (microbatch , ** forward_kwargs )
171
+ if idx == 0 :
172
+ activation = microbatch
173
+ else :
174
+ activation : torch .Tensor = layer (
175
+ microbatch , ** forward_kwargs
176
+ )
164
177
165
178
if idx == 0 :
166
179
loss = loss_fn (activation , microtarget , ** loss_kwargs )
@@ -172,27 +185,58 @@ def calculate_gradients(
172
185
activation .backward (microtarget )
173
186
self ._grads [idx ].append (microbatch .grad .cpu ())
174
187
175
- for local_param , main_param in zip (
176
- layer .parameters (), layers [f_idx ].parameters ()
177
- ):
178
- if main_param .grad is None :
179
- main_param .grad = local_param .grad .cpu () / num_steps
180
- else :
181
- main_param .grad += local_param .grad .cpu () / num_steps
188
+ self ._copy_grad_to_main_model (
189
+ idx ,
190
+ num_steps ,
191
+ local_params = layer .parameters (),
192
+ main_params = layers [f_idx ].parameters (),
193
+ )
182
194
195
+ self ._grads = list (reversed (self ._grads ))
196
+
197
+ if not skip_last_layer :
183
198
with torch .no_grad ():
184
- self ._grads [idx ] = (
185
- torch .cat (self ._grads [idx ], dim = 0 ).cpu () / num_steps
186
- )
199
+ loss_value = torch .tensor (np .sum (losses ) / num_steps_in_loss )
187
200
188
- self ._grads = list (reversed (self ._grads ))
189
- with torch .no_grad ():
190
- loss_value = torch .tensor (np .sum (losses ) / num_steps_in_loss )
201
+ return loss_value
191
202
192
- return loss_value
203
+ return None
193
204
194
205
def __call__ (self , batch : torch .Tensor ) -> torch .Tensor :
195
206
return self .forward (batch )
196
207
208
+ def _get_microbatch_size (self , batch : torch .Tensor ) -> int :
209
+ batch_size = batch .shape [0 ]
210
+ return (
211
+ batch_size if self .microbatch_size is None else self .microbatch_size
212
+ )
213
+
214
+ def _get_layers (self ) -> nn .ModuleList :
215
+ return getattr (self .main_model , self .layers_attr )
216
+
217
+ def _copy_grad_to_main_model (
218
+ self , idx : int , num_steps : int , local_params , main_params
219
+ ):
220
+ for local_param , main_param in zip (local_params , main_params ):
221
+ if main_param .grad is None :
222
+ main_param .grad = local_param .grad .cpu () / num_steps
223
+ else :
224
+ main_param .grad += local_param .grad .cpu () / num_steps
225
+
226
+ with torch .no_grad ():
227
+ self ._grads [idx ] = (
228
+ torch .cat (self ._grads [idx ], dim = 0 ).cpu () / num_steps
229
+ )
230
+
231
+ def l2l_loss (
232
+ self , loss_fn : LossFn , store_grad_on_calc : bool = True , ** forward_kwargs
233
+ ) -> L2LLoss :
234
+ return L2LLoss (
235
+ model = self ,
236
+ loss_fn = loss_fn ,
237
+ store_grad_on_calc = store_grad_on_calc ,
238
+ ** forward_kwargs ,
239
+ )
240
+
197
241
198
242
__all__ = ["Layer2Layer" ]
0 commit comments