Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
75a5f85
Add notebook debugging SSL pretraining transformations
valosekj Mar 8, 2024
e37ff59
Fix typo
valosekj Mar 8, 2024
80fd7db
Add script for Self-Supervised Pre-training using Vision Transformer …
valosekj Mar 8, 2024
4dc7e81
Fix typo
valosekj Mar 8, 2024
3a9baef
Do sanity dimension check after transformations
valosekj Mar 8, 2024
4517f7a
Use `patch_size=(16, 16, 16)` and `batch_size = 2`
valosekj Mar 8, 2024
a621a9a
rename 'crop_size' to 'spatial_size'
valosekj Mar 9, 2024
516f1ec
use 'contrast-agnostic' data augmentation for Training of the fine-tu…
valosekj Mar 9, 2024
fd8f943
Add script to finetune the 3D Single-Class Spinal Cord Lesion Segment…
valosekj Mar 9, 2024
d1b9c3d
update description
valosekj Mar 9, 2024
45ef357
use 'logger.info' instead of 'print'
valosekj Mar 9, 2024
0955689
make '--data' arg notrequired
valosekj Mar 9, 2024
9ea8bca
explicitly specify CUDA GPU number
valosekj Mar 9, 2024
3bc999a
change sliding_window_inference roi_size and batch_size
valosekj Mar 9, 2024
8e8a977
remove 'to_onehot_y=True' and 'softmax=True' because we work with sin…
valosekj Mar 9, 2024
42394e8
add notebook to debug transformations when using SC masks to crop pat…
valosekj Mar 12, 2024
622c3f6
rerun the notebook to show different slices
valosekj Mar 12, 2024
59f4db0
Add 'keys' arg
valosekj Mar 12, 2024
a3ae050
Add 'keys' arg
valosekj Mar 12, 2024
e1cfad1
Update 'RandCoarseDropoutd' params
valosekj Mar 12, 2024
d4e9995
Update comments
valosekj Mar 12, 2024
4c35d42
Fix imports
valosekj Mar 12, 2024
05ab59a
Change batch_size and NUM_WORKERS to 4
valosekj Mar 12, 2024
0b12c10
Make '--data' non required
valosekj Mar 12, 2024
f5881aa
change 'num_workers' to 0 to prevent 'RuntimeError: received 0 items …
valosekj Mar 12, 2024
7aa9858
Update input arg description
valosekj Mar 13, 2024
16691d7
Rerun the notebook
valosekj Mar 13, 2024
56ca187
Print hyper-parameters into the log file
valosekj Mar 13, 2024
ce7315a
Add '--cuda' input arg
valosekj Mar 13, 2024
1ebf833
track and save epoch time
valosekj Mar 15, 2024
208b2c1
Plot and save input and output validation images to see how the model…
valosekj Mar 15, 2024
5a9e69a
Add comment
valosekj Mar 15, 2024
eca6616
Do not plot 'outputs_v2' as it is a hidden representation
valosekj Mar 15, 2024
8e55985
Include the epoch number as master title
valosekj Mar 15, 2024
e7575c5
Add 'torch.multiprocessing.set_sharing_strategy('file_system')' to so…
valosekj Mar 15, 2024
3b6df14
Create validation_figures directory if it does not exist
valosekj Mar 15, 2024
59b061c
Use 3 leading zeros for the epoch number in the figures fname
valosekj Mar 15, 2024
f2d648b
Link issue
valosekj Mar 15, 2024
a6bc6ec
Add note for 'RandCropByPosNegLabeld'
valosekj Mar 15, 2024
c1553d6
Add 'number_of_holes' arg to specify the number of holes to be used f…
valosekj Mar 15, 2024
0e802df
typo
valosekj Mar 16, 2024
9255855
batch_size = 8
valosekj Mar 16, 2024
8296913
NUM_WORKERS = batch_size
valosekj Mar 16, 2024
93c5415
number_of_holes=5
valosekj Mar 16, 2024
a4182b5
Update transforms for training of the fine-tuned model
valosekj Mar 16, 2024
2573578
update comment, remove unused imports
valosekj Mar 16, 2024
cfb8f12
Add notebook with RandCoarseDropoutd transform debug
valosekj Mar 16, 2024
7f27340
Fix 'dropout_holes=True' and 'dropout_holes=False' comments
valosekj Mar 17, 2024
74b2f2c
remove unused 'max_spatial_size' arg
valosekj Mar 17, 2024
bdb2ac1
use 'fill_value=0' for 'RandCoarseDropoutd'
valosekj Mar 17, 2024
f7f3689
Plot also RandCoarseDropoutd dropout_holes=False fill_value=0
valosekj Mar 17, 2024
af9c6d0
Add note that the batch size is actually doubled (8*2=16), because we…
valosekj Mar 17, 2024
b416135
Add '--cuda' input arg
valosekj Mar 18, 2024
0303776
Remove 'AsDiscrete'
valosekj Mar 18, 2024
4534c57
Remove 'AsDiscrete'
valosekj Mar 18, 2024
94c8611
Add 'CUDA_NUM=args.cuda'
valosekj Mar 18, 2024
5fb325a
Add TODO to increase batch_size to 16
valosekj Mar 18, 2024
d1a03c8
Use 'roi_size' for 'define_finetune_train_transforms'
valosekj Mar 18, 2024
e3d8086
Use 'label_sc' to crop samples around the SC
valosekj Mar 18, 2024
7e628af
batch_size = 8
valosekj Mar 18, 2024
f832c94
NUM_WORKERS = batch_size
valosekj Mar 18, 2024
c68d65d
Add 'import torch.multiprocessing'
valosekj Mar 18, 2024
a7c7c77
Fix shape logging
valosekj Mar 18, 2024
1c6109c
Change 'img_size' to 'ROI_SIZE'
valosekj Mar 18, 2024
b92e7d5
'batch["label"]' --> 'batch["label_lesion"]'
valosekj Mar 18, 2024
f174fb9
Plot and save input and output validation images to see how the model…
valosekj Mar 18, 2024
532e24b
Fix ROI_SIZE for sliding_window_inference
valosekj Mar 18, 2024
5c10890
Crop samples of 64x64x64 also for Validation of the fine-tuned model
valosekj Mar 18, 2024
d33abfa
update docstring
valosekj Mar 18, 2024
efeccf9
log validation samples shapes
valosekj Mar 18, 2024
f34ece1
Save validation figure only if it contains a lesion
valosekj Mar 18, 2024
e0afafe
Plot GT together with image
valosekj Mar 19, 2024
e588a22
print unique values in the slice to see if it is binary
valosekj Mar 19, 2024
2d0b530
update output fig fname
valosekj Mar 19, 2024
988680e
Add 'AsDiscreted' for Training and Validation of the fine-tuned model
valosekj Mar 19, 2024
f675611
threshold val_labels_list and val_outputs_list by 0.5 threshold befor…
valosekj Mar 19, 2024
cc78896
add normalized relu normalization
naga-karthik Mar 19, 2024
92334ea
fix binarization bug
naga-karthik Mar 19, 2024
2c49ca2
remove 'logger.info(np.unique(output.detach().cpu().numpy()))'
valosekj Mar 19, 2024
fbe3dd7
overlay prediction over input image
valosekj Mar 19, 2024
ba39d28
Fix variable when getting probabilities from logits
valosekj Mar 20, 2024
c8d9a5d
Add debug lines
valosekj Mar 20, 2024
ac040a5
PEP8
valosekj Mar 20, 2024
1e2b61d
Set validation batch_size to 1
valosekj Mar 20, 2024
8f49545
improve comments
valosekj Mar 20, 2024
3a74e28
fix figure title
valosekj Mar 20, 2024
dea121c
comment 'AsDiscreted' transforms
valosekj Mar 20, 2024
013ca65
Make '--pretrained-model' non required to allow training from the scr…
valosekj Mar 21, 2024
f1aae1b
Add script to create spine-generic MSD dataset
valosekj Mar 22, 2024
e4c87be
run notebook again
valosekj Mar 22, 2024
4448356
Make 'create_msd_data.py' compatible with other BIDS datasets
valosekj Mar 23, 2024
a354a21
Add note that no testing set is created
valosekj Mar 23, 2024
141c96b
Add README with instructions on how to download T2w images from multi…
valosekj Mar 24, 2024
7f8e6bb
fix typo
valosekj Mar 24, 2024
46f1bd9
remove unused imports
valosekj Mar 24, 2024
afeb7c5
fix sc suffix for sci-paris
valosekj Mar 24, 2024
da1c131
Use os.path.abspath for 'args.path_data' and 'args.path_out'
valosekj Mar 24, 2024
ed42c68
Update logging message
valosekj Mar 24, 2024
7f28fed
Fix '.replace' to prevent '//' in the output string
valosekj Mar 24, 2024
99a936b
Add 'create_msd_data.py' commands to create MSD-style JSON datalists
valosekj Mar 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions vit_unetr_ssl/README_multiple_datasets_pretraining.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Pre-training on multiple datasets

