Skip to content

Commit 09c6f61

Browse files
authored
Merge pull request #437 from ml-struct-bio/v3.4.4
v3.4.4: Support for Python v3.12, fixing batch iteration, analyzing convergence
2 parents 4ba7550 + b3a6695 commit 09c6f61

File tree

106 files changed

+297
-3157
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

106 files changed

+297
-3157
lines changed

.github/workflows/docs.yml

Lines changed: 0 additions & 44 deletions
This file was deleted.

.github/workflows/tests.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ jobs:
1515
runs-on: ${{ matrix.os }}
1616
strategy:
1717
matrix:
18-
python: [ '3.9', '3.10', '3.11' ]
18+
python: [ '3.10', '3.11' , '3.12' ]
1919
os: [ macos-latest, ubuntu-latest ]
2020
include:
21-
- python: '3.9'
22-
torch: '1.12'
2321
- python: '3.10'
24-
torch: '2.1'
22+
torch: '1.12'
2523
- python: '3.11'
24+
torch: '2.1'
25+
- python: '3.12'
2626
torch: '2.4'
2727
fail-fast: false
2828

cryodrgn/analysis.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,9 +334,14 @@ def scatter_color(
334334
sc = plt.scatter(x, y, s=s, alpha=alpha, rasterized=True, cmap=cmap, c=c)
335335
cbar = plt.colorbar(sc)
336336
cbar.set_alpha(1)
337-
cbar.draw_all()
337+
338+
if hasattr(cbar, "draw_all"):
339+
cbar.draw_all()
340+
else:
341+
cbar._draw_all()
338342
if label:
339343
cbar.set_label(label)
344+
340345
return fig, ax
341346

342347

cryodrgn/command_line.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def util_commands() -> None:
112112
_get_commands(
113113
cmd_dir=os.path.join(os.path.dirname(__file__), "commands_utils"),
114114
cmds=[
115+
"analyze_convergence",
115116
"add_psize",
116117
"clean",
117118
"concat_pkls",

cryodrgn/commands/abinit_het.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ def add_args(parser):
8080
parser.add_argument(
8181
"--seed", type=int, default=np.random.randint(0, 100000), help="Random seed"
8282
)
83+
parser.add_argument(
84+
"--shuffle-seed",
85+
type=int,
86+
default=None,
87+
help="Random seed for data shuffling",
88+
)
8389

8490
group = parser.add_argument_group("Dataset loading")
8591
group.add_argument(
@@ -603,12 +609,17 @@ def eval_z(
603609
use_tilt=False,
604610
ctf_params=None,
605611
shuffler_size=0,
612+
seed=None,
606613
):
607614
assert not model.training
608-
z_mu_all = []
609-
z_logvar_all = []
615+
616+
z_mu_all, z_logvar_all = list(), list()
610617
data_generator = dataset.make_dataloader(
611-
data, batch_size=batch_size, shuffler_size=shuffler_size, shuffle=False
618+
data,
619+
batch_size=batch_size,
620+
shuffler_size=shuffler_size,
621+
shuffle=False,
622+
seed=seed,
612623
)
613624

614625
for minibatch in data_generator:
@@ -638,9 +649,8 @@ def eval_z(
638649
z_mu, z_logvar = _model.encode(*input_)
639650
z_mu_all.append(z_mu.detach().cpu().numpy())
640651
z_logvar_all.append(z_logvar.detach().cpu().numpy())
641-
z_mu_all = np.vstack(z_mu_all)
642-
z_logvar_all = np.vstack(z_logvar_all)
643-
return z_mu_all, z_logvar_all
652+
653+
return np.vstack(z_mu_all), np.vstack(z_logvar_all)
644654

645655

646656
def save_checkpoint(
@@ -814,9 +824,7 @@ def main(args):
814824
datadir=args.datadir,
815825
window_r=args.window_r,
816826
)
817-
818-
Nimg = data.N
819-
D = data.D
827+
Nimg, D = data.N, data.D
820828

821829
if args.encode_mode == "conv":
822830
assert D - 1 == 64, "Image size must be 64x64 for convolutional encoder"
@@ -983,25 +991,28 @@ def main(args):
983991
)
984992

985993
data_iterator = dataset.make_dataloader(
986-
data, batch_size=args.batch_size, shuffler_size=args.shuffler_size
994+
data,
995+
batch_size=args.batch_size,
996+
shuffler_size=args.shuffler_size,
997+
seed=args.shuffle_seed,
987998
)
988999

9891000
# pretrain decoder with random poses
9901001
global_it = 0
9911002
logger.info("Using random poses for {} iterations".format(args.pretrain))
992-
while global_it < args.pretrain:
993-
for batch in data_iterator:
994-
global_it += len(batch[0])
995-
batch = (
996-
(batch[0].to(device), None)
997-
if tilt is None
998-
else (batch[0].to(device), batch[1].to(device))
999-
)
1000-
loss = pretrain(model, lattice, optim, batch, tilt=ps.tilt, zdim=args.zdim)
1001-
if global_it % args.log_interval == 0:
1002-
logger.info(f"[Pretrain Iteration {global_it}] loss={loss:4f}")
1003-
if global_it > args.pretrain:
1004-
break
1003+
for batch in data_iterator:
1004+
global_it += len(batch[0])
1005+
batch = (
1006+
(batch[0].to(device), None)
1007+
if tilt is None
1008+
else (batch[0].to(device), batch[1].to(device))
1009+
)
1010+
loss = pretrain(model, lattice, optim, batch, tilt=ps.tilt, zdim=args.zdim)
1011+
if global_it % args.log_interval == 0:
1012+
logger.info(f"[Pretrain Iteration {global_it}] loss={loss:4f}")
1013+
1014+
if global_it >= args.pretrain:
1015+
break
10051016

10061017
# reset model after pretraining
10071018
if args.reset_optim_after_pretrain:
@@ -1147,6 +1158,7 @@ def main(args):
11471158
use_tilt=tilt is not None,
11481159
ctf_params=ctf_params,
11491160
shuffler_size=args.shuffler_size,
1161+
seed=args.shuffle_seed,
11501162
)
11511163
save_checkpoint(
11521164
model,
@@ -1181,6 +1193,8 @@ def main(args):
11811193
device,
11821194
use_tilt=tilt is not None,
11831195
ctf_params=ctf_params,
1196+
shuffler_size=args.shuffler_size,
1197+
seed=args.shuffle_seed,
11841198
)
11851199
save_checkpoint(
11861200
model,

cryodrgn/commands/abinit_homo.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,13 @@ def add_args(parser):
7979
parser.add_argument(
8080
"--seed", type=int, default=np.random.randint(0, 100000), help="Random seed"
8181
)
82+
parser.add_argument(
83+
"--shuffle-seed",
84+
type=int,
85+
default=None,
86+
help="Random seed for data shuffling",
87+
)
88+
8289
parser.add_argument(
8390
"--uninvert-data",
8491
dest="invert_data",
@@ -599,9 +606,7 @@ def main(args):
599606
datadir=args.datadir,
600607
window_r=args.window_r,
601608
)
602-
603-
D = data.D
604-
Nimg = data.N
609+
D, Nimg = data.D, data.N
605610

606611
# load ctf
607612
if args.ctf is not None:
@@ -706,25 +711,28 @@ def main(args):
706711
start_epoch = 0
707712

708713
data_iterator = dataset.make_dataloader(
709-
data, batch_size=args.batch_size, shuffler_size=args.shuffler_size
714+
data,
715+
batch_size=args.batch_size,
716+
shuffler_size=args.shuffler_size,
717+
seed=args.shuffle_seed,
710718
)
711719

712720
# pretrain decoder with random poses
713721
global_it = 0
714722
logger.info("Using random poses for {} iterations".format(args.pretrain))
715-
while global_it < args.pretrain:
716-
for batch in data_iterator:
717-
global_it += len(batch[0])
718-
batch = (
719-
(batch[0].to(device), None)
720-
if tilt is None
721-
else (batch[0].to(device), batch[1].to(device))
722-
)
723-
loss = pretrain(model, lattice, optim, batch, tilt=ps.tilt)
724-
if global_it % args.log_interval == 0:
725-
logger.info(f"[Pretrain Iteration {global_it}] loss={loss:4f}")
726-
if global_it > args.pretrain:
727-
break
723+
for batch in data_iterator:
724+
global_it += len(batch[0])
725+
batch = (
726+
(batch[0].to(device), None)
727+
if tilt is None
728+
else (batch[0].to(device), batch[1].to(device))
729+
)
730+
loss = pretrain(model, lattice, optim, batch, tilt=ps.tilt)
731+
if global_it % args.log_interval < args.batch_size:
732+
logger.info(f"[Pretrain Iteration {global_it}] loss={loss:4f}")
733+
if global_it >= args.pretrain:
734+
break
735+
728736
out_mrc = "{}/pretrain.reconstruct.mrc".format(args.outdir)
729737
model.eval()
730738
vol = model.eval_volume(lattice.coords, lattice.D, lattice.extent, tuple(data.norm))
@@ -808,7 +816,7 @@ def main(args):
808816
base_poses.append((ind_np, base_pose))
809817
# logging
810818
loss_accum += loss_item * len(batch[0])
811-
if batch_it % args.log_interval == 0:
819+
if batch_it % args.log_interval < args.batch_size:
812820
logger.info(
813821
"# [Train Epoch: {}/{}] [{}/{} images] loss={:.4f}".format(
814822
epoch + 1, args.num_epochs, batch_it, Nimg, loss_item

cryodrgn/commands/analyze.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def main(args: argparse.Namespace) -> None:
454454

455455
else:
456456
use_apix = 1.0
457-
logger.info("cannot find A/px in CTF parameters, " "defaulting to A/px=1.0")
457+
logger.info("Cannot find A/px in CTF parameters, defaulting to A/px=1.0")
458458

459459
if E == -1:
460460
zfile = f"{workdir}/z.pkl"

cryodrgn/commands/train_nn.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ def add_args(parser: argparse.ArgumentParser) -> None:
8282
parser.add_argument(
8383
"--seed", type=int, default=np.random.randint(0, 100000), help="Random seed"
8484
)
85+
parser.add_argument(
86+
"--shuffle-seed",
87+
type=int,
88+
default=None,
89+
help="Random seed for data shuffling",
90+
)
8591

8692
group = parser.add_argument_group("Dataset loading")
8793
group.add_argument(
@@ -415,9 +421,7 @@ def main(args: argparse.Namespace) -> None:
415421
datadir=args.datadir,
416422
window_r=args.window_r,
417423
)
418-
419-
D = data.D
420-
Nimg = data.N
424+
D, Nimg = data.D, data.N
421425

422426
# instantiate model
423427
# if args.pe_type != 'none': assert args.l_extent == 0.5
@@ -532,7 +536,10 @@ def main(args: argparse.Namespace) -> None:
532536

533537
# train
534538
data_generator = dataset.make_dataloader(
535-
data, batch_size=args.batch_size, shuffler_size=args.shuffler_size
539+
data,
540+
batch_size=args.batch_size,
541+
shuffler_size=args.shuffler_size,
542+
seed=args.shuffle_seed,
536543
)
537544

538545
epoch = None
@@ -561,7 +568,7 @@ def main(args: argparse.Namespace) -> None:
561568
if pose_optimizer is not None and epoch >= args.pretrain:
562569
pose_optimizer.step()
563570
loss_accum += loss_item * len(ind)
564-
if batch_it % args.log_interval == 0:
571+
if batch_it % args.log_interval < args.batch_size:
565572
logger.info(
566573
"# [Train Epoch: {}/{}] [{}/{} images] loss={:.6f}".format(
567574
epoch + 1, args.num_epochs, batch_it, Nimg, loss_item

0 commit comments

Comments
 (0)