Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ Official repository for contrast-agnostic segmentation of the spinal cord.
This repo contains all the code for training the contrast-agnostic model. The code for training is based on the [nnUNetv2 framework](https://github.com/MIC-DKFZ/nnUNet). The segmentation model is available as part of [Spinal Cord Toolbox (SCT)](https://spinalcordtoolbox.com/stable/user_section/command-line/deepseg/spinalcord.html) via the `sct_deepseg` functionality.


<img width="1540" alt="lifelong_ca_final" src="https://github.com/user-attachments/assets/c35d445c-d2ec-4bca-9995-e16371972cbf" />


### Citation Information

If you find this work and/or code useful for your research, please cite our paper:
Expand All @@ -26,8 +29,6 @@ note = {Shared authorship -- authors contributed equally}
}
```

**TODO**: add lifelong learning figure


## Table of contents
* [Training the model ](#training-the-model)
Expand All @@ -49,7 +50,7 @@ note = {Shared authorship -- authors contributed equally}

1. Create a conda environment with the following command:
```bash
conda create -n contrast_agnostic python=3.9
conda create -n contrast_agnostic python=3.9.16
```

2. Activate the environment with the following command:
Expand All @@ -64,8 +65,8 @@ git clone https://github.com/sct-pipeline/contrast-agnostic-softseg-spinalcord.g

3. Install the required packages with the following command:
```bash
cd contrast-agnostic-softseg-spinalcord/nnUnet
pip install -r requirements.txt
cd contrast-agnostic-softseg-spinalcord
pip install -r nnUnet/requirements.txt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it intentional that nnUnet/requirements.txt does not install the nnUNet itself?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without nnUNet installed in the contrast_agnostic env, I obviously get:

image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it is intentional, I thikn nnunet installation should be done outside the contrast-agnostic repo

