Skip to content

Commit 90211c4

Browse files
committed
add checkpointing to transformers trunk, to save on memory when finetuning
1 parent 18614f7 commit 90211c4

2 files changed

Lines changed: 24 additions & 4 deletions

File tree

enformer_pytorch/enformer_pytorch.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
22
import torch
33
from torch import nn, einsum
4+
from torch.utils.checkpoint import checkpoint_sequential
45
from einops import rearrange, reduce
56
from einops.layers.torch import Rearrange
67

@@ -262,7 +263,8 @@ def __init__(
262263
attn_dim_key = 64,
263264
dropout_rate = 0.4,
264265
attn_dropout = 0.05,
265-
pos_dropout = 0.01
266+
pos_dropout = 0.01,
267+
use_checkpointing = False
266268
):
267269
super().__init__()
268270
self.dim = dim
@@ -359,6 +361,10 @@ def __init__(
359361
nn.Softplus()
360362
), output_heads))
361363

364+
# use checkpointing on transformer trunk
365+
366+
self.use_checkpointing = use_checkpointing
367+
362368
def set_target_length(self, target_length):
363369
crop_module = self._trunk[-2]
364370
crop_module.target_length = target_length
@@ -370,7 +376,20 @@ def trunk(self):
370376
@property
371377
def heads(self):
372378
return self._heads
373-
379+
380+
def trunk_checkpointed(self, x):
381+
x = self.stem(x)
382+
x = self.conv_tower(x)
383+
x = self.transformer[0](x)
384+
385+
# todo (move the rearrange out of self.transformers sequential module, and transfer all weights to new module rearrangement, directly checkpoint on self.transformers)
386+
transformer_blocks = self.transformer[1:]
387+
x = checkpoint_sequential(nn.Sequential(*transformer_blocks), len(transformer_blocks), x)
388+
389+
x = self.crop_final(x)
390+
x = self.final_pointwise(x)
391+
return x
392+
374393
def forward(
375394
self,
376395
x,
@@ -390,7 +409,8 @@ def forward(
390409
if no_batch:
391410
x = rearrange(x, '... -> () ...')
392411

393-
x = self._trunk(x)
412+
trunk_fn = self.trunk_checkpointed if self.use_checkpointing else self._trunk
413+
x = trunk_fn(x)
394414

395415
if no_batch:
396416
x = rearrange(x, '() ... -> ...')

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.1.20',
7+
version = '0.1.21',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)