1616import os
1717import pickle
1818import sys
19+ import contextlib
1920import logging
2021from datetime import datetime as dt
2122import numpy as np
3233from cryodrgn .models import HetOnlyVAE , unparallelize
3334from cryodrgn .pose_search import PoseSearch
3435
36+ try :
37+ import apex .amp as amp # type: ignore # PYR01
38+ except ImportError :
39+ pass
40+
3541logger = logging .getLogger (__name__ )
3642
3743
@@ -223,6 +229,12 @@ def add_args(parser):
223229 type = int ,
224230 help = "If set, reset the optimizer every N epochs" ,
225231 )
232+ group .add_argument (
233+ "--no-amp" ,
234+ action = "store_false" ,
235+ dest = "amp" ,
236+ help = "Do not use mixed-precision training for accelerating training" ,
237+ )
226238 group .add_argument (
227239 "--multigpu" ,
228240 action = "store_true" ,
@@ -451,6 +463,8 @@ def train(
451463 enc_only = False ,
452464 poses = None ,
453465 ctf_params = None ,
466+ use_amp = False ,
467+ scaler = None ,
454468):
455469 y , yt = minibatch
456470 use_tilt = yt is not None
@@ -470,87 +484,104 @@ def train(
470484 # TODO: Center image?
471485 # We do this in pose-supervised train_vae
472486
473- # VAE inference of z
474- model .train ()
475- optim .zero_grad ()
476- input_ = (y , yt ) if use_tilt else (y ,)
477- if ctf_i is not None :
478- input_ = (x * ctf_i .sign () for x in input_ ) # phase flip by the ctf
479-
480- _model = unparallelize (model )
481- assert isinstance (_model , HetOnlyVAE )
482- z_mu , z_logvar = _model .encode (* input_ )
483- z = _model .reparameterize (z_mu , z_logvar )
484-
485- lamb = eq_loss = None
486- if equivariance is not None :
487- lamb , equivariance_loss = equivariance
488- eq_loss = equivariance_loss (y , z_mu )
489-
490- # pose inference
491- if poses is not None : # use provided poses
492- rot = poses [0 ]
493- trans = poses [1 ]
494- else : # pose search
495- model .eval ()
496- with torch .no_grad ():
497- rot , trans , _base_pose = ps .opt_theta_trans (
498- y ,
499- z = z ,
500- images_tilt = None if enc_only else yt ,
501- ctf_i = ctf_i ,
502- )
503- model .train ()
504-
505- # reconstruct circle of pixels instead of whole image
506- mask = lattice .get_circular_mask (L )
487+ if scaler is not None :
488+ amp_mode = torch .cuda .amp .autocast_mode .autocast ()
489+ else :
490+ amp_mode = contextlib .nullcontext ()
507491
508- def gen_slice (R ):
509- slice_ = model (lattice .coords [mask ] @ R , z ).view (B , - 1 )
492+ with amp_mode :
493+ # VAE inference of z
494+ model .train ()
495+ optim .zero_grad ()
496+ input_ = (y , yt ) if use_tilt else (y ,)
510497 if ctf_i is not None :
511- slice_ *= ctf_i .view (B , - 1 )[:, mask ]
512- return slice_
498+ input_ = (x * ctf_i .sign () for x in input_ ) # phase flip by the ctf
513499
514- def translate (img ):
515- img = lattice .translate_ht (img , trans .unsqueeze (1 ), mask )
516- return img .view (B , - 1 )
500+ _model = unparallelize (model )
501+ assert isinstance (_model , HetOnlyVAE )
502+ z_mu , z_logvar = _model .encode (* input_ )
503+ z = _model .reparameterize (z_mu , z_logvar )
504+
505+ lamb = eq_loss = None
506+ if equivariance is not None :
507+ lamb , equivariance_loss = equivariance
508+ eq_loss = equivariance_loss (y , z_mu )
509+
510+ # pose inference
511+ if poses is not None : # use provided poses
512+ rot = poses [0 ]
513+ trans = poses [1 ]
514+ else : # pose search
515+ model .eval ()
516+ with torch .no_grad ():
517+ rot , trans , _base_pose = ps .opt_theta_trans (
518+ y ,
519+ z = z ,
520+ images_tilt = None if enc_only else yt ,
521+ ctf_i = ctf_i ,
522+ )
523+ model .train ()
517524
518- y = y .view (B , - 1 )[:, mask ]
519- if use_tilt :
520- yt = yt .view (B , - 1 )[:, mask ]
521- y = translate (y )
522- if use_tilt :
523- yt = translate (yt )
525+ # reconstruct circle of pixels instead of whole image
526+ mask = lattice .get_circular_mask (L )
524527
525- if use_tilt :
526- gen_loss = 0.5 * F .mse_loss (gen_slice (rot ), y ) + 0.5 * F .mse_loss (
527- gen_slice (bnb .tilt @ rot ), yt # type: ignore # noqa: F821
528- )
529- else :
530- gen_loss = F .mse_loss (gen_slice (rot ), y )
528+ def gen_slice (R ):
529+ slice_ = model (lattice .coords [mask ] @ R , z ).view (B , - 1 )
530+ if ctf_i is not None :
531+ slice_ *= ctf_i .view (B , - 1 )[:, mask ]
532+ return slice_
531533
532- # latent loss
533- kld = torch .mean (
534- - 0.5 * torch .sum (1 + z_logvar - z_mu .pow (2 ) - z_logvar .exp (), dim = 1 ), dim = 0
535- )
536- if torch .isnan (kld ):
537- logger .info (z_mu [0 ])
538- logger .info (z_logvar [0 ])
539- raise RuntimeError ("KLD is nan" )
534+ def translate (img ):
535+ img = lattice .translate_ht (img , trans .unsqueeze (1 ), mask )
536+ return img .view (B , - 1 )
540537
541- if beta_control is None :
542- loss = gen_loss + beta * kld / mask .sum ().float ()
543- else :
544- loss = gen_loss + beta_control * (beta - kld ) ** 2 / mask .sum ().float ()
538+ y = y .view (B , - 1 )[:, mask ]
539+ if use_tilt :
540+ yt = yt .view (B , - 1 )[:, mask ]
541+ y = translate (y )
542+ if use_tilt :
543+ yt = translate (yt )
545544
546- if loss is not None and eq_loss is not None :
547- loss += lamb * eq_loss
545+ if use_tilt :
546+ gen_loss = 0.5 * F .mse_loss (gen_slice (rot ), y ) + 0.5 * F .mse_loss (
547+ gen_slice (bnb .tilt @ rot ), yt # type: ignore # noqa: F821
548+ )
549+ else :
550+ gen_loss = F .mse_loss (gen_slice (rot ), y )
548551
549- loss .backward ()
552+ # latent loss
553+ kld = torch .mean (
554+ - 0.5 * torch .sum (1 + z_logvar - z_mu .pow (2 ) - z_logvar .exp (), dim = 1 ), dim = 0
555+ )
556+ if torch .isnan (kld ):
557+ logger .info (z_mu [0 ])
558+ logger .info (z_logvar [0 ])
559+ raise RuntimeError ("KLD is nan" )
560+
561+ if beta_control is None :
562+ loss = gen_loss + beta * kld / mask .sum ().float ()
563+ else :
564+ loss = gen_loss + beta_control * (beta - kld ) ** 2 / mask .sum ().float ()
565+
566+ if loss is not None and eq_loss is not None :
567+ loss += lamb * eq_loss
568+
569+ if use_amp :
570+ if scaler is not None :
571+ scaler .scale (loss ).backward ()
572+ scaler .step (optim )
573+ scaler .update ()
574+ else : # apex.amp mixed precision
575+ with amp .scale_loss (loss , optim ) as scaled_loss :
576+ scaled_loss .backward ()
577+ optim .step ()
578+ else :
579+ loss .backward ()
580+ optim .step ()
550581
551- optim .step ()
552582 save_pose = [rot .detach ().cpu ().numpy ()]
553583 save_pose .append (trans .detach ().cpu ().numpy ())
584+
554585 return (
555586 gen_loss .item (),
556587 kld .item (),
@@ -833,6 +864,51 @@ def main(args):
833864
834865 optim = torch .optim .Adam (model .parameters (), lr = args .lr , weight_decay = args .wd )
835866
867+ # Mixed precision training
868+ scaler = None
869+ if args .amp :
870+ if args .batch_size % 8 != 0 :
871+ logger .warning (
872+ f"Batch size { args .batch_size } not divisible by 8 "
873+ f"and thus not optimal for AMP training!"
874+ )
875+ if (D - 1 ) % 8 != 0 :
876+ logger .warning (
877+ f"Image size { D - 1 } not divisible by 8 "
878+ f"and thus not optimal for AMP training!"
879+ )
880+
881+ if args .pdim % 8 != 0 :
882+ logger .warning (
883+ f"Decoder hidden layer dimension { args .pdim } not divisible by 8 "
884+ f"and thus not optimal for AMP training!"
885+ )
886+
887+ # also check e.g. enc_mask dim?
888+ if args .qdim % 8 != 0 :
889+ logger .warning (
890+ f"Decoder hidden layer dimension { args .qdim } not divisible by 8 "
891+ f"and thus not optimal for AMP training!"
892+ )
893+
894+ if args .zdim % 8 != 0 :
895+ logger .warning (
896+ f"Z dimension { args .zdim } is not a multiple of 8 "
897+ "-- AMP training speedup is not optimized!"
898+ )
899+ if in_dim % 8 != 0 :
900+ logger .warning (
901+ f"Masked input image dimension { in_dim } is not a mutiple of 8 "
902+ "-- AMP training speedup is not optimized!"
903+ )
904+
905+ # mixed precision with apex.amp
906+ try :
907+ model , optim = amp .initialize (model , optim , opt_level = "O1" )
908+ # mixed precision with pytorch (v1.6+)
909+ except : # noqa: E722
910+ scaler = torch .cuda .amp .grad_scaler .GradScaler ()
911+
836912 if args .load == "latest" :
837913 args = get_latest (args )
838914
@@ -1007,6 +1083,8 @@ def main(args):
10071083 enc_only = args .enc_only ,
10081084 poses = p ,
10091085 ctf_params = ctf_i ,
1086+ use_amp = args .amp ,
1087+ scaler = scaler ,
10101088 )
10111089 # logging
10121090 poses .append ((ind .cpu ().numpy (), pose ))
0 commit comments