Skip to content

Commit 582bb33

Browse files
committed
🔧 fix multiple grad copy
1 parent 0bf37e0 commit 582bb33

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

‎layer_to_layer_pytorch/l2l.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ def backward(
119119
total=self.num_layers,
120120
leave=False,
121121
):
122-
layer = copy.deepcopy(l).to(self.gpu_device)
122+
layer: nn.Module = copy.deepcopy(l).to(self.gpu_device)
123+
for param in layer.parameters():
124+
param.grad = None
123125
f_idx: int = self.num_layers - idx - 1
124126

125127
# TODO: preserve re-calculations
@@ -194,11 +196,11 @@ def backward(
194196

195197
self._grads[idx].append(microbatch.grad.cpu())
196198

197-
self._copy_grad_to_main_model(
198-
num_steps,
199-
local_params=layer.parameters(),
200-
main_params=layers[f_idx].parameters(),
201-
)
199+
self._copy_grad_to_main_model(
200+
num_steps,
201+
local_params=layer.parameters(),
202+
main_params=layers[f_idx].parameters(),
203+
)
202204

203205
with torch.no_grad():
204206
self._grads[idx] = (

0 commit comments

Comments
 (0)