```

> **Note**
Expand All @@ -74,11 +75,17 @@ pip install -r requirements.txt

### Step 2: Train the model

The script `scripts/train_contrast_agnostic.sh` downloads the datasets from git-annex, creates datalists, converts them into nnUNet-specific format, and trains the model. More instructions about what variables to set and which datasets to use can be found in the script itself. Once these variables are set, the script can be run simply as follows:
The script `scripts/train_contrast_agnostic.sh` downloads the datasets from git-annex, creates datalists, converts them into nnUNet-specific format, and trains the model. More instructions about what variables to set and which datasets to use can be found in the script itself. Once these variables are set, run:

```bash
bash scripts/train_contrast_agnostic.sh
```

> [!IMPORTANT]
> The script `train_contrast_agnostic.sh` will NOT run out-of-the-box. User-specific variables such as the path to download datasets and nnUnet repository need to be set. Info about which varibles to set can be found in the script itself.

> [!IMPORTANT]
> You might need to run the `train_contrast_agnostic.sh` script in a virtual terminal such as `tmux` or `screen`.
<!--
TODO: move to csa_qc_evaluation folder
## 5. Computing morphometric measures (CSA)
Expand Down
26 changes: 17 additions & 9 deletions anima_metrics/compute_anima_metrics_spine_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,13 @@ def get_parser():
help='Path to the folder containing nifti images of test predictions AND GTs'
' when method == monai, else path to the xml files containing the pre-computed'
' ANIMA metrics when method == [deepseg*, propseg]')
parser.add_argument('--out-folder', required=False, type=str, default=None,
help='Path to the folder containing the output xml files containing the ANIMA metrics.')
parser.add_argument('-dname', '--dataset-name', required=True, type=str, choices=STANDARD_DATASETS,
help='Dataset name used for storing on git-annex. For region-based metrics, '
'append "-region" to the dataset name. Default: spine-generic')
parser.add_argument('--method', required=True, type=str, default='monai',
choices=['monai', 'synthseg', 'deepseg2d', 'deepseg3d', 'propseg', 'v20', 'v30'],
choices=['monai', 'synthseg', 'deepseg2d', 'deepseg3d', 'propseg', 'v20', 'v30', 'scisegv2'],
help='Segmentation method to compute metrics for. Default: monai')

return parser
Expand All @@ -120,10 +122,13 @@ def get_test_metrics_by_dataset(pred_folder, output_folder, anima_binaries_path,
if method == 'v20':
pred_suffix = 'seg_v20' # '_pred.nii.gz'
elif method == 'v30':
pred_suffix = 'seg_nnunet-AllRandInit3D_bin'
# pred_suffix = 'seg_nnunet-AllRandInit3D_bin'
pred_suffix = 'seg_nnunet-AllRandInit3D'
elif method == 'deepseg2d':
pred_suffix = 'seg_deepseg_2d'
gt_suffix = 'softseg_bin' # '_softseg_gt.nii.gz'
elif method == 'scisegv2':
pred_suffix = 'seg_scisegv2'
gt_suffix = "label-SC_seg" #'seg-manual' # 'softseg_bin'
if data_set in STANDARD_DATASETS:
# glob all the predictions and GTs and get the last three digits of the filename
pred_files = sorted(glob.glob(f"{pred_folder}/**/**/*_{pred_suffix}.nii.gz"))
Expand Down Expand Up @@ -202,9 +207,13 @@ def main():
print(f"Saving ANIMA performance metrics to {output_folder}")

# Get all XML filepaths where ANIMA performance metrics are saved for each hold-out subject
if method in ['monai', 'synthseg', 'v20', 'deepseg2d', 'v30']:
subject_filepaths = get_test_metrics_by_dataset(pred_folder, output_folder, anima_binaries_path,
data_set=dataset_name, method=method)
if method in ['monai', 'synthseg', 'v20', 'deepseg2d', 'v30', 'scisegv2']:
if args.out_folder:
subject_filepaths = [os.path.join(output_folder, f) for f in os.listdir(output_folder) if f.endswith('.xml')]
else:
print("Computing ANIMA metrics from scratch as no output folder is provided!")
subject_filepaths = get_test_metrics_by_dataset(pred_folder, output_folder, anima_binaries_path,
data_set=dataset_name, method=method)
elif method in ['deepseg3d', 'propseg']:
subject_filepaths = sorted(glob.glob(f"{pred_folder}/*.xml"))
else:
Expand All @@ -216,7 +225,7 @@ def main():
# Update the test metrics dictionary by iterating over all subjects
for subject_filepath in subject_filepaths:
subject = os.path.split(subject_filepath)[-1].split('_')[0]
contrast = os.path.split(subject_filepath)[-1].split('_')[1]
contrast = 'T2' #os.path.split(subject_filepath)[-1].split('_')[1] # T2w
root_node = ET.parse(source=subject_filepath).getroot()

# create a dictionary to store the metrics for each subject
Expand Down Expand Up @@ -252,8 +261,7 @@ def main():
# get the list of contrasts
contrasts = sorted(list(test_metrics[list(test_metrics.keys())[0]].keys()))
print(f"Contrasts: {contrasts}")
metrics = ['Dice', 'RelativeVolumeError', 'SurfaceDistance', 'HausdorffDistance']

metrics = ['Jaccard', 'Dice', 'RelativeVolumeError', 'SurfaceDistance', 'HausdorffDistance']
metrics_per_contrast = {}
metrics_avg_all = {}

Expand Down
147 changes: 145 additions & 2 deletions csa_generate_figures/analyse_csa_across.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import friedmanchisquare, normaltest, wilcoxon
import scikit_posthocs as sp
from loguru import logger

# Setting the hue order as specified
HUE_ORDER = ["softseg_bin", "deepseg_2d", "v20", "nnunet-AllRandInit3D_bin"]
Expand All @@ -33,6 +36,36 @@
'nnunet-AllRandInit3D_bin': '#a6d854',
}

def format_pvalue(p_value, decimal_places=3, include_space=False, include_equal=True):
"""
Format p-value.
If the p-value is lower than alpha, format it to "<0.001", otherwise, round it to three decimals

:param p_value: input p-value as a float
:param decimal_places: number of decimal places the p-value will be rounded
:param include_space: include space or not (e.g., ' = 0.06')
:param include_equal: include equal sign ('=') to the p-value (e.g., '=0.06') or not (e.g., '0.06')
:return: p_value: the formatted p-value (e.g., '<0.05') as a str
"""
if include_space:
space = ' '
else:
space = ''

# If the p-value is lower than alpha, return '<alpha' (e.g., <0.001)
for alpha in [0.001, 0.01, 0.05]:
if p_value < alpha:
p_value = space + "<" + space + str(alpha)
break
# If the p-value is greater than 0.05, round it number of decimals specified by decimal_places
else:
if include_equal:
p_value = space + '=' + space + str(round(p_value, decimal_places))
else:
p_value = space + str(round(p_value, decimal_places))

return p_value


def save_figure(file_path, save_fname):
plt.tight_layout()
Expand Down Expand Up @@ -97,7 +130,6 @@ def extract_contrast_and_details(filename, across="Method"):
raise ValueError(f'Unknown analysis type: {across}. Choices: [Method, Resolution, Threshold].')



def generate_figure_std(data, file_path, across="Method", metric="csa", hue_order=HUE_ORDER):
"""
Generate violinplot showing STD across participants for each method
Expand Down Expand Up @@ -334,7 +366,6 @@ def generate_figure_csa(file_path, data, method=None, threshold=None):
save_figure(file_path, save_fname)



def generate_figure_csa_all_methods(file_path, data):
"""
Generate a single violinplot showing CSA for each contrast across all methods
Expand Down Expand Up @@ -464,10 +495,117 @@ def generate_figure_csa_all_methods(file_path, data):
plt.show()


def compute_statistical_tests(df_avg_csa, exp_type='default'):
"""
Compute statistical tests to compare the STD of CSA across methods.
"""
# Compute mean and std across contrasts for each method
# contains columns: 'Method', 'Participant', 'mean', 'std'
df = df_avg_csa.groupby(['Method', 'Participant'])['MEAN(area)'].agg(['mean', 'std']).reset_index()
# drop method 'softseg_bin'
df = df[df['Method'] != 'softseg_bin'] # we don't want pair-wise comparisons with GT (only methods)

# create df for each method
method_dfs = {}
for method in df['Method'].unique():
method_dfs[method] = df[df['Method'] == method][['Participant', 'std']].set_index('Participant')

# check the length of each method dataframe
lengths = [len(method_dfs[method]) for method in method_dfs]
if len(set(lengths)) != 1:
logger.info("Error: Not all methods have the same number of participants. Cannot perform Friedman test.")
return

# for method in df['Method'].unique():
# if method == 'v20':
# # just a sanity check to see if average of std values matches what we have on STD CSA plot
# # print mean and std of the 'std' column
# mean_std = method_dfs[method]['std'].mean()
# std_std = method_dfs[method]['std'].std()
# print(f'Method: {method}, Mean of STD: {mean_std:.3f}, Std of STD: {std_std:.3f}')

# Check normality of the data for each method using D'Agostino and Pearson's test
for method in method_dfs:
stat, p_value = normaltest(method_dfs[method]['std'])
logger.info(f'Normality test for method {method}: stat={stat}, p-value (formatted): {format_pvalue(p_value)}, p-value: {p_value}')
if p_value < 0.05:
logger.info(f"The data for method {method} is not normally distributed (reject H0). Consider using a non-parametric test.\n")
else:
logger.info(f"The data for method {method} is normally distributed (fail to reject H0)")

if exp_type == 'ablation':

# check if nnunet-AllInferred3D_bin|nnunet-AllRandInit3D_bin are in method_dfs
methods_to_compare = ['nnunet-AllInferred3D_bin', 'nnunet-AllRandInit3D_bin']
for method in methods_to_compare:
if method not in method_dfs:
logger.info(f"Error: Method {method} not found in the data. Cannot perform Wilcoxon signed-rank test.")
return

# NOTE: because we only have 2 related samples (i.e., methods): original model and one trained with recursive GT.
# we use Wilcoxon signed-rank test
stat, p_value = wilcoxon(method_dfs['nnunet-AllInferred3D_bin']['std'], method_dfs['nnunet-AllRandInit3D_bin']['std'])
logger.info(f'Wilcoxon signed-rank test statistic: {stat}, p-value (formatted): {format_pvalue(p_value)}, p-value: {p_value}')
if p_value < 0.05:
logger.info("There is a significant difference between the two methods (reject H0).\n")

else:
# NOTE: Why Friedman? Because we have more than 2 related samples (i.e., methods) and we want to compare their distributions.
# Related (paired) samples because the same participants (i.e. test set) are used for each method.
stat, p_value = friedmanchisquare(*[method_dfs[method]['std'] for method in method_dfs])
logger.info(f'Friedman test statistic: {stat}, p-value (formatted): {format_pvalue(p_value)}, p-value: {p_value}')
if p_value < 0.05:
logger.info("There is a significant difference between the methods (reject H0). Perform post-hoc test to identify which pairs are different.\n")

# reframe the df to have all methods in a column and the metric values in another column
df_friedman = pd.DataFrame()
for method in method_dfs:
temp_df = method_dfs[method].reset_index()
temp_df['Method'] = method
df_friedman = pd.concat([df_friedman, temp_df])
df_friedman = df_friedman.rename(columns={'std': 'Metric'})

# NOTE: As per https://scikit-posthocs.readthedocs.io/en/latest/generated/scikit_posthocs.posthoc_dunn.html
# Dunn's test is suitable after Kruskall-Wallis test
# # Perform post-hoc test using Dunn's test with Holm correction
# p_posthoc = sp.posthoc_dunn(df_friedman, val_col='Metric', group_col='Method', p_adjust='holm')
# logger.info("Post-hoc Dunn's test p-values (Holm corrected):")
# logger.info(f"\n{p_posthoc}")

# NOTE: As per https://scikit-posthocs.readthedocs.io/en/latest/generated/scikit_posthocs.posthoc_nemenyi_friedman.html
# there exists a post-hoc test specifically for Friedman test, which is the Nemenyi test
# Perform post-hoc test using nemenyi test
# get a non-melted version of df_friedman (can be used with default setting of scikit-posthocs)
df_wide = df_friedman.pivot(index='Participant', columns='Method', values='Metric')

p_posthoc = sp.posthoc_nemenyi_friedman(df_wide)
logger.info("Post-hoc Nemenyi test p-values:")
logger.info(f"\n{p_posthoc}")

# check if p_posthoc is symmetric
assert (p_posthoc.values == p_posthoc.values.T).all(), "Post-hoc p-value matrix is not symmetric"

# p_posthoc is a square matrix with the p-values for each pair of methods. The diagonal is always 1
# (i.e., method compared to itself). So, the matrix is symmetric (i.e., p-value for method A vs B is the same as B vs A)
# we only need to print the upper triangle of the matrix
logger.info("Significant differences (p < 0.05) between methods:")
for i in range(len(p_posthoc)):
for j in range(i+1, len(p_posthoc)):
method1 = p_posthoc.index[i]
method2 = p_posthoc.columns[j]
p_val = p_posthoc.iloc[i, j]
if p_val < 0.05:
logger.info(f"\t{method1} vs {method2}: p-value (formatted): {format_pvalue(p_val)}, p-value: {p_val}")
else:
logger.info("No significant difference between the methods (fail to reject H0)")


def main(args, analysis_type="methods"):
# Load the CSV file containing averaged (across slices) C2-C3 CSA
data_avg_csa = pd.read_csv(args.i)

logger.add(os.path.join(os.path.dirname(args.i), 'log.txt'), rotation='10 MB', level='INFO')

# Apply the function to extract participant ID
data_avg_csa['Participant'] = data_avg_csa['Filename'].apply(fetch_participant_id)

Expand All @@ -483,6 +621,9 @@ def main(args, analysis_type="methods"):
# Generate violinplot showing STD across participants for each method
generate_figure_std(data_avg_csa, file_path=args.i, metric="csa")

# Compute statistical tests
compute_statistical_tests(data_avg_csa, exp_type='ablation' if args.ablation else 'default')

if args.i_dice is not None:
# Generate violinplot showing average slicewise Dice scores across participants for each method

Expand Down Expand Up @@ -532,5 +673,7 @@ def main(args, analysis_type="methods"):
help="Path to the CSV file containing averaged slice-wise Dice scores for each contrast, method, and subjects")
parser.add_argument('-a', type=str, default="methods",
help='Options to analyse CSA across. Choices: [methods, resolutions]')
parser.add_argument('--ablation', action='store_true',
help='If set, perform generate std, stats results for recursive/ablation experiments.')
args = parser.parse_args()
main(args, analysis_type=args.a)
2 changes: 2 additions & 0 deletions datasplits/datasplit_basel-mp2rage_seed50.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
dataset_name: basel-mp2rage
dataset_version_commit: 1efa01bc306292bc043f9f6a6ea8c6ed4d6c44fd
test:
- sub-C069
- sub-C090
Expand Down
2 changes: 2 additions & 0 deletions datasplits/datasplit_canproco_seed50.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
dataset_name: canproco
dataset_version_commit: a04d89739c769dc03f23fcda183df62c62f586a9
test:
- sub-cal080
- sub-cal085
Expand Down
2 changes: 2 additions & 0 deletions datasplits/datasplit_data-multi-subject_seed50.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
dataset_name: data-multi-subject
dataset_version_commit: a0738046538232df8e09eba8d98899eada9c11d5
test:
- sub-barcelona06
- sub-beijingPrisma01
Expand Down
2 changes: 2 additions & 0 deletions datasplits/datasplit_dcm-brno_seed50.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
dataset_name: dcm-brno
dataset_version_commit: 3dacde7ee16f0dfc27508fe0bf8f1919cfc7eb4d
test:
- sub-1860B6472B
- sub-2295B4676B
Expand Down
2 changes: 2 additions & 0 deletions datasplits/datasplit_dcm-zurich-lesions-20231115_seed50.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
dataset_name: dcm-zurich-lesions-20231115
dataset_version_commit:
test:
- sub-11
- sub-12
Expand Down
2 changes: 2 additions & 0 deletions datasplits/datasplit_dcm-zurich-lesions_seed50.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
dataset_name: dcm-zurich-lesions
dataset_version_commit: d214e0603fcd3879317fe0a0b4cd634ee2a92f1d
test:
- sub-09
- sub-16
Expand Down
2 changes: 2 additions & 0 deletions datasplits/datasplit_dcm-zurich_seed50.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
dataset_name: dcm-zurich
dataset_version_commit: 83dab50d8138bbc1f8e4f18672e651e988d1e000
test:
- sub-260155
- sub-296085
Expand Down
2 changes: 2 additions & 0 deletions datasplits/datasplit_lumbar-epfl_seed50.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
dataset_name: lumbar-epfl
dataset_version_commit: c6685fc4762daea3ec6f184b128b7fe19acad2b8
test:
- sub-05
- sub-11
Expand Down
2 changes: 2 additions & 0 deletions datasplits/datasplit_lumbar-vanderbilt_seed50.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
dataset_name: lumbar-vanderbilt
dataset_version_commit: 81fc970a6515ec27d90c0dda5935b5179a10305e
test:
- sub-140549
- sub-242142
Expand Down
2 changes: 2 additions & 0 deletions datasplits/datasplit_sci-colorado_seed50.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
dataset_name: sci-colorado
dataset_version_commit: 1518ecd184b8a89bc1a1197eb5ae4caf5c608fb9
test:
- sub-5575
- sub-5629
Expand Down
2 changes: 2 additions & 0 deletions datasplits/datasplit_sci-paris_seed50.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
dataset_name: sci-paris
dataset_version_commit: 0a0d252c95e2400038f86e80bde85ffba0ffff0e
test:
- sub-045
- sub-049
Expand Down
2 changes: 2 additions & 0 deletions datasplits/datasplit_sci-zurich_seed50.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
dataset_name: sci-zurich
dataset_version_commit: ac1e679a91e5befac1bcd09ba451daddf2a25d1b
test:
- sub-zh117
- sub-zh119
Expand Down
Loading