Skip to content
143 changes: 112 additions & 31 deletions compute_metrics_reloaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
The script is compatible with both binary and multi-class segmentation tasks (e.g., nnunet region-based).
The metrics are computed for each unique label (class) in the reference (ground truth) image.

Authors: Jan Valosek
Authors: Jan Valosek, Naga Karthik
"""


Expand All @@ -41,6 +41,8 @@
import numpy as np
import nibabel as nib
import pandas as pd
import re
from tqdm import tqdm

from MetricsReloaded.metrics.pairwise_measures import BinaryPairwiseMeasures as BPM

Expand Down Expand Up @@ -81,6 +83,8 @@ def get_parser():
'see: https://metricsreloaded.readthedocs.io/en/latest/reference/metrics/metrics.html.')
parser.add_argument('-output', type=str, default='metrics.csv', required=False,
help='Path to the output CSV file to save the metrics. Default: metrics.csv')
parser.add_argument('-mask-type', type=str, default='chunks', required=False,
help='Type of the labels in the images. Options: [chunks, stitched]')

return parser

Expand Down Expand Up @@ -122,25 +126,13 @@ def get_images_in_folder(prediction, reference):
return prediction_files, reference_files


def compute_metrics_single_subject(prediction, reference, metrics):
def compute_metrics_single_subject(prediction_data, reference_data, metrics, metrics_dict):
"""
Compute MetricsReloaded metrics for a single subject
:param prediction: path to the nifti image with the prediction
:param reference: path to the nifti image with the reference (ground truth)
:param prediction: numpy array of the prediction mask
:param reference: numpy array of the reference mask (ground truth)
:param metrics: list of metrics to compute
"""
# load nifti images
print(f'Processing...')
print(f'\tPrediction: {os.path.basename(prediction)}')
print(f'\tReference: {os.path.basename(reference)}')
prediction_data = load_nifti_image(prediction)
reference_data = load_nifti_image(reference)

# check whether the images have the same shape and orientation
if prediction_data.shape != reference_data.shape:
raise ValueError(f'The prediction and reference (ground truth) images must have the same shape. '
f'The prediction image has shape {prediction_data.shape} and the ground truth image has '
f'shape {reference_data.shape}.')

# get all unique labels (classes)
# for example, for nnunet region-based segmentation, spinal cord has label 1, and lesions have label 2
Expand All @@ -152,14 +144,10 @@ def compute_metrics_single_subject(prediction, reference, metrics):
# Get the unique labels that are present in the reference OR prediction images
unique_labels = np.unique(np.concatenate((unique_labels_reference, unique_labels_prediction)))

# append entry into the output_list to store the metrics for the current subject
metrics_dict = {'reference': reference, 'prediction': prediction}

# loop over all unique labels, e.g., voxels with values 1, 2, ...
# by doing this, we can compute metrics for each label separately, e.g., separately for spinal cord and lesions
for label in unique_labels:
# create binary masks for the current label
print(f'\tLabel {label}')
prediction_data_label = np.array(prediction_data == label, dtype=float)
reference_data_label = np.array(reference_data == label, dtype=float)

Expand All @@ -176,7 +164,6 @@ def compute_metrics_single_subject(prediction, reference, metrics):
# Special case when both the reference and prediction images are empty
else:
label = 1
print(f'\tLabel {label} -- both the reference and prediction are empty')
bpm = BPM(prediction_data, reference_data, measures=metrics)
dict_seg = bpm.to_dict_meas()

Expand Down Expand Up @@ -216,8 +203,22 @@ def build_output_dataframe(output_list):
return df


def main():
def find_subject_session_chunk_in_path(path):
"""
Extracts subject and session identifiers from the given path.
:param path: Input path containing subject and session identifiers.
:return: Extracted subject and session identifiers or None if not found.
"""
# pattern = r'.*_(sub-m\d{6})_(ses-\d{8}).*_(chunk-\d{1})_.*'
pattern = r'.*_(sub-m\d{6}_ses-\d{8}).*_(chunk-\d{1})_.*'
match = re.search(pattern, path)
if match:
return match.group(1), match.group(2)
else:
return None, None, None


def main():
# parse command line arguments
parser = get_parser()
args = parser.parse_args()
Expand All @@ -232,26 +233,106 @@ def main():
if os.path.isdir(args.prediction) and os.path.isdir(args.reference):
# Get all files in the directories
prediction_files, reference_files = get_images_in_folder(args.prediction, args.reference)
# Loop over the subjects
for i in range(len(prediction_files)):
# Compute metrics for each subject
metrics_dict = compute_metrics_single_subject(prediction_files[i], reference_files[i], args.metrics)
# Append the output dictionary (representing a single reference-prediction pair per subject) to the
# output_list
output_list.append(metrics_dict)

if args.mask_type == 'chunks':

# get the subject, session, and chunk identifiers from the path
subjects_sessions = [find_subject_session_chunk_in_path(f)[0] for f in prediction_files if find_subject_session_chunk_in_path(f)]
subjects_sessions = list(set(subjects_sessions))

for sub_ses in tqdm(subjects_sessions, desc='Computing metrics for each subject'):
preds_per_sub_ses = [f for f in prediction_files if sub_ses in f]
refs_per_sub_ses = [f for f in reference_files if sub_ses in f]

preds_stack, refs_stack = [], []
for pred, ref in zip(preds_per_sub_ses, refs_per_sub_ses):
# load nifti images
prediction_data = load_nifti_image(pred)
reference_data = load_nifti_image(ref)

# check whether the images have the same shape and orientation
if prediction_data.shape != reference_data.shape:
raise ValueError(f'The prediction and reference (ground truth) images must have the same shape. '
f'The prediction image has shape {prediction_data.shape} and the ground truth image has '
f'shape {reference_data.shape}.')

preds_stack.append(prediction_data)
refs_stack.append(reference_data)

# min_shape = np.min([pred.shape for pred in preds_stack], axis=0)
max_shape = np.max([pred.shape for pred in preds_stack], axis=0)
max_shape_ref = np.max([ref.shape for ref in refs_stack], axis=0)

assert max_shape[0] == max_shape_ref[0], "The images must have the same shape at dim[0]"
assert max_shape[1] == max_shape_ref[1], "The images must have the same shape at dim[1]"
assert max_shape[2] == max_shape_ref[2], "The images must have the same shape at dim[2]"

# pad the images to the same shape
preds_stack = [np.pad(pred, ((0, max_shape[0] - pred.shape[0]), (0, max_shape[1] - pred.shape[1]), (0, max_shape[2] - pred.shape[2]))) for pred in preds_stack]
refs_stack = [np.pad(ref, ((0, max_shape[0] - ref.shape[0]), (0, max_shape[1] - ref.shape[1]), (0, max_shape[2] - ref.shape[2]))) for ref in refs_stack]

# stack the images
preds_stacked = np.stack(preds_stack, axis=-1)
refs_stacked = np.stack(refs_stack, axis=-1)

# create a new file name for reference and prediction
pred_fname = os.path.join(os.path.dirname(preds_per_sub_ses[0]), f'{sub_ses}_preds_stack.nii.gz')
ref_fname = os.path.join(os.path.dirname(refs_per_sub_ses[0]), f'{sub_ses}_refs_stack.nii.gz')

metrics_dict = {'reference': ref_fname, 'prediction': pred_fname}
# Compute metrics for each subject
metrics_dict = compute_metrics_single_subject(preds_stacked, refs_stacked, args.metrics, metrics_dict)

# append the dictionary to the output list
output_list.append(metrics_dict)

elif args.mask_type == 'stitched':
# Loop over the subjects
for i in tqdm(range(len(reference_files)), desc='Computing metrics for each subject'):

# Load nifti images
prediction_data = load_nifti_image(prediction_files[i])
reference_data = load_nifti_image(reference_files[i])

# append entry into the output_list to store the metrics for the current subject
metrics_dict = {'reference': reference_files[i], 'prediction': prediction_files[i]}
# Compute metrics for each subject
metrics_dict = compute_metrics_single_subject(prediction_data, reference_data, args.metrics, metrics_dict)

# Append the output dictionary (representing a single reference-prediction pair per subject) to the
# output_list
output_list.append(metrics_dict)

# Args.prediction and args.reference are paths nii.gz files from a SINGLE subject
else:
metrics_dict = compute_metrics_single_subject(args.prediction, args.reference, args.metrics)
# Load nifti images
prediction_data = load_nifti_image(args.prediction)
reference_data = load_nifti_image(args.reference)

metrics_dict = {'reference': args.reference, 'prediction': args.prediction}
metrics_dict = compute_metrics_single_subject(prediction_data, reference_data, args.metrics, metrics_dict)

# Append the output dictionary (representing a single reference-prediction pair per subject) to the output_list
output_list.append(metrics_dict)

# Convert JSON data to pandas DataFrame
df = build_output_dataframe(output_list)

# create a separate dataframe for columns where EmptyRef and EmptyPred is True
df_empty_masks = df[(df['EmptyRef'] == True) & (df['EmptyPred'] == True)]

# keep only the rows where either pred or ref is non-empty or both are non-empty
df = df[(df['EmptyRef'] == False) | (df['EmptyPred'] == False)]

# Compute mean and standard deviation of metrics across all subjects
df_mean = (df.drop(columns=['reference', 'prediction', 'EmptyRef', 'EmptyPred']).groupby('label').
agg(['mean', 'std']).reset_index())

# Convert multi-index to flat index
df_mean.columns = ['_'.join(col).strip() for col in df_mean.columns.values]
# Rename column `label_` back to `label`
df_mean.rename(columns={'label_': 'label'}, inplace=True)

# Rename columns
df.rename(columns={metric: METRICS_TO_NAME[metric] for metric in METRICS_TO_NAME}, inplace=True)
df_mean.rename(columns={metric: METRICS_TO_NAME[metric] for metric in METRICS_TO_NAME}, inplace=True)
Expand All @@ -272,4 +353,4 @@ def main():


if __name__ == '__main__':
main()
main()