Skip to content

Commit 605deb8

Browse files
forklady42claudepre-commit-ci[bot]
authored
Add gradient checkpointing (#33)
Adds gradient checkpointing, which trades compute for memory, to residual blocks and upsampling layers. --------- Co-authored-by: Claude <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a84076f commit 605deb8

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

src/electrai/lightning.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(self, cfg):
1818
K1=int(cfg.kernel_size1),
1919
K2=int(cfg.kernel_size2),
2020
normalize=cfg.normalize,
21+
use_checkpoint=getattr(cfg, "use_checkpoint", True),
2122
)
2223
self.loss_fn = NormMAE()
2324

src/electrai/model/srgan_layernorm_pbc.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66

77
import torch
88
import torch.nn as nn
9+
from torch.utils.checkpoint import checkpoint
910

1011

1112
class ResidualBlock(nn.Module):
12-
def __init__(self, in_features, K=3):
13+
def __init__(self, in_features, K=3, use_checkpoint=True):
1314
super().__init__()
15+
self.use_checkpoint = use_checkpoint
1416
self.conv_block = nn.Sequential(
1517
nn.Conv3d(
1618
in_features,
@@ -34,7 +36,11 @@ def __init__(self, in_features, K=3):
3436
)
3537

3638
def forward(self, x):
37-
return x + self.conv_block(x)
39+
if self.use_checkpoint and self.training:
40+
# Use gradient checkpointing to save memory during training
41+
return x + checkpoint(self.conv_block, x, use_reentrant=False)
42+
else:
43+
return x + self.conv_block(x)
3844

3945

4046
class PixelShuffle3d(nn.Module):
@@ -67,16 +73,19 @@ def __init__(
6773
K1=5,
6874
K2=3,
6975
normalize=True,
76+
use_checkpoint=True,
7077
):
7178
"""
7279
This net upscales each axis by 2**n_upscale_layers
7380
C = channel size in most of layers
7481
K1 = kernel size in the first and last layers
7582
K2 = kernel size in Res blocks
83+
use_checkpoint = enable gradient checkpointing to save memory
7684
"""
7785
super().__init__()
7886
self.n_upscale_layers = n_upscale_layers
7987
self.normalize = normalize
88+
self.use_checkpoint = use_checkpoint
8089

8190
# First layer
8291
self.conv1 = nn.Sequential(
@@ -92,7 +101,10 @@ def __init__(
92101
)
93102

94103
# Residual blocks
95-
res_blocks = [ResidualBlock(C, K=K2) for _ in range(n_residual_blocks)]
104+
res_blocks = [
105+
ResidualBlock(C, K=K2, use_checkpoint=use_checkpoint)
106+
for _ in range(n_residual_blocks)
107+
]
96108
self.res_blocks = nn.Sequential(*res_blocks)
97109

98110
# Second conv layer post residual blocks

0 commit comments

Comments
 (0)