@@ -348,47 +348,55 @@ def plot(i, j):
348348
349349 kmeans_labels = utils .load_pkl (os .path .join (outdir , f"kmeans{ K } " , "labels.pkl" ))
350350 kmeans_counts = Counter (kmeans_labels )
351- for i in range (M ):
352- vol_i = np .where (labels == i )[0 ]
353- logger .info (f"State { i } : { len (vol_i )} volumes" )
351+ for cluster_i in range (M ):
352+ vol_indices = np .where (labels == cluster_i )[0 ]
353+ logger .info (f"State { cluster_i } : { len (vol_indices )} volumes" )
354354 if vol_ind is not None :
355- vol_i = np .arange (K )[vol_ind ][vol_i ]
355+ vol_indices = np .arange (K )[vol_ind ][vol_indices ]
356+
357+ vol_fls = [
358+ os .path .join (kmean_dir , f"vol_{ vol_start_index + vol_i :03d} .mrc" )
359+ for vol_i in vol_indices
360+ ]
361+ vol_i_all = torch .stack (
362+ [torch .Tensor (parse_mrc (vol_fl )[0 ]) for vol_fl in vol_fls ]
363+ )
356364
357- vol_fl = os .path .join (kmean_dir , f"vol_{ vol_start_index + i :03d} .mrc" )
358- vol_i_all = torch .stack ([torch .Tensor (parse_mrc (vol_fl )[0 ]) for i in vol_i ])
359- nparticles = np .array ([kmeans_counts [i ] for i in vol_i ])
365+ nparticles = np .array ([kmeans_counts [vol_i ] for vol_i in vol_indices ])
360366 vol_i_mean = np .average (vol_i_all , axis = 0 , weights = nparticles )
361367 vol_i_std = (
362368 np .average ((vol_i_all - vol_i_mean ) ** 2 , axis = 0 , weights = nparticles ) ** 0.5
363369 )
370+
364371 write_mrc (
365- os .path .join (subdir , f"state_{ i } _mean.mrc" ),
372+ os .path .join (subdir , f"state_{ cluster_i } _mean.mrc" ),
366373 vol_i_mean .astype (np .float32 ),
367374 Apix = Apix ,
368375 )
369376 write_mrc (
370- os .path .join (subdir , f"state_{ i } _std.mrc" ),
377+ os .path .join (subdir , f"state_{ cluster_i } _std.mrc" ),
371378 vol_i_std .astype (np .float32 ),
372379 Apix = Apix ,
373380 )
374381
375- os .makedirs ( os . path .join (subdir , f"state_{ i } " ), exist_ok = True )
376- for v in vol_i :
377- os . symlink (
378- os .path .join (kmean_dir , f"vol_{ vol_start_index + v :03d} .mrc" ),
379- os .path .join (subdir , f"state_ { i } " , f" vol_{ vol_start_index + v :03d} .mrc" ),
380- )
382+ statedir = os .path .join (subdir , f"state_{ cluster_i } " )
383+ os . makedirs ( statedir , exist_ok = True )
384+ for vol_i in vol_indices :
385+ kmean_fl = os .path .join (kmean_dir , f"vol_{ vol_start_index + vol_i :03d} .mrc" )
386+ sub_fl = os .path .join (statedir , f"vol_{ vol_start_index + vol_i :03d} .mrc" )
387+ os . symlink ( kmean_fl , sub_fl )
381388
382- particle_ind = analysis .get_ind_for_cluster (kmeans_labels , vol_i )
383- logger .info (f"State { i } : { len (particle_ind )} particles" )
389+ particle_ind = analysis .get_ind_for_cluster (kmeans_labels , vol_indices )
390+ logger .info (f"State { cluster_i } : { len (particle_ind )} particles" )
384391 if particle_ind_orig is not None :
385392 utils .save_pkl (
386393 particle_ind_orig [particle_ind ],
387- os .path .join (subdir , f"state_{ i } _particle_ind.pkl" ),
394+ os .path .join (subdir , f"state_{ cluster_i } _particle_ind.pkl" ),
388395 )
389396 else :
390397 utils .save_pkl (
391- particle_ind , os .path .join (subdir , f"state_{ i } _particle_ind.pkl" )
398+ particle_ind ,
399+ os .path .join (subdir , f"state_{ cluster_i } _particle_ind.pkl" ),
392400 )
393401
394402 # plot clustering results
0 commit comments