File tree 1 file changed +8
-6
lines changed
1 file changed +8
-6
lines changed Original file line number Diff line number Diff line change @@ -119,7 +119,9 @@ def backward(
119
119
total = self .num_layers ,
120
120
leave = False ,
121
121
):
122
- layer = copy .deepcopy (l ).to (self .gpu_device )
122
+ layer : nn .Module = copy .deepcopy (l ).to (self .gpu_device )
123
+ for param in layer .parameters ():
124
+ param .grad = None
123
125
f_idx : int = self .num_layers - idx - 1
124
126
125
127
# TODO: preserve re-calculations
@@ -194,11 +196,11 @@ def backward(
194
196
195
197
self ._grads [idx ].append (microbatch .grad .cpu ())
196
198
197
- self ._copy_grad_to_main_model (
198
- num_steps ,
199
- local_params = layer .parameters (),
200
- main_params = layers [f_idx ].parameters (),
201
- )
199
+ self ._copy_grad_to_main_model (
200
+ num_steps ,
201
+ local_params = layer .parameters (),
202
+ main_params = layers [f_idx ].parameters (),
203
+ )
202
204
203
205
with torch .no_grad ():
204
206
self ._grads [idx ] = (
You can’t perform that action at this time.
0 commit comments