We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 582bb33 commit 0dc6691Copy full SHA for 0dc6691
‎layer_to_layer_pytorch/l2l.py
@@ -57,6 +57,10 @@ def zero_grad(self) -> None:
57
58
self._reset_activations()
59
60
+ def _zero_layer_grad(self, layer: nn.Module) -> None:
61
+ for param in layer.parameters():
62
+ param.grad = None
63
+
64
@torch.no_grad()
65
def forward(self, batch: torch.Tensor, **kwargs) -> torch.Tensor:
66
layers: nn.ModuleList = self._get_layers()
@@ -119,9 +123,8 @@ def backward(
119
123
total=self.num_layers,
120
124
leave=False,
121
125
):
126
+ self._zero_layer_grad(l)
122
127
layer: nn.Module = copy.deepcopy(l).to(self.gpu_device)
- for param in layer.parameters():
- param.grad = None
128
f_idx: int = self.num_layers - idx - 1
129
130
# TODO: preserve re-calculations
0 commit comments