### Download datasets

Download T2w images and spinal cord segmentations for the following datasets.

```commandline
cd ~/duke/temp/janvalosek/ssl_pretraining_multiple_datasets
```

`spine-generic multi-subject` (n=267)

```commandline
git clone https://github.com/spine-generic/data-multi-subject
cd data-multi-subject
git checkout sb/156-add-preprocessed-images
git annex get $(find . -name "*space-other_T2w.nii.gz")
git annex get $(find . -name "*space-other_T2w_label-SC_seg.nii.gz")
```


`canproco` (n=413)

```commandline
git clone [email protected]:datasets/canproco
cd canproco
git annex dead here
git annex get $(find . -name "*ses-M0_T2w.nii.gz")
git annex get $(find . -name "*ses-M0_T2w_seg-manual.nii.gz")
```

`sci-colorado` (n=80)

```commandline
git clone [email protected]:datasets/sci-colorado
cd sci-colorado
git annex dead here
git annex get $(find . -name "*T2w.nii.gz")
git annex get $(find . -name "*T2w_seg-manual.nii.gz")
```

`dcm-zurich` (n=135)

```commandline
git clone [email protected]:datasets/dcm-zurich
cd dcm-zurich
git annex dead here
git annex get $(find . -name "*acq-axial_T2w.nii.gz")
git annex get $(find . -name "*acq-axial_T2w_label-SC_mask-manual.nii.gz")
```

`sci-paris` (n=14)

```commandline
git clone [email protected]:datasets/sci-paris
cd sci-paris
git annex dead here
git annex get $(find . -name "*T2w.nii.gz")
git annex get $(find . -name "*T2w_seg.nii.gz")
```

### Create MSD-style JSON datalists

```commandline
python /Users/user/code/model-seg-dcm/vit_unetr_ssl/create_msd_data.py --path-data data-multi-subject --dataset-name spine-generic --path-out . --split 0.8 0.2 --seed 42
python /Users/user/code/model-seg-dcm/vit_unetr_ssl/create_msd_data.py --path-data canproco --dataset-name canproco --path-out . --split 0.8 0.2 --seed 42
python /Users/user/code/model-seg-dcm/vit_unetr_ssl/create_msd_data.py --path-data sci-colorado --dataset-name sci-colorado --path-out . --split 0.8 0.2 --seed 42
python /Users/user/code/model-seg-dcm/vit_unetr_ssl/create_msd_data.py --path-data dcm-zurich --dataset-name dcm-zurich --path-out . --split 0.8 0.2 --seed 42
python /Users/user/code/model-seg-dcm/vit_unetr_ssl/create_msd_data.py --path-data sci-paris --dataset-name sci-paris --path-out . --split 0.8 0.2 --seed 42
```
181 changes: 181 additions & 0 deletions vit_unetr_ssl/create_msd_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
"""
Create MSD-style JSON datalist file for BIDS datasets.
The following two keys are included in the JSON file: 'image' and 'label_sc'.

NOTE: the script is meant to be used for pre-training, meaning that the dataset is split into training and validation.
In other words, NO testing set is created.

The script has to be run for each dataset separately, meaning that one JSON file is created for each dataset.

Example usage:
python create_msd_data.py
--path-data /Users/user/data/spine-generic
--dataset-name spine-generic
--path-out /Users/user/data/spine-generic

python create_msd_data.py
--path-data /Users/user/data/dcm-zurich
--dataset-name dcm-zurich
--path-out /Users/user/data/dcm-zurich
"""

import os
import json
import argparse
from pathlib import Path
from loguru import logger
from sklearn.model_selection import train_test_split

contrast_dict = {
'spine-generic': 'space-other_T2w', # iso T2w (preprocessed data)
'canproco': 'ses-M0_T2w', # iso T2w (session M0)
'dcm-zurich': 'acq-axial_T2w', # axial T2w
'sci-paris': 'T2w', # iso T2w
'sci-colorado': 'T2w' # axial T2w
}

# Spinal cord segmentation file suffixes for different datasets
sc_fname_suffix_dict = {
'spine-generic': 'label-SC_seg',
'canproco': 'seg-manual',
'dcm-zurich': 'label-SC_mask-manual',
'sci-paris': 'seg-manual',
'sci-colorado': 'seg-manual'
}


def get_parser():
parser = argparse.ArgumentParser(description='Create MSD-style JSON datalist file for BIDS datasets.')

parser.add_argument('--path-data', required=True, type=str,
help='Path to BIDS dataset. Example: /Users/user/data/dcm-zurich')
parser.add_argument('--dataset-name', required=True, type=str,
help='Name of the dataset. Example: spine-generic or dcm-zurich.')
parser.add_argument('--path-out', type=str, required=True,
help='Path to the output directory where dataset json is saved')
parser.add_argument('--split', nargs='+', type=float, default=[0.8, 0.2],
help='Ratios of training and validation 0-1. '
'Example: --split 0.8 0.2')
parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility")

return parser


def main():
args = get_parser().parse_args()

dataset = os.path.abspath(args.path_data)
dataset_name = args.dataset_name
train_ratio, val_ratio = args.split
seed = args.seed
path_out = os.path.abspath(args.path_out)

# Check if the dataset name is valid
if dataset_name not in contrast_dict.keys():
raise ValueError(f"Dataset name {dataset_name} is not valid. Choose from {contrast_dict.keys()}")

contrast = contrast_dict[dataset_name]
sc_fname_suffix = sc_fname_suffix_dict[dataset_name]
datalist_fname = f"{dataset_name}_seed{seed}"

train_images, val_images = {}, {}

# For spine-generic, we add 'derivatives/data_preprocessed' to the path to use the preprocessed data with the same
# resolution and orientation as the spinal cord segmentations
if dataset_name == 'spine-generic':
root = Path(dataset) / 'derivatives/data_preprocessed'
else:
root = Path(dataset)
# Path to 'derivatives/labels with spinal cord segmentations
labels = Path(dataset) / 'derivatives/labels'

# Check if the dataset path exists
if not os.path.exists(root):
raise ValueError(f"Path {root} does not exist.")
if not os.path.exists(labels):
raise ValueError(f"Path {labels} does not exist.")

logger.info(f"Root path: {root}")
logger.info(f"Labels path: {labels}")

# get recursively all the subjects from the root folder
subjects = [sub for sub in os.listdir(root) if sub.startswith("sub-")]

# Get the training and validation splits
# Note: we are doing SSL pre-training, so we don't need test set
tr_subs, val_subs = train_test_split(subjects, test_size=val_ratio, random_state=args.seed)

# recursively find the spinal cord segmentation files under 'derivatives/labels' for training and validation
# subjects
tr_seg_files = [str(path) for sub in tr_subs for path in
Path(labels).rglob(f"{sub}_{contrast}_{sc_fname_suffix}.nii.gz")]
val_seg_files = [str(path) for sub in val_subs for path in
Path(labels).rglob(f"{sub}_{contrast}_{sc_fname_suffix}.nii.gz")]

# update the train and validation images dicts with the key as the subject and value as the path to the subject
train_images.update({sub: os.path.join(root, sub) for sub in tr_seg_files})
val_images.update({sub: os.path.join(root, sub) for sub in val_seg_files})

logger.info(f"Found subjects in the training set: {len(train_images)}")
logger.info(f"Found subjects in the validation set: {len(val_images)}")

# keys to be defined in the dataset_0.json
params = {}
params["dataset_name"] = dataset_name
params["contrast"] = contrast
params["labels"] = {
"0": "background",
"1": "sc-seg"
}
params["modality"] = {
"0": "MRI"
}
params["numTraining"] = len(train_images)
params["numValidation"] = len(val_images)
params["seed"] = args.seed
params["tensorImageSize"] = "3D"

train_images_dict = {"training": train_images}
val_images_dict = {"validation": val_images}

all_images_list = [train_images_dict, val_images_dict]

for images_dict in all_images_list:

for name, images_list in images_dict.items():

temp_list = []
for label in images_list:

temp_data_t2w = {}
# create the image path by replacing the label path
if dataset_name == 'spine-generic':
temp_data_t2w["image"] = label.replace(f'_{sc_fname_suffix}', '').replace('labels',
'data_preprocessed')
else:
temp_data_t2w["image"] = label.replace(f'_{sc_fname_suffix}', '').replace('/derivatives/labels', '')

# Spinal cord segmentation file
temp_data_t2w["label_sc"] = label

if os.path.exists(temp_data_t2w["label_sc"]) and os.path.exists(temp_data_t2w["image"]):
temp_list.append(temp_data_t2w)
else:
logger.info(f"Either image/label does not exist.")

params[name] = temp_list
logger.info(f"Number of images in {name} set: {len(temp_list)}")

final_json = json.dumps(params, indent=4, sort_keys=False)
if not os.path.exists(path_out):
os.makedirs(path_out, exist_ok=True)

jsonFile = open(path_out + "/" + f"{datalist_fname}.json", "w")
jsonFile.write(final_json)
jsonFile.close()
print(f"JSON file saved to {path_out}/{datalist_fname}.json")


if __name__ == "__main__":
main()

Loading