Skip to content

Commit 661b4f6

Browse files
optimize correlations computing
1 parent ed33256 commit 661b4f6

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

postprocess/postprocess_service.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,30 +20,38 @@ def get_dataset(self, path, corr: pd.DataFrame) -> pd.DataFrame:
2020
df = pd.read_csv(config, delimiter=';').astype(bool)
2121
df['id'] = conf_id
2222
df['from_ref'] = corr.loc[(corr['source'] == conf_id) & (corr['target'] == 'ref'), 'correlation'].values[0]
23-
df['from_mean'] = corr.loc[(corr['source'] == conf_id) & (corr['target'] == 'mean'), 'correlation'].values[0]
23+
df['from_mean'] = corr.loc[(corr['source'] == conf_id) & (corr['target'] == 'mean'), 'correlation'].values[
24+
0]
2425
dataframes.append(df)
2526

2627
return pd.concat(dataframes, ignore_index=True)
2728

2829
def get_all_correlations(self, path, ids: List[str]) -> pd.DataFrame:
29-
dataframes = []
30+
dataframe = pd.DataFrame(columns=['source', 'target', 'correlation'])
3031
niis = {'ref': os.path.join(path, 'ref', '_subject_id_01', 'result.nii'),
3132
'mean': os.path.join(path, 'mean_result.nii')}
3233
for conf_id in ids:
3334
niis[conf_id] = os.path.join(path, conf_id, '_subject_id_01', 'result.nii')
3435

3536
for id_src in niis:
3637
for id_tgt in niis:
37-
corr = self.corr_srv.get_correlation_coefficient(niis[id_tgt], niis[id_src], 'spearman')
38-
dataframes.append(pd.DataFrame([[id_src, id_tgt, corr]], columns=['source', 'target', 'correlation']))
38+
# This correlation may have already been calculated the other way
39+
if ((dataframe['source'] == id_tgt) & (dataframe['target'] == id_src)).any():
40+
corr = dataframe.loc[(dataframe['source'] == id_tgt) & (dataframe['target'] == id_src), 'correlation'].values[0]
41+
else:
42+
corr = self.corr_srv.get_correlation_coefficient(niis[id_tgt], niis[id_src], 'spearman')
43+
dataframe.append({'source': id_src, 'target': id_tgt, 'correlation': corr}, ignore_index=True)
3944

40-
return pd.concat(dataframes, ignore_index=True).sort_values(by='correlation', ascending=False)
45+
return dataframe.sort_values(by='correlation', ascending=False)
4146

4247
def get_mean_image(self, inputs: list, batch_size: int) -> nib.Nifti1Image:
4348
total_sum = None
4449
count = 0
4550

46-
for i in range(0, len(inputs), batch_size):
51+
total = len(inputs)
52+
53+
print(f"Summing up the [{total}] images")
54+
for i in range(0, total, batch_size):
4755
batch_paths = inputs[i:i + batch_size]
4856
batch_images = [nib.load(path).get_fdata() for path in batch_paths]
4957

@@ -60,11 +68,12 @@ def get_mean_image(self, inputs: list, batch_size: int) -> nib.Nifti1Image:
6068
total_sum += batch_sum
6169

6270
count += len(batch_paths)
71+
print(f"Summed [{count}] images.")
6372

64-
# Calculate the mean image
73+
print("Calculating the mean image...")
6574
mean_image = total_sum / count
6675

67-
# Create a new NIfTI image with the mean data
76+
print("Creating a new NIfTI image with the mean data...")
6877
mean_nifti = nib.Nifti1Image(mean_image, affine=nib.load(inputs[0]).affine)
69-
78+
print("Mean image created.")
7079
return mean_nifti

0 commit comments

Comments
 (0)