Skip to content
225 changes: 225 additions & 0 deletions dataset_aggregation/agregate_unannotated_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
"""
This script agregrates data from multiple datasets into a json file.
The datasets are umass* (4 datasets), ms-nmo-beijing and ms-mayo-critical-lesions.
For inspiration, I used: https://github.com/ivadomed/ms-lesion-agnostic/blob/r20250626/dataset_analysis/msd_data_analysis.py

Input:
-data: path to the folder containing the datasets
-output: path to the output json file
-exclude-mayo: path to the file containing the list of subjects to exclude from the mayo dataset
Output:
None

Example:
python agregate_unannotated_data.py -data /path/to/data -output /path/to/output.json -exclude-mayo /path/to/exclude_mayo.yml

Author: Pierre-Louis Benveniste
"""
import os
import json
import argparse
from pathlib import Path
import yaml
from utils.image import Image
import numpy as np
from tqdm import tqdm


def parse_args():
parser = argparse.ArgumentParser(description="Aggregate unannotated data from multiple datasets into a json file.")
parser.add_argument("-data", type=str, required=True, help="Path to the folder containing the datasets.")
parser.add_argument("-output", type=str, required=True, help="Path to the output json file.")
parser.add_argument("-exclude-mayo", type=str, required=True, help="Path to the file containing the list of subjects to exclude from the mayo dataset.")
return parser.parse_args()


def get_acquisition_resolution_and_dimension(image_path, site):
"""
This function takes an image file as input and returns its acquisition, resolution and dimension.

Input:
image_path : str : Path to the image file

Returns:
acquisition : str : Acquisition of the image
orientation : str : Orientation of the image
resolution : list : Resolution of the image
dimension : list : Dimension of the image
field_strength : str : Field strength of the image
manufacturer : str : Manufacturer of the image
"""
img = Image(str(image_path))
img.change_orientation('RPI')
# Get the resolution
resolution = list(img.dim[4:7])
# Get image dimension
dimension = list(img.dim[0:3])

# Get image name
image_name = image_path.split('/')[-1]
if 'ax' in image_name:
orientation = 'ax'
elif 'sag' in image_name:
orientation = 'sag'
if '3D' in image_name:
acquisition = '3D'
if "mayo" in site:
acquisition = '2D'
orientation = 'ax'
field_strength = 'Missing'
manufacturer = 'Missing'

# Check if there is a json file
json_path = image_path.replace('.nii.gz', '.json')
if os.path.exists(json_path):
with open(json_path, 'r') as f:
json_data = json.load(f)
if 'SAG' in json_data.get('SeriesDescription') or 'sag' in json_data.get('SeriesDescription') or 'Sag' in json_data.get('SeriesDescription'):
orientation = 'sag'
elif 'AX' in json_data.get('SeriesDescription') or 'ax' in json_data.get('SeriesDescription') or 'Ax' in json_data.get('SeriesDescription'):
orientation = 'ax'
acquisition = json_data.get('MRAcquisitionType')
field_strength = json_data.get('MagneticFieldStrength')
manufacturer = json_data.get('Manufacturer')

return acquisition, orientation, resolution, dimension, field_strength, manufacturer


def main():
# Parse arguments
args = parse_args()
data_path = Path(args.data)
output_path = Path(args.output)
mayo_exclude = Path(args.exclude_mayo)

# If output directory does not exist, create it
os.makedirs(output_path, exist_ok=True)

# Load the exclude file from the mayo dataset
with open(mayo_exclude, 'r') as file:
mayo_exclude = yaml.load(file, Loader=yaml.FullLoader)
mayo_exclude = mayo_exclude['slice_motion'] + mayo_exclude['intensity_spikes'] + mayo_exclude['contrast_issues']

# List dataset paths
mayo_path = os.path.join(data_path, "ms-mayo-critical-lesions") # It contains T2w images
beijing_path = os.path.join(data_path, "ms-nmo-beijing") # It contains T1w images
path_umass_1 = os.path.join(data_path, 'umass-ms-ge-hdxt1.5')
path_umass_2 = os.path.join(data_path, 'umass-ms-ge-pioneer3')
path_umass_3 = os.path.join(data_path, 'umass-ms-siemens-espree1.5')
path_umass_4 = os.path.join(data_path, 'umass-ms-ge-excite1.5')

# Aggregate data
## MS-MAYO
imgs_mayo = list(Path(mayo_path).rglob('*_T2w.nii.gz'))
imgs_mayo = [i for i in imgs_mayo if 'derivatives' not in str(i)]
imgs_mayo = [i for i in imgs_mayo if str(i).split('/')[-1] not in mayo_exclude]
imgs_mayo = [str(i) for i in imgs_mayo]
print(f"Number of images in mayo dataset: {len(imgs_mayo)}")

## MS-NMO-BEIJING
imgs_beijing = list(Path(beijing_path).rglob('*acq-sag_*T1w.nii.gz')) # We add acq-sag_T1w (we leave a space in case multiple runs)
imgs_beijing += list(Path(beijing_path).rglob('*axTseRst_*T2w.nii.gz')) # We add axTseRst_T2w (we leave a space in case multiple runs)
imgs_beijing += list(Path(beijing_path).rglob('*sagTseRst_*T2w.nii.gz')) # We add sagTseRst_T2w (we leave a space in case multiple runs)
imgs_beijing = [i for i in imgs_beijing if 'ocalizer' not in str(i)]
imgs_beijing = [i for i in imgs_beijing if 'sub-MS' in str(i)]
imgs_beijing = [str(i) for i in imgs_beijing]
print(f"Number of images in beijing dataset: {len(imgs_beijing)}")

## UMASS 1
imgs_umass_1 = list(Path(path_umass_1).rglob('*_T1w.nii.gz')) # This is only for images with T1w (not designed for acq-...: there is not acq-... in this case)
imgs_umass_1 = [i for i in imgs_umass_1 if '_acq-' not in str(i)]
imgs_umass_1 = [i for i in imgs_umass_1 if 'ce-gad' not in str(i)]
imgs_umass_1 += list(Path(path_umass_1).rglob('*acq-FMPIR_T2w.nii.gz')) # We add acq-FMPIR_T2w (we leave a space in case multiple runs)
imgs_umass_1 += list(Path(path_umass_1).rglob('*acq-ax_T1w.nii.gz')) # We add acq-ax_T1w (we leave a space in case multiple runs)
imgs_umass_1 += list(Path(path_umass_1).rglob('*acq-ax_T2w.nii.gz')) # We add acq-ax_T2w (we leave a space in case multiple runs)
imgs_umass_1 = [i for i in imgs_umass_1 if 'derivatives' not in str(i)]
imgs_umass_1 = [str(i) for i in imgs_umass_1]
print(f"Number of images in umass_1 dataset: {len(imgs_umass_1)}")

## UMASS 2
imgs_umass_2 = list(Path(path_umass_2).rglob('*_T1w.nii.gz'))
imgs_umass_2 = [i for i in imgs_umass_2 if '_ce-gad' not in str(i)]
imgs_umass_2 = [i for i in imgs_umass_2 if 'acq-3D' not in str(i)]
imgs_umass_2 += list(Path(path_umass_2).rglob('*acq-3D_T1w.nii.gz')) # We add acq-3D_T1w (we leave a space in case multiple runs)
imgs_umass_2 += list(Path(path_umass_2).rglob('*acq-STIR_T2w.nii.gz')) # We add acq-STIR_T2w (we leave a space in case multiple runs)
imgs_umass_2 += list(Path(path_umass_2).rglob('*acq-axial_T2w.nii.gz')) # We add acq-axial_T2w (we leave a space in case multiple runs)
imgs_umass_2 = [i for i in imgs_umass_2 if 'derivatives' not in str(i)]
imgs_umass_2 = [i for i in imgs_umass_2 if 'SHA256' not in str(i)]
imgs_umass_2 = [str(i) for i in imgs_umass_2]
print(f"Number of images in umass_2 dataset: {len(imgs_umass_2)}")


## UMASS 3
imgs_umass_3 = list(Path(path_umass_3).rglob('*T1w.nii.gz'))
imgs_umass_3 = [i for i in imgs_umass_3 if '_ce-gad' not in str(i)]
imgs_umass_3 += list(Path(path_umass_3).rglob('*T2w.nii.gz')) # We add acq-3D_T1w (we leave a space in case multiple runs)
imgs_umass_3 = [i for i in imgs_umass_3 if 'acq-STIR'not in str(i) and 'acq-ax' not in str(i)] # We keep only acq-STIR and acq-ax (we leave a space in case multiple runs)
imgs_umass_3 += list(Path(path_umass_3).rglob('*acq-STIR_T2w.nii.gz')) # We add acq-STIR_T2w (we leave a space in case multiple runs)
imgs_umass_3 += list(Path(path_umass_3).rglob('*acq-ax_T2w.nii.gz')) # We add acq-ax_T2w (we leave a space in case multiple runs)
imgs_umass_3 = [i for i in imgs_umass_3 if 'derivatives' not in str(i)]
imgs_umass_3 = [i for i in imgs_umass_3 if 'SHA256' not in str(i)]
imgs_umass_3 = [str(i) for i in imgs_umass_3]
print(f"Number of images in umass_3 dataset: {len(imgs_umass_3)}")

## UMASS 4
imgs_umass_4 = list(Path(path_umass_4).rglob('*T1w.nii.gz'))
imgs_umass_4 = [i for i in imgs_umass_4 if '_ce-gad' not in str(i)]
imgs_umass_4 += list(Path(path_umass_4).rglob('*T2w.nii.gz')) # We add acq-3D_T1w (we leave a space in case multiple runs)
imgs_umass_4 = [i for i in imgs_umass_4 if 'acq-STIR'not in str(i) and 'acq-ax' not in str(i)] # We keep only acq-STIR and acq-ax (we leave a space in case multiple runs)
imgs_umass_4 += list(Path(path_umass_4).rglob('*acq-STIR_T2w.nii.gz')) # We add acq-STIR_T2w (we leave a space in case multiple runs)
imgs_umass_4 += list(Path(path_umass_4).rglob('*acq-ax_T2w.nii.gz')) # We add acq-ax_T2w (we leave a space in case multiple runs)
imgs_umass_4 = [str(i) for i in imgs_umass_4]
imgs_umass_4 = [i for i in imgs_umass_4 if 'SHA256' not in str(i)]
print(f"Number of images in umass_4 dataset: {len(imgs_umass_4)}")

# Aggregate all images:
all_imgs = imgs_mayo + imgs_beijing + imgs_umass_1 + imgs_umass_2 + imgs_umass_3 + imgs_umass_4

# Now we iterate over all images to create a dictionary with the required information
data_dict = {}
for img in tqdm(all_imgs):
# Get the subject ID:
subject_id = img.split('/')[-1].split('_')[0]
# Get site ID:
site = img.split('/data/')[-1].split('/')[0]
# Contrast:
contrast = img.split('_')[-1].replace('.nii.gz', '')
# Except for some particular cases:
if 'TseRst' in img:
contrast = 'TseRst_T2w'
elif 'FMPIR_T2w' in img:
contrast = 'FMPIR'
elif 'STIR_T2w' in img:
contrast = 'STIR'
acquisition, orientation, resolution, dimension, field_strength, manufacturer = get_acquisition_resolution_and_dimension(img, site)
resolution = [np.float64(i) for i in resolution]
img_info = {
"path": img,
"subject_id": subject_id,
"site": site,
"contrast": contrast,
"acquisition": acquisition,
"orientation": orientation,
"resolution": resolution,
"dimension": dimension,
"field_strength": field_strength,
"manufacturer": manufacturer
}

# add to the dictionary
data_dict[img] = img_info

# For each field, we print the unique values
print(f"Number of unique contrasts: {set([data_dict[k]['contrast'] for k in data_dict])}")
print(f"Number of unique acquisitions: {set([data_dict[k]['acquisition'] for k in data_dict])}")
print(f"Number of unique orientations: {set([data_dict[k]['orientation'] for k in data_dict])}")
print(f"Number of unique field strengths: {set([data_dict[k]['field_strength'] for k in data_dict])}")

# save the dictionary as a json file
json_file_path = os.path.join(output_path, 'unannotated_data.json')
with open(json_file_path, 'w') as f:
json.dump(data_dict, f, indent=4)
print(f"Data saved to {json_file_path}")

if __name__ == "__main__":
main()
151 changes: 151 additions & 0 deletions dataset_analysis/analyse_unannotated_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""
This file was created to analyze the msd dataset used for training and testing our dataset.
It takes as input the msd dataset and analysis the properties of the dataset.

Input:
--msd-data-path: path to the msd dataset in json format
--output-folder: path to the output folder where the analysis will be saved

Output:
None

Example:
python dataset_analysis/msd_data_analysis.py --msd-data-path /path/to/msd/data --output-folder /path/to/output/folder

Author: Pierre-Louis Benveniste
"""

import argparse
import os
import json
import numpy as np
from pathlib import Path
from tqdm import tqdm
from loguru import logger
import pandas as pd

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--msd-data-path', type=str, required=True, help='Path to the MSD dataset')
parser.add_argument('--output-folder', type=str, required=True, help='Path to the output folder')
return parser.parse_args()


def main():
# Parse arguments
args = parse_args()
msd_data_path = args.msd_data_path
output_folder = args.output_folder

# Build the output folder
os.makedirs(output_folder, exist_ok=True)

# Load the dataset
with open(msd_data_path, 'r') as f:
data = json.load(f)

# Create the logger file
log_file = os.path.join(output_folder, f'{Path(msd_data_path).name.split(".json")[0]}_analysis.txt')
# Clear the log file
with open(log_file, 'w') as f:
f.write('')
logger.add(log_file)

# Log some basic stuff
logger.info(f"MSD dataset: {Path(msd_data_path)}")
logger.info(f"Number of images: {len(data)}")

# Count the number of images per contrast
contrast_count = {}
for image in data:
image = data[image]
contrast = image['contrast']
if contrast not in contrast_count:
contrast_count[contrast] = 0
contrast_count[contrast] += 1
logger.info(f"Number of images per contrast: {contrast_count}")

# Count the number of images per site
site_count = {}
for image in data:
image = data[image]
site = image['site']
if site not in site_count:
site_count[site] = 0
site_count[site] += 1
logger.info(f"Number of images per site: {site_count}")

# We also count the number of subjects per site
subjects_per_site = {}
for image in data:
image = data[image]
dataset = image['site']
sub = image['subject_id']
subject = dataset + '/' + sub
if dataset not in subjects_per_site:
subjects_per_site[dataset] = set()
subjects_per_site[dataset].add(subject)
# Convert the sets to counts
for site in subjects_per_site:
subjects_per_site[site] = len(subjects_per_site[site])
logger.info(f"\n Number of subjects per site: {subjects_per_site}")

# Create a pandas DataFrame to store the data
df = pd.DataFrame(columns=['Site', 'Contrast', 'Acquisition', 'Orientation', 'Count', 'Avg resolution (R-L)', 'Avg resolution (P-A)', 'Avg resolution (I-S)', 'Number of subjects'])
## Add the data to the DataFrame
for image in data:
image = data[image]
dataset = image['site']
contrast = image['contrast']
acquisition = image['acquisition']
orientation = image['orientation']
resolution = image['resolution']
field_strength = image['field_strength']
# Add the data to the DataFrame
new_row = {
'Site': dataset,
'Contrast': contrast,
'Acquisition': acquisition,
'Orientation': orientation,
'Count': 1,
'Avg resolution (R-L)': resolution[0],
'Std resolution (R-L)': resolution[0],
'Avg resolution (P-A)': resolution[1],
'Std resolution (P-A)': resolution[1],
'Avg resolution (I-S)': resolution[2],
'Std resolution (I-S)': resolution[2],
'Number of subjects': subject,
'Field strength': field_strength
}
df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
# Group the DataFrame by Site, Contrast, Acquisition, Orientation and sum the Count and Number of subjects and average the Avg resolution (RPI)
df_grouped = df.groupby(['Site', 'Contrast', 'Acquisition', 'Orientation', 'Field strength']).agg({
'Count': 'sum',
'Avg resolution (R-L)': 'mean',
'Std resolution (R-L)': 'std',
'Avg resolution (P-A)': 'mean',
'Std resolution (P-A)': 'std',
'Avg resolution (I-S)': 'mean',
'Std resolution (I-S)': 'std',
'Number of subjects': 'nunique'
})
# Reset the index
df_grouped = df_grouped.reset_index()
# We add the number of subjects per site
subjects_per_site_series = pd.Series(subjects_per_site, name='# Participants')
df_grouped = df_grouped.merge(subjects_per_site_series, left_on='Site', right_index=True, how='left')
# Reorder the columns
df_grouped = df_grouped[['Site','# Participants','Field strength', 'Contrast', 'Acquisition', 'Orientation', 'Avg resolution (R-L)', 'Std resolution (R-L)', 'Avg resolution (P-A)', 'Std resolution (P-A)', 'Avg resolution (I-S)', 'Std resolution (I-S)', 'Count']]
# Log the DataFrame
logger.info("DataFrame with the number of images per site, contrast, acquisition, orientation and field strength:")
logger.info(df_grouped.to_string(index=False))

# Also save the DataFrame to a csv file
csv_file = os.path.join(output_folder, 'csv_file.csv')
df_grouped.to_csv(csv_file, index=False)

return None


if __name__ == '__main__':
main()
Loading