@@ -80,6 +80,12 @@ def add_args(parser):
8080 parser .add_argument (
8181 "--seed" , type = int , default = np .random .randint (0 , 100000 ), help = "Random seed"
8282 )
83+ parser .add_argument (
84+ "--shuffle-seed" ,
85+ type = int ,
86+ default = None ,
87+ help = "Random seed for data shuffling" ,
88+ )
8389
8490 group = parser .add_argument_group ("Dataset loading" )
8591 group .add_argument (
@@ -603,12 +609,17 @@ def eval_z(
603609 use_tilt = False ,
604610 ctf_params = None ,
605611 shuffler_size = 0 ,
612+ seed = None ,
606613):
607614 assert not model .training
608- z_mu_all = []
609- z_logvar_all = []
615+
616+ z_mu_all , z_logvar_all = list (), list ()
610617 data_generator = dataset .make_dataloader (
611- data , batch_size = batch_size , shuffler_size = shuffler_size , shuffle = False
618+ data ,
619+ batch_size = batch_size ,
620+ shuffler_size = shuffler_size ,
621+ shuffle = False ,
622+ seed = seed ,
612623 )
613624
614625 for minibatch in data_generator :
@@ -638,9 +649,8 @@ def eval_z(
638649 z_mu , z_logvar = _model .encode (* input_ )
639650 z_mu_all .append (z_mu .detach ().cpu ().numpy ())
640651 z_logvar_all .append (z_logvar .detach ().cpu ().numpy ())
641- z_mu_all = np .vstack (z_mu_all )
642- z_logvar_all = np .vstack (z_logvar_all )
643- return z_mu_all , z_logvar_all
652+
653+ return np .vstack (z_mu_all ), np .vstack (z_logvar_all )
644654
645655
646656def save_checkpoint (
@@ -814,9 +824,7 @@ def main(args):
814824 datadir = args .datadir ,
815825 window_r = args .window_r ,
816826 )
817-
818- Nimg = data .N
819- D = data .D
827+ Nimg , D = data .N , data .D
820828
821829 if args .encode_mode == "conv" :
822830 assert D - 1 == 64 , "Image size must be 64x64 for convolutional encoder"
@@ -983,25 +991,28 @@ def main(args):
983991 )
984992
985993 data_iterator = dataset .make_dataloader (
986- data , batch_size = args .batch_size , shuffler_size = args .shuffler_size
994+ data ,
995+ batch_size = args .batch_size ,
996+ shuffler_size = args .shuffler_size ,
997+ seed = args .shuffle_seed ,
987998 )
988999
9891000 # pretrain decoder with random poses
9901001 global_it = 0
9911002 logger .info ("Using random poses for {} iterations" .format (args .pretrain ))
992- while global_it < args . pretrain :
993- for batch in data_iterator :
994- global_it += len ( batch [ 0 ])
995- batch = (
996- ( batch [ 0 ]. to ( device ), None )
997- if tilt is None
998- else ( batch [ 0 ]. to ( device ), batch [ 1 ]. to ( device ) )
999- )
1000- loss = pretrain ( model , lattice , optim , batch , tilt = ps . tilt , zdim = args . zdim )
1001- if global_it % args . log_interval == 0 :
1002- logger . info ( f"[Pretrain Iteration { global_it } ] loss= { loss :4f } " )
1003- if global_it > args .pretrain :
1004- break
1003+ for batch in data_iterator :
1004+ global_it += len ( batch [ 0 ])
1005+ batch = (
1006+ ( batch [ 0 ]. to ( device ), None )
1007+ if tilt is None
1008+ else ( batch [ 0 ]. to ( device ), batch [ 1 ]. to ( device ))
1009+ )
1010+ loss = pretrain ( model , lattice , optim , batch , tilt = ps . tilt , zdim = args . zdim )
1011+ if global_it % args . log_interval == 0 :
1012+ logger . info ( f"[Pretrain Iteration { global_it } ] loss= { loss :4f } " )
1013+
1014+ if global_it >= args .pretrain :
1015+ break
10051016
10061017 # reset model after pretraining
10071018 if args .reset_optim_after_pretrain :
@@ -1147,6 +1158,7 @@ def main(args):
11471158 use_tilt = tilt is not None ,
11481159 ctf_params = ctf_params ,
11491160 shuffler_size = args .shuffler_size ,
1161+ seed = args .shuffle_seed ,
11501162 )
11511163 save_checkpoint (
11521164 model ,
@@ -1181,6 +1193,8 @@ def main(args):
11811193 device ,
11821194 use_tilt = tilt is not None ,
11831195 ctf_params = ctf_params ,
1196+ shuffler_size = args .shuffler_size ,
1197+ seed = args .shuffle_seed ,
11841198 )
11851199 save_checkpoint (
11861200 model ,
0 commit comments