@@ -20,30 +20,38 @@ def get_dataset(self, path, corr: pd.DataFrame) -> pd.DataFrame:
2020 df = pd .read_csv (config , delimiter = ';' ).astype (bool )
2121 df ['id' ] = conf_id
2222 df ['from_ref' ] = corr .loc [(corr ['source' ] == conf_id ) & (corr ['target' ] == 'ref' ), 'correlation' ].values [0 ]
23- df ['from_mean' ] = corr .loc [(corr ['source' ] == conf_id ) & (corr ['target' ] == 'mean' ), 'correlation' ].values [0 ]
23+ df ['from_mean' ] = corr .loc [(corr ['source' ] == conf_id ) & (corr ['target' ] == 'mean' ), 'correlation' ].values [
24+ 0 ]
2425 dataframes .append (df )
2526
2627 return pd .concat (dataframes , ignore_index = True )
2728
2829 def get_all_correlations (self , path , ids : List [str ]) -> pd .DataFrame :
29- dataframes = []
30+ dataframe = pd . DataFrame ( columns = [ 'source' , 'target' , 'correlation' ])
3031 niis = {'ref' : os .path .join (path , 'ref' , '_subject_id_01' , 'result.nii' ),
3132 'mean' : os .path .join (path , 'mean_result.nii' )}
3233 for conf_id in ids :
3334 niis [conf_id ] = os .path .join (path , conf_id , '_subject_id_01' , 'result.nii' )
3435
3536 for id_src in niis :
3637 for id_tgt in niis :
37- corr = self .corr_srv .get_correlation_coefficient (niis [id_tgt ], niis [id_src ], 'spearman' )
38- dataframes .append (pd .DataFrame ([[id_src , id_tgt , corr ]], columns = ['source' , 'target' , 'correlation' ]))
38+ # This correlation may have already been calculated the other way
39+ if ((dataframe ['source' ] == id_tgt ) & (dataframe ['target' ] == id_src )).any ():
40+ corr = dataframe .loc [(dataframe ['source' ] == id_tgt ) & (dataframe ['target' ] == id_src ), 'correlation' ].values [0 ]
41+ else :
42+ corr = self .corr_srv .get_correlation_coefficient (niis [id_tgt ], niis [id_src ], 'spearman' )
43+ dataframe .append ({'source' : id_src , 'target' : id_tgt , 'correlation' : corr }, ignore_index = True )
3944
40- return pd . concat ( dataframes , ignore_index = True ) .sort_values (by = 'correlation' , ascending = False )
45+ return dataframe .sort_values (by = 'correlation' , ascending = False )
4146
4247 def get_mean_image (self , inputs : list , batch_size : int ) -> nib .Nifti1Image :
4348 total_sum = None
4449 count = 0
4550
46- for i in range (0 , len (inputs ), batch_size ):
51+ total = len (inputs )
52+
53+ print (f"Summing up the [{ total } ] images" )
54+ for i in range (0 , total , batch_size ):
4755 batch_paths = inputs [i :i + batch_size ]
4856 batch_images = [nib .load (path ).get_fdata () for path in batch_paths ]
4957
@@ -60,11 +68,12 @@ def get_mean_image(self, inputs: list, batch_size: int) -> nib.Nifti1Image:
6068 total_sum += batch_sum
6169
6270 count += len (batch_paths )
71+ print (f"Summed [{ count } ] images." )
6372
64- # Calculate the mean image
73+ print ( "Calculating the mean image..." )
6574 mean_image = total_sum / count
6675
67- # Create a new NIfTI image with the mean data
76+ print ( "Creating a new NIfTI image with the mean data..." )
6877 mean_nifti = nib .Nifti1Image (mean_image , affine = nib .load (inputs [0 ]).affine )
69-
78+ print ( "Mean image created." )
7079 return mean_nifti
0 commit comments