Skip to content

Commit 196365d

Browse files
authored
Merge pull request #419 from ml-struct-bio/v3.4.2
v3.4.2: AMP for ab-initio reconstruction; faster pose parsing
2 parents 5bf68c5 + 25948f7 commit 196365d

15 files changed

+541
-279
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ For any feedback, questions, or bugs, please file a Github issue or start a Gith
2121

2222
### New in Version 3.4.x
2323
* [NEW] `cryodrgn plot_classes` for analysis visualizations colored by a given set of class labels
24-
* support for RELION 3.1 .star files with separate optics tables
25-
* support for np.float16 number formats used in RELION .mrcs outputs
24+
* implementing [automatic mixed-precision training](https://pytorch.org/docs/stable/amp.html)
25+
for ab-initio reconstruction for 2-4x speedup
26+
* support for RELION 3.1 .star files with separate optics tables, np.float16 number formats used in RELION .mrcs outputs
2627
* `cryodrgn backproject_voxel` produces cryoSPARC-style FSC curve plots with phase-randomization correction of
2728
automatically generated tight masks
2829
* `cryodrgn downsample` can create a new .star or .txt image stack from the corresponding stack format instead of

cryodrgn/commands/abinit_het.py

Lines changed: 147 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import pickle
1818
import sys
19+
import contextlib
1920
import logging
2021
from datetime import datetime as dt
2122
import numpy as np
@@ -32,6 +33,11 @@
3233
from cryodrgn.models import HetOnlyVAE, unparallelize
3334
from cryodrgn.pose_search import PoseSearch
3435

36+
try:
37+
import apex.amp as amp # type: ignore # PYR01
38+
except ImportError:
39+
pass
40+
3541
logger = logging.getLogger(__name__)
3642

3743

@@ -223,6 +229,12 @@ def add_args(parser):
223229
type=int,
224230
help="If set, reset the optimizer every N epochs",
225231
)
232+
group.add_argument(
233+
"--no-amp",
234+
action="store_false",
235+
dest="amp",
236+
help="Do not use mixed-precision training for accelerating training",
237+
)
226238
group.add_argument(
227239
"--multigpu",
228240
action="store_true",
@@ -451,6 +463,8 @@ def train(
451463
enc_only=False,
452464
poses=None,
453465
ctf_params=None,
466+
use_amp=False,
467+
scaler=None,
454468
):
455469
y, yt = minibatch
456470
use_tilt = yt is not None
@@ -470,87 +484,104 @@ def train(
470484
# TODO: Center image?
471485
# We do this in pose-supervised train_vae
472486

473-
# VAE inference of z
474-
model.train()
475-
optim.zero_grad()
476-
input_ = (y, yt) if use_tilt else (y,)
477-
if ctf_i is not None:
478-
input_ = (x * ctf_i.sign() for x in input_) # phase flip by the ctf
479-
480-
_model = unparallelize(model)
481-
assert isinstance(_model, HetOnlyVAE)
482-
z_mu, z_logvar = _model.encode(*input_)
483-
z = _model.reparameterize(z_mu, z_logvar)
484-
485-
lamb = eq_loss = None
486-
if equivariance is not None:
487-
lamb, equivariance_loss = equivariance
488-
eq_loss = equivariance_loss(y, z_mu)
489-
490-
# pose inference
491-
if poses is not None: # use provided poses
492-
rot = poses[0]
493-
trans = poses[1]
494-
else: # pose search
495-
model.eval()
496-
with torch.no_grad():
497-
rot, trans, _base_pose = ps.opt_theta_trans(
498-
y,
499-
z=z,
500-
images_tilt=None if enc_only else yt,
501-
ctf_i=ctf_i,
502-
)
503-
model.train()
504-
505-
# reconstruct circle of pixels instead of whole image
506-
mask = lattice.get_circular_mask(L)
487+
if scaler is not None:
488+
amp_mode = torch.cuda.amp.autocast_mode.autocast()
489+
else:
490+
amp_mode = contextlib.nullcontext()
507491

508-
def gen_slice(R):
509-
slice_ = model(lattice.coords[mask] @ R, z).view(B, -1)
492+
with amp_mode:
493+
# VAE inference of z
494+
model.train()
495+
optim.zero_grad()
496+
input_ = (y, yt) if use_tilt else (y,)
510497
if ctf_i is not None:
511-
slice_ *= ctf_i.view(B, -1)[:, mask]
512-
return slice_
498+
input_ = (x * ctf_i.sign() for x in input_) # phase flip by the ctf
513499

514-
def translate(img):
515-
img = lattice.translate_ht(img, trans.unsqueeze(1), mask)
516-
return img.view(B, -1)
500+
_model = unparallelize(model)
501+
assert isinstance(_model, HetOnlyVAE)
502+
z_mu, z_logvar = _model.encode(*input_)
503+
z = _model.reparameterize(z_mu, z_logvar)
504+
505+
lamb = eq_loss = None
506+
if equivariance is not None:
507+
lamb, equivariance_loss = equivariance
508+
eq_loss = equivariance_loss(y, z_mu)
509+
510+
# pose inference
511+
if poses is not None: # use provided poses
512+
rot = poses[0]
513+
trans = poses[1]
514+
else: # pose search
515+
model.eval()
516+
with torch.no_grad():
517+
rot, trans, _base_pose = ps.opt_theta_trans(
518+
y,
519+
z=z,
520+
images_tilt=None if enc_only else yt,
521+
ctf_i=ctf_i,
522+
)
523+
model.train()
517524

518-
y = y.view(B, -1)[:, mask]
519-
if use_tilt:
520-
yt = yt.view(B, -1)[:, mask]
521-
y = translate(y)
522-
if use_tilt:
523-
yt = translate(yt)
525+
# reconstruct circle of pixels instead of whole image
526+
mask = lattice.get_circular_mask(L)
524527

525-
if use_tilt:
526-
gen_loss = 0.5 * F.mse_loss(gen_slice(rot), y) + 0.5 * F.mse_loss(
527-
gen_slice(bnb.tilt @ rot), yt # type: ignore # noqa: F821
528-
)
529-
else:
530-
gen_loss = F.mse_loss(gen_slice(rot), y)
528+
def gen_slice(R):
529+
slice_ = model(lattice.coords[mask] @ R, z).view(B, -1)
530+
if ctf_i is not None:
531+
slice_ *= ctf_i.view(B, -1)[:, mask]
532+
return slice_
531533

532-
# latent loss
533-
kld = torch.mean(
534-
-0.5 * torch.sum(1 + z_logvar - z_mu.pow(2) - z_logvar.exp(), dim=1), dim=0
535-
)
536-
if torch.isnan(kld):
537-
logger.info(z_mu[0])
538-
logger.info(z_logvar[0])
539-
raise RuntimeError("KLD is nan")
534+
def translate(img):
535+
img = lattice.translate_ht(img, trans.unsqueeze(1), mask)
536+
return img.view(B, -1)
540537

541-
if beta_control is None:
542-
loss = gen_loss + beta * kld / mask.sum().float()
543-
else:
544-
loss = gen_loss + beta_control * (beta - kld) ** 2 / mask.sum().float()
538+
y = y.view(B, -1)[:, mask]
539+
if use_tilt:
540+
yt = yt.view(B, -1)[:, mask]
541+
y = translate(y)
542+
if use_tilt:
543+
yt = translate(yt)
545544

546-
if loss is not None and eq_loss is not None:
547-
loss += lamb * eq_loss
545+
if use_tilt:
546+
gen_loss = 0.5 * F.mse_loss(gen_slice(rot), y) + 0.5 * F.mse_loss(
547+
gen_slice(bnb.tilt @ rot), yt # type: ignore # noqa: F821
548+
)
549+
else:
550+
gen_loss = F.mse_loss(gen_slice(rot), y)
548551

549-
loss.backward()
552+
# latent loss
553+
kld = torch.mean(
554+
-0.5 * torch.sum(1 + z_logvar - z_mu.pow(2) - z_logvar.exp(), dim=1), dim=0
555+
)
556+
if torch.isnan(kld):
557+
logger.info(z_mu[0])
558+
logger.info(z_logvar[0])
559+
raise RuntimeError("KLD is nan")
560+
561+
if beta_control is None:
562+
loss = gen_loss + beta * kld / mask.sum().float()
563+
else:
564+
loss = gen_loss + beta_control * (beta - kld) ** 2 / mask.sum().float()
565+
566+
if loss is not None and eq_loss is not None:
567+
loss += lamb * eq_loss
568+
569+
if use_amp:
570+
if scaler is not None:
571+
scaler.scale(loss).backward()
572+
scaler.step(optim)
573+
scaler.update()
574+
else: # apex.amp mixed precision
575+
with amp.scale_loss(loss, optim) as scaled_loss:
576+
scaled_loss.backward()
577+
optim.step()
578+
else:
579+
loss.backward()
580+
optim.step()
550581

551-
optim.step()
552582
save_pose = [rot.detach().cpu().numpy()]
553583
save_pose.append(trans.detach().cpu().numpy())
584+
554585
return (
555586
gen_loss.item(),
556587
kld.item(),
@@ -833,6 +864,51 @@ def main(args):
833864

834865
optim = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
835866

867+
# Mixed precision training
868+
scaler = None
869+
if args.amp:
870+
if args.batch_size % 8 != 0:
871+
logger.warning(
872+
f"Batch size {args.batch_size} not divisible by 8 "
873+
f"and thus not optimal for AMP training!"
874+
)
875+
if (D - 1) % 8 != 0:
876+
logger.warning(
877+
f"Image size {D - 1} not divisible by 8 "
878+
f"and thus not optimal for AMP training!"
879+
)
880+
881+
if args.pdim % 8 != 0:
882+
logger.warning(
883+
f"Decoder hidden layer dimension {args.pdim} not divisible by 8 "
884+
f"and thus not optimal for AMP training!"
885+
)
886+
887+
# also check e.g. enc_mask dim?
888+
if args.qdim % 8 != 0:
889+
logger.warning(
890+
f"Decoder hidden layer dimension {args.qdim} not divisible by 8 "
891+
f"and thus not optimal for AMP training!"
892+
)
893+
894+
if args.zdim % 8 != 0:
895+
logger.warning(
896+
f"Z dimension {args.zdim} is not a multiple of 8 "
897+
"-- AMP training speedup is not optimized!"
898+
)
899+
if in_dim % 8 != 0:
900+
logger.warning(
901+
f"Masked input image dimension {in_dim} is not a mutiple of 8 "
902+
"-- AMP training speedup is not optimized!"
903+
)
904+
905+
# mixed precision with apex.amp
906+
try:
907+
model, optim = amp.initialize(model, optim, opt_level="O1")
908+
# mixed precision with pytorch (v1.6+)
909+
except: # noqa: E722
910+
scaler = torch.cuda.amp.grad_scaler.GradScaler()
911+
836912
if args.load == "latest":
837913
args = get_latest(args)
838914

@@ -1007,6 +1083,8 @@ def main(args):
10071083
enc_only=args.enc_only,
10081084
poses=p,
10091085
ctf_params=ctf_i,
1086+
use_amp=args.amp,
1087+
scaler=scaler,
10101088
)
10111089
# logging
10121090
poses.append((ind.cpu().numpy(), pose))

0 commit comments

Comments
 (0)