11import math
22import torch
33from torch import nn , einsum
4+ from torch .utils .checkpoint import checkpoint_sequential
45from einops import rearrange , reduce
56from einops .layers .torch import Rearrange
67
@@ -262,7 +263,8 @@ def __init__(
262263 attn_dim_key = 64 ,
263264 dropout_rate = 0.4 ,
264265 attn_dropout = 0.05 ,
265- pos_dropout = 0.01
266+ pos_dropout = 0.01 ,
267+ use_checkpointing = False
266268 ):
267269 super ().__init__ ()
268270 self .dim = dim
@@ -359,6 +361,10 @@ def __init__(
359361 nn .Softplus ()
360362 ), output_heads ))
361363
364+ # use checkpointing on transformer trunk
365+
366+ self .use_checkpointing = use_checkpointing
367+
362368 def set_target_length (self , target_length ):
363369 crop_module = self ._trunk [- 2 ]
364370 crop_module .target_length = target_length
@@ -370,7 +376,20 @@ def trunk(self):
370376 @property
371377 def heads (self ):
372378 return self ._heads
373-
379+
380+ def trunk_checkpointed (self , x ):
381+ x = self .stem (x )
382+ x = self .conv_tower (x )
383+ x = self .transformer [0 ](x )
384+
385+ # todo (move the rearrange out of self.transformers sequential module, and transfer all weights to new module rearrangement, directly checkpoint on self.transformers)
386+ transformer_blocks = self .transformer [1 :]
387+ x = checkpoint_sequential (nn .Sequential (* transformer_blocks ), len (transformer_blocks ), x )
388+
389+ x = self .crop_final (x )
390+ x = self .final_pointwise (x )
391+ return x
392+
374393 def forward (
375394 self ,
376395 x ,
@@ -390,7 +409,8 @@ def forward(
390409 if no_batch :
391410 x = rearrange (x , '... -> () ...' )
392411
393- x = self ._trunk (x )
412+ trunk_fn = self .trunk_checkpointed if self .use_checkpointing else self ._trunk
413+ x = trunk_fn (x )
394414
395415 if no_batch :
396416 x = rearrange (x , '() ... -> ...' )
0 commit comments