Skip to content

Commit 0dc6691

Browse files
committed
🔧 Update
1 parent 582bb33 commit 0dc6691

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

‎layer_to_layer_pytorch/l2l.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ def zero_grad(self) -> None:
5757

5858
self._reset_activations()
5959

60+
def _zero_layer_grad(self, layer: nn.Module) -> None:
61+
for param in layer.parameters():
62+
param.grad = None
63+
6064
@torch.no_grad()
6165
def forward(self, batch: torch.Tensor, **kwargs) -> torch.Tensor:
6266
layers: nn.ModuleList = self._get_layers()
@@ -119,9 +123,8 @@ def backward(
119123
total=self.num_layers,
120124
leave=False,
121125
):
126+
self._zero_layer_grad(l)
122127
layer: nn.Module = copy.deepcopy(l).to(self.gpu_device)
123-
for param in layer.parameters():
124-
param.grad = None
125128
f_idx: int = self.num_layers - idx - 1
126129

127130
# TODO: preserve re-calculations

0 commit comments

Comments
 (0)