Skip to content

Commit 4ba7550

Browse files
authored
Merge pull request #427 from ml-struct-bio/v3.4.3
v3.4.3: Making movies, improving filtering interface, and fixing landscape analysis
2 parents 196365d + ec32932 commit 4ba7550

File tree

13 files changed

+1036
-433
lines changed

13 files changed

+1036
-433
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ cryoDRGN installation, training and analysis. A brief quick start is provided be
1919
For any feedback, questions, or bugs, please file a Github issue or start a Github discussion.
2020

2121

22-
### New in Version 3.4.x
23-
* [NEW] `cryodrgn plot_classes` for analysis visualizations colored by a given set of class labels
22+
### Updates in Version 3.4.x
23+
* [NEW] `cryodrgn_utils plot_classes` for analysis visualizations colored by a given set of class labels
24+
* [NEW] `cryodrgn_utils make_movies` for animations of `analyze*` output volumes
2425
* implementing [automatic mixed-precision training](https://pytorch.org/docs/stable/amp.html)
2526
for ab-initio reconstruction for 2-4x speedup
2627
* support for RELION 3.1 .star files with separate optics tables, np.float16 number formats used in RELION .mrcs outputs
@@ -33,7 +34,7 @@ For any feedback, questions, or bugs, please file a Github issue or start a Gith
3334
* official support for Python 3.11
3435

3536

36-
### New in Version 3.x
37+
### Updates in Version 3.x
3738

3839
The official release of [cryoDRGN-ET](https://www.biorxiv.org/content/10.1101/2023.08.18.553799v1) for heterogeneous subtomogram analysis.
3940

cryodrgn/command_line.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
since automated scanning for command modules is computationally non-trivial.
1212
1313
"""
14+
1415
import argparse
1516
import os
1617
from importlib import import_module
@@ -122,6 +123,7 @@ def util_commands() -> None:
122123
"fsc",
123124
"gen_mask",
124125
"invert_contrast",
126+
"make_movies",
125127
"phase_flip",
126128
"plot_classes",
127129
"plot_fsc",

cryodrgn/commands/abinit_het.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,10 @@ def train(
485485
# We do this in pose-supervised train_vae
486486

487487
if scaler is not None:
488-
amp_mode = torch.cuda.amp.autocast_mode.autocast()
488+
try:
489+
amp_mode = torch.amp.autocast("cuda")
490+
except AttributeError:
491+
amp_mode = torch.cuda.amp.autocast_mode.autocast()
489492
else:
490493
amp_mode = contextlib.nullcontext()
491494

@@ -898,7 +901,7 @@ def main(args):
898901
)
899902
if in_dim % 8 != 0:
900903
logger.warning(
901-
f"Masked input image dimension {in_dim} is not a mutiple of 8 "
904+
f"Masked input image dimension {in_dim} is not a multiple of 8 "
902905
"-- AMP training speedup is not optimized!"
903906
)
904907

@@ -907,7 +910,10 @@ def main(args):
907910
model, optim = amp.initialize(model, optim, opt_level="O1")
908911
# mixed precision with pytorch (v1.6+)
909912
except: # noqa: E722
910-
scaler = torch.cuda.amp.grad_scaler.GradScaler()
913+
try:
914+
scaler = torch.amp.GradScaler("cuda")
915+
except AttributeError:
916+
scaler = torch.cuda.amp.grad_scaler.GradScaler()
911917

912918
if args.load == "latest":
913919
args = get_latest(args)

cryodrgn/commands/abinit_homo.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,10 @@ def train(
416416
)
417417

418418
if scaler is not None:
419-
amp_mode = torch.cuda.amp.autocast_mode.autocast()
419+
try:
420+
amp_mode = torch.amp.autocast("cuda")
421+
except AttributeError:
422+
amp_mode = torch.cuda.amp.autocast_mode.autocast()
420423
else:
421424
amp_mode = contextlib.nullcontext()
422425

@@ -676,7 +679,10 @@ def main(args):
676679
model, optim = amp.initialize(model, optim, opt_level="O1")
677680
# mixed precision with pytorch (v1.6+)
678681
except: # noqa: E722
679-
scaler = torch.cuda.amp.grad_scaler.GradScaler()
682+
try:
683+
scaler = torch.amp.GradScaler("cuda")
684+
except AttributeError:
685+
scaler = torch.cuda.amp.grad_scaler.GradScaler()
680686

681687
sorted_poses = []
682688
if args.load:

cryodrgn/commands/analyze_landscape.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -348,47 +348,55 @@ def plot(i, j):
348348

349349
kmeans_labels = utils.load_pkl(os.path.join(outdir, f"kmeans{K}", "labels.pkl"))
350350
kmeans_counts = Counter(kmeans_labels)
351-
for i in range(M):
352-
vol_i = np.where(labels == i)[0]
353-
logger.info(f"State {i}: {len(vol_i)} volumes")
351+
for cluster_i in range(M):
352+
vol_indices = np.where(labels == cluster_i)[0]
353+
logger.info(f"State {cluster_i}: {len(vol_indices)} volumes")
354354
if vol_ind is not None:
355-
vol_i = np.arange(K)[vol_ind][vol_i]
355+
vol_indices = np.arange(K)[vol_ind][vol_indices]
356+
357+
vol_fls = [
358+
os.path.join(kmean_dir, f"vol_{vol_start_index + vol_i:03d}.mrc")
359+
for vol_i in vol_indices
360+
]
361+
vol_i_all = torch.stack(
362+
[torch.Tensor(parse_mrc(vol_fl)[0]) for vol_fl in vol_fls]
363+
)
356364

357-
vol_fl = os.path.join(kmean_dir, f"vol_{vol_start_index+i:03d}.mrc")
358-
vol_i_all = torch.stack([torch.Tensor(parse_mrc(vol_fl)[0]) for i in vol_i])
359-
nparticles = np.array([kmeans_counts[i] for i in vol_i])
365+
nparticles = np.array([kmeans_counts[vol_i] for vol_i in vol_indices])
360366
vol_i_mean = np.average(vol_i_all, axis=0, weights=nparticles)
361367
vol_i_std = (
362368
np.average((vol_i_all - vol_i_mean) ** 2, axis=0, weights=nparticles) ** 0.5
363369
)
370+
364371
write_mrc(
365-
os.path.join(subdir, f"state_{i}_mean.mrc"),
372+
os.path.join(subdir, f"state_{cluster_i}_mean.mrc"),
366373
vol_i_mean.astype(np.float32),
367374
Apix=Apix,
368375
)
369376
write_mrc(
370-
os.path.join(subdir, f"state_{i}_std.mrc"),
377+
os.path.join(subdir, f"state_{cluster_i}_std.mrc"),
371378
vol_i_std.astype(np.float32),
372379
Apix=Apix,
373380
)
374381

375-
os.makedirs(os.path.join(subdir, f"state_{i}"), exist_ok=True)
376-
for v in vol_i:
377-
os.symlink(
378-
os.path.join(kmean_dir, f"vol_{vol_start_index+v:03d}.mrc"),
379-
os.path.join(subdir, f"state_{i}", f"vol_{vol_start_index+v:03d}.mrc"),
380-
)
382+
statedir = os.path.join(subdir, f"state_{cluster_i}")
383+
os.makedirs(statedir, exist_ok=True)
384+
for vol_i in vol_indices:
385+
kmean_fl = os.path.join(kmean_dir, f"vol_{vol_start_index+vol_i:03d}.mrc")
386+
sub_fl = os.path.join(statedir, f"vol_{vol_start_index+vol_i:03d}.mrc")
387+
os.symlink(kmean_fl, sub_fl)
381388

382-
particle_ind = analysis.get_ind_for_cluster(kmeans_labels, vol_i)
383-
logger.info(f"State {i}: {len(particle_ind)} particles")
389+
particle_ind = analysis.get_ind_for_cluster(kmeans_labels, vol_indices)
390+
logger.info(f"State {cluster_i}: {len(particle_ind)} particles")
384391
if particle_ind_orig is not None:
385392
utils.save_pkl(
386393
particle_ind_orig[particle_ind],
387-
os.path.join(subdir, f"state_{i}_particle_ind.pkl"),
394+
os.path.join(subdir, f"state_{cluster_i}_particle_ind.pkl"),
388395
)
389396
else:
390397
utils.save_pkl(
391-
particle_ind, os.path.join(subdir, f"state_{i}_particle_ind.pkl")
398+
particle_ind,
399+
os.path.join(subdir, f"state_{cluster_i}_particle_ind.pkl"),
392400
)
393401

394402
# plot clustering results

0 commit comments

Comments
 (0)