diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..de46d15 --- /dev/null +++ b/.gitignore @@ -0,0 +1,14 @@ +totalspineseg/ +__pycache__ +output.csv +dataset_split.csv +model* +*.pth +20* +models/ +**/wandb/ +**/*run*/ +myenv/ +saved_batches/ +*.csv +mil_model_* \ No newline at end of file diff --git a/MIL_training/augment.py b/MIL_training/augment.py new file mode 100644 index 0000000..c549bae --- /dev/null +++ b/MIL_training/augment.py @@ -0,0 +1,224 @@ +import sys, argparse, textwrap +import multiprocessing as mp +from functools import partial +from tqdm.contrib.concurrent import process_map +from pathlib import Path +import nibabel as nib +import numpy as np +import torchio as tio +import gryds +import scipy.ndimage as ndi +from scipy.stats import norm +import warnings +from augment import * + + + +warnings.filterwarnings("ignore") + + +def aug_histogram_equalization(image1, seg, image2): + img_min1, img_max1 = image1.min(), image1.max() + img_min2, img_max2 = image2.min(), image2.max() + + image1_flattened = image1.flatten() + hist1, bins1 = np.histogram(image1_flattened, bins=256, range=[image1_flattened.min(), image1_flattened.max()]) + cdf1 = hist1.cumsum() + cdf_normalized1 = cdf1 * (hist1.max() / cdf1.max()) + image1 = np.interp(image1_flattened, bins1[:-1], cdf_normalized1).reshape(image1.shape) + image1 = np.interp(image1, (image1.min(), image1.max()), (img_min1, img_max1)) + + image2_flattened = image2.flatten() + hist2, bins2 = np.histogram(image2_flattened, bins=256, range=[image2_flattened.min(), image2_flattened.max()]) + cdf2 = hist2.cumsum() + cdf_normalized2 = cdf2 * (hist2.max() / cdf2.max()) + image2 = np.interp(image2_flattened, bins2[:-1], cdf_normalized2).reshape(image2.shape) + image2 = np.interp(image2, (image2.min(), image2.max()), (img_min2, img_max2)) + + return image1, seg, image2 + +def aug_transform(image1, transform): + img_min1, img_max1 = image1.min(), image1.max() + + + image1 = (image1 - image1.mean()) / image1.std() + image1 = np.interp(image1, (image1.min(), image1.max()), (0, 1)) + + image1 = transform(image1) + + + image1 = np.interp(image1, (image1.min(), image1.max()), (img_min1, img_max1)) + + + return image1 + +def aug_log(image1): + return aug_transform(image1, lambda x: np.log(1 + x)) + +def aug_sqrt(image1 ): + return aug_transform(image1, np.sqrt) + +def aug_sin(image1): + return aug_transform(image1, np.sin) + +def aug_exp(image1): + return aug_transform(image1, np.exp) + +def aug_sig(image1): + return aug_transform(image1, lambda x: 1 / (1 + np.exp(-x))) + +def aug_laplace(image1): + return aug_transform(image1, lambda x: np.abs(ndi.laplace(x))) + +def aug_inverse(image1): + image1 = image1.min() + image1.max() - image1 + return image1 + +def aug_bspline(image1, seg, image2): + grid = rs.rand(3, 3, 3, 3) + bspline = gryds.BSplineTransformation((grid - .5) / 5) + grid[:, 0] += ((grid[:, 0] > 0) * 2 - 1) * .9 + image1 = gryds.Interpolator(image1).transform(bspline).astype(np.float64) + image2 = gryds.Interpolator(image2).transform(bspline).astype(np.float64) + seg = gryds.Interpolator(seg, order=0).transform(bspline).astype(np.uint8) + return image1, seg, image2 + +def aug_flip(image1, seg, image2): + subject = tio.RandomFlip(axes=('LR',))(tio.Subject( + image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)), + seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)), + image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0)) + )) + return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64) + +def aug_aff(image1, seg, image2): + subject = tio.RandomAffine()(tio.Subject( + image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)), + seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)), + image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0)) + )) + return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64) + +def aug_elastic(image1, seg, image2): + subject = tio.RandomElasticDeformation(max_displacement=40)(tio.Subject( + image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)), + seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)), + image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0)) + )) + return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64) + +def aug_anisotropy(image1, seg, image2, downsampling=7): + subject = tio.RandomAnisotropy(downsampling=downsampling)(tio.Subject( + image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)), + seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)), + image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0)) + )) + return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64) + +def aug_motion(image1, seg, image2): + subject = tio.RandomMotion()(tio.Subject( + image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)), + seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)), + image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0)) + )) + return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64) + +def aug_ghosting(image1, seg, image2): + subject = tio.RandomGhosting()(tio.Subject( + image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)), + seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)), + image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0)) + )) + return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64) + +def aug_spike(image1, seg, image2): + subject = tio.RandomSpike(intensity=(1, 2))(tio.Subject( + image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)), + seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)), + image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0)) + )) + return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64) + +def aug_bias_field(image1, image2, seg): + subject = tio.RandomBiasField()(tio.Subject( + image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)), + seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)), + image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0)) + )) + return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64) + +def aug_blur(image1, seg, image2): + subject = tio.RandomBlur()(tio.Subject( + image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)), + seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)), + image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0)) + )) + return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64) + +def aug_noise(image1, seg, image2): + original_mean1, original_std1 = np.mean(image1), np.std(image1) + original_mean2, original_std2 = np.mean(image2), np.std(image2) + + image1 = (image1 - original_mean1) / original_std1 + image2 = (image2 - original_mean2) / original_std2 + + subject = tio.RandomNoise()(tio.Subject( + image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)), + seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)), + image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0)) + )) + image1 = image1 * original_std1 + original_mean1 + image2 = image2 * original_std2 + original_mean2 + + return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64) + +def aug_swap(image1, seg, image2): + subject = tio.RandomSwap()(tio.Subject( + image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)), + seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)), + image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0)) + )) + return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64) + +def aug_labels2image(image1, seg, image2, leave_background=0.5, classes=None): + _seg = seg + if classes: + _seg = combine_classes(seg, classes) + subject = tio.RandomLabelsToImage(label_key="seg", image_key="image")(tio.Subject( + seg=tio.LabelMap(tensor=np.expand_dims(_seg, axis=0)) + )) + new_img = subject.image.data.squeeze().numpy().astype(np.float64) + + if rs.rand() < leave_background: + img_min1, img_max1 = np.min(image1), np.max(image1) + _image1 = (image1 - img_min1) / (img_max1 - img_min1) + + new_img_min, new_img_max = np.min(new_img), np.max(new_img) + new_img = (new_img - new_img_min) / (new_img_max - new_img_min) + new_img[_seg == 0] = _image1[_seg == 0] + new_img = np.interp(new_img, (new_img.min(), new_img.max()), (img_min1, img_max1)) + + return new_img, seg, image2 + +def aug_gamma(image1, seg, image2): + subject = tio.RandomGamma()(tio.Subject( + image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)), + seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)), + image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0)) + )) + return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64) + + +def parse_class(c): + c = [_.split('-') for _ in c.split(',')] + c = tuple(__ for _ in c for __ in list(range(int(_[0]), int(_[-1]) + 1))) + return c + +def combine_classes(seg, classes): + _seg = np.zeros_like(seg) + for i, c in enumerate(classes): + _seg[np.isin(seg, c)] = i + 1 + return _seg + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MIL_training/mil_definition.py b/MIL_training/mil_definition.py new file mode 100644 index 0000000..dab321b --- /dev/null +++ b/MIL_training/mil_definition.py @@ -0,0 +1,101 @@ +''' +File to introduce a MIL model +Note that loads of hyperparameters could be included as arguments +It could be avg pooling size, hidden dim, etc... +Also encoder could be changed to a different model +''' + +import torch +import torch.nn as nn +import timm + + + +class MILsection(nn.Module): + def __init__(self, input_dim, hidden_dim, num_classes, num_layers=1): + super(MILsection, self).__init__() + self.num_layers = num_layers + # RNN cells, here we use a GRU + if num_layers > 0: + self.rnn = nn.GRU(input_dim, input_dim//2, num_layers=num_layers, + batch_first=True, dropout=0.1, bidirectional=True) + # attention layer, here we use a simple linear layer + self.attention = nn.Sequential( + nn.Tanh(), + nn.Linear(input_dim, 1) + ) + + def forward(self, bags): + """ + Args: + bags: (batch_size, num_instances, input_dim) + + Returns: + logits: (batch_size, num_classes) + """ + + # bags iterates in the RNN + if self.num_layers > 0: + bags_rnn, _ = self.rnn(bags) + else: + bags_rnn = bags + + # main attention + attn_scores = self.attention(bags_rnn).squeeze(-1) # [batch_size, num_instances] + attn_weights = torch.softmax(attn_scores, dim=-1) # [batch_size, num_instances] + weighted_instances = torch.bmm(attn_weights.unsqueeze(1), bags_rnn).squeeze(1) # [batch_size, input_dim] + + return weighted_instances + + +class MILmodel(nn.Module): + # attention, a new encoder instance has to be created at each MIL creation + # if not, the same encoder could be used for all MIL sections + def __init__(self, encoder, num_layers=1): + super(MILmodel, self).__init__() + # encoder + self.encoder = encoder + # flattening layer, applying pooling and flattening + # note here that we could try different pooling methods + self.flatten = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Flatten(1) + ) + # size of the features after encoding + # encoder outputs n feature maps + # that are flattened into a vector of size n + self.feature_size = self.encoder.num_features + + # MIL section, loads of hyperparameters here also + self.mil_section = MILsection(input_dim=self.feature_size, + hidden_dim=1024, + num_classes=3, + num_layers=num_layers) + # classifier output + # we use a final simple linear layer to output the final prediction + self.classifier = nn.Linear(self.feature_size, 3) + + def forward(self, x): + # x shape: (batch_size, 6, 1, 384, 384) + batch_size, num_instances, channels, H, W = x.shape + + # Reshape to process all instances through encoder + x = x.reshape(-1, channels, H, W) # shape: (batch_size * 6, 1, 384, 384) + + # Pass through encoder + x = self.encoder.forward_features(x) # shape: (batch_size * 6, feature_size, h', w') + + # Apply pooling and flatten + x = self.flatten(x) # shape: (batch_size * 6, feature_size) + + # Reshape back to separate instances + x = x.reshape(batch_size, num_instances, self.feature_size) # shape: (batch_size, 6, feature_size) + + # Pass through MIL section + weighted_instances = self.mil_section(x) + # weighted_instances: (batch_size, feature_size) + + # classification output + output = self.classifier(weighted_instances) # shape: (batch_size, 3) + + return output diff --git a/MIL_training/prepare_nfn.py b/MIL_training/prepare_nfn.py new file mode 100644 index 0000000..dcb8276 --- /dev/null +++ b/MIL_training/prepare_nfn.py @@ -0,0 +1,179 @@ +''' define the transformations to prepare the data for the NFN model + from the raw patches we obtain from the preprocessing pipeline +''' + +import os +import pandas as pd +import torchio as tio +from monai.transforms import ( + Compose, LoadImaged, EnsureChannelFirstd, ScaleIntensityd, ConcatItemsd, + ToTensord, SpatialPadd, CenterSpatialCropd, NormalizeIntensityd, + RandRotated, RandSpatialCropd, RandBiasFieldd, Lambdad, Transform, + RandGaussianNoised, RandAffined, RandZoomd, Rand3DElasticd, Flipd, + SpatialCropd, Spacingd, RandLambdad, ResizeWithPadOrCropd, + RandGaussianSharpend, CenterSpatialCropd, RandScaleIntensityd, RandFlipd +) + +from torch.utils.data import DataLoader, ConcatDataset +from monai.data import Dataset +import matplotlib.pyplot as plt +import nibabel as nib +from augment import * + + + +# custom transform to extract slices from the 3D image +# and put them in the MIL bag format +class ExtractSlicesD_nfn(Transform): + def __init__(self, keys=['image'], target_size=(384, 384), verbose=False): + self.keys = keys + self.target_size = target_size + self.resize = tio.Resize(target_shape=(*target_size, 1)) + self.verbose = verbose + + def __call__(self, data): + d = dict(data) + + for key in self.keys: + # Get image and remove channel dimension (1, X, Y, 6) -> (X, Y, 6) + image = d[key].squeeze(0) + for i in range(image.shape[0]): + # Extract slice, add channel dim for torchio, + # resize, then normalize + slice_2d = image[i, :, :] + slice_3d = slice_2d.unsqueeze(0).unsqueeze(-1) + if self.verbose: + print(f"Shape before resize: {slice_3d.shape}") + slice_resized = self.resize(slice_3d) + if self.verbose: + print(f"Shape after resize: {slice_resized.shape}") + # Remove the z dimension that we added + slice_final = slice_resized.squeeze(-1) + d[f'slice_{i}'] = slice_final + if self.verbose: + print(f"Final slice {i} shape: {slice_final.shape}") + return d + + +# whole transforms for dataloading +def get_transforms_nfn(mode='basic'): + + regular_transforms = Compose([ + LoadImaged(keys=['image']), + EnsureChannelFirstd(keys=["image"]), + Spacingd(keys=['image'], pixdim=(4.0, 0.4, 0.4), mode=('bilinear')), + ]) + + if mode == 'basic': + common_transforms = Compose([ + CenterSpatialCropd(keys=['image'],roi_size=(6, 100, 100)), + ResizeWithPadOrCropd(keys=['image'], spatial_size=(6, 100, 100)), + NormalizeIntensityd(keys=['image'],nonzero=True), + ]) + + elif mode == 'random': # for training ! + # Same transforms but with random augmentations + common_transforms = Compose([ + RandRotated(keys=['image'], prob=0.5, range_x=0.1), + RandFlipd(keys=['image'], prob=0.5, spatial_axis=0), + RandSpatialCropd(keys=['image'], roi_size=(6,100, 100), random_size=False), + RandLambdad(keys=['image'],func=aug_sqrt,prob=0.05,), + RandLambdad(keys=['image'],func=aug_sin,prob=0.05,), + RandLambdad(keys=['image'],func=aug_exp,prob=0.05,), + RandLambdad(keys=['image'],func=aug_sig,prob=0.05, ), + RandLambdad(keys=['image'],func=aug_laplace,prob=0.05,), + RandLambdad(keys=['image'],func=aug_inverse,prob=0.05, ), + RandBiasFieldd(keys=['image'],prob=0.05), + RandAffined(keys=['image'],prob=0.05, padding_mode="zeros", mode=["bilinear"]), + + RandGaussianNoised(keys=['image'], mean=0.0, std=0.1, prob=0.05), + RandGaussianSharpend(keys=['image'], prob=0.05), + + #Rand3DElasticd(keys=['T1'],prob=0.05, padding_mode="zeros", mode=["bilinear"], sigma_range=(5,7), magnitude_range=(50,150)), + + ResizeWithPadOrCropd(keys=['image'], spatial_size=(6, 100, 100)), + RandScaleIntensityd(keys=['image'], factors=(0.8, 1.2), prob=1), + ]) + + # Create list of transforms for processing 2D slices + slice_transforms = Compose([ + # Custom transform to extract and resize slices + ExtractSlicesD_nfn(keys=['image'], target_size=(100, 100)), + # Ensure all slices are tensors + ToTensord( + keys=[f'slice_{i}' for i in range(6)] + ), + # Concatenate all slices into a bag + ConcatItemsd( + keys=[f'slice_{i}' for i in range(6)], + name='bag', + dim=0 + ), + # Add a transform to ensure bag has the correct shape + Lambdad( + keys=['bag'], + func=lambda x: x.reshape(6, 1, 100, 100) + ) + ]) + + # Combine common_transforms with slice_transforms + transforms = Compose([regular_transforms, common_transforms, slice_transforms]) + + return transforms + +# prepare the data for the NFN model, +# random is for training, basic for validation +# returns a ConcatDataset of the left and right data +def prepare_data_nfn(data_dir, csv_file, random=True): + data = [] + + labels_df = pd.read_csv(csv_file) + + counter = 0 + # Label conversion dictionary + text2int = {"Normal/Mild": 0, "Moderate": 1, "Severe": 2} + + for subject in os.listdir(data_dir): + #print(subject) + '''if counter >40: + break''' + subject_dir = os.path.join(data_dir, subject, 'anat') + if os.path.isdir(subject_dir): + for file in os.listdir(subject_dir): + if '_patch.nii.gz' in file and 'foramen' in file and 'T1' in file: + image_path = os.path.join(subject_dir, file) + parts = image_path.split('_') + disk_level = f"{parts[-5]}_{parts[-4]}" + + if os.path.exists(image_path): + + subject_id = (subject.replace('sub-', '')) + if 'left' in file: + orientation = 'left' + elif 'right' in file: + orientation = 'right' + label_column = ( + f'{orientation}_neural_foraminal_narrowing_{disk_level.lower()}' + ) + + + # Get raw label + label = labels_df.loc[ + labels_df['study_id'] == int(subject_id), + label_column + ].values[0] + + # Convert text label to numeric value + label_numeric = text2int.get(label, -1) + if label_numeric != -1: + counter += 1 + + data.append({ + "image": image_path, + "label": label_numeric + }) + + + print(f"Number of loaded data: {counter}") + return Dataset(data=data, transform=get_transforms_nfn(mode='random') if random else get_transforms_nfn(mode='basic')) + diff --git a/MIL_training/prepare_sas.py b/MIL_training/prepare_sas.py new file mode 100644 index 0000000..fa08740 --- /dev/null +++ b/MIL_training/prepare_sas.py @@ -0,0 +1,175 @@ +''' define the transformations to prepare the data for the SAS model + from the raw patches we obtain from the preprocessing pipeline +''' + +import os +import pandas as pd +import torchio as tio +from monai.transforms import ( + Compose, LoadImaged, EnsureChannelFirstd, ScaleIntensityd, ConcatItemsd, + ToTensord, SpatialPadd, CenterSpatialCropd, NormalizeIntensityd, + RandRotated, RandSpatialCropd, RandBiasFieldd, Lambdad, Transform, + RandGaussianNoised, RandAffined, RandZoomd, Rand3DElasticd, Flipd, + SpatialCropd, Spacingd, RandLambdad, ResizeWithPadOrCropd, + RandGaussianSharpend, CenterSpatialCropd, RandScaleIntensityd, RandFlipd +) + +from torch.utils.data import DataLoader, ConcatDataset +from monai.data import Dataset +import matplotlib.pyplot as plt +import nibabel as nib +from augment import * + + + +# custom transform to extract slices from the 3D image +# and put them in the MIL bag format +class ExtractSlicesD_sas(Transform): + def __init__(self, keys=['image'], target_size=(384, 384), verbose=False): + self.keys = keys + self.target_size = target_size + self.resize = tio.Resize(target_shape=(*target_size, 1)) + self.verbose = verbose + + def __call__(self, data): + d = dict(data) + + for key in self.keys: + # Get image and remove channel dimension (1, X, Y, 6) -> (X, Y, 6) + image = d[key].squeeze(0) + for i in range(image.shape[0]): + # Extract slice, add channel dim for torchio, + # resize, then normalize + slice_2d = image[i, :, :] + slice_3d = slice_2d.unsqueeze(0).unsqueeze(-1) + if self.verbose: + print(f"Shape before resize: {slice_3d.shape}") + slice_resized = self.resize(slice_3d) + if self.verbose: + print(f"Shape after resize: {slice_resized.shape}") + # Remove the z dimension that we added + slice_final = slice_resized.squeeze(-1) + d[f'slice_{i}'] = slice_final + if self.verbose: + print(f"Final slice {i} shape: {slice_final.shape}") + return d + +# whole transforms for dataloading +def get_transforms_sas(mode='basic'): + + regular_transforms = Compose([ + LoadImaged(keys=['image']), + EnsureChannelFirstd(keys=["image"]), + Spacingd(keys=['image'], pixdim=(4.0, 0.4, 0.4), mode=('bilinear')), + ]) + + if mode == 'basic': + common_transforms = Compose([ + CenterSpatialCropd(keys=['image'],roi_size=(6, 100, 100)), + ResizeWithPadOrCropd(keys=['image'], spatial_size=(6, 100, 100)), + NormalizeIntensityd(keys=['image'],nonzero=True), + ]) + + elif mode == 'random': # for training ! + # Same transforms but with random augmentations + common_transforms = Compose([ + RandRotated(keys=['image'], prob=0.5, range_x=0.1), + RandFlipd(keys=['image'], prob=0.5, spatial_axis=0), + RandSpatialCropd(keys=['image'], roi_size=(6,100, 100), random_size=False), + RandLambdad(keys=['image'],func=aug_sqrt,prob=0.05,), + RandLambdad(keys=['image'],func=aug_sin,prob=0.05,), + RandLambdad(keys=['image'],func=aug_exp,prob=0.05,), + RandLambdad(keys=['image'],func=aug_sig,prob=0.05, ), + RandLambdad(keys=['image'],func=aug_laplace,prob=0.05,), + RandLambdad(keys=['image'],func=aug_inverse,prob=0.05, ), + RandBiasFieldd(keys=['image'],prob=0.05), + RandAffined(keys=['image'],prob=0.05, padding_mode="zeros", mode=["bilinear"]), + + RandGaussianNoised(keys=['image'], mean=0.0, std=0.1, prob=0.05), + RandGaussianSharpend(keys=['image'], prob=0.05), + + #Rand3DElasticd(keys=['T1'],prob=0.05, padding_mode="zeros", mode=["bilinear"], sigma_range=(5,7), magnitude_range=(50,150)), + + ResizeWithPadOrCropd(keys=['image'], spatial_size=(6, 100, 100)), + RandScaleIntensityd(keys=['image'], factors=(0.8, 1.2), prob=1), + ]) + + # Create list of transforms for processing 2D slices + slice_transforms = Compose([ + # Custom transform to extract and resize slices + ExtractSlicesD_sas(keys=['image'], target_size=(100, 100)), + # Ensure all slices are tensors + ToTensord( + keys=[f'slice_{i}' for i in range(6)] + ), + # Concatenate all slices into a bag + ConcatItemsd( + keys=[f'slice_{i}' for i in range(6)], + name='bag', + dim=0 + ), + # Add a transform to ensure bag has the correct shape + Lambdad( + keys=['bag'], + func=lambda x: x.reshape(6, 1, 100, 100) + ) + ]) + + # Combine common_transforms with slice_transforms + transforms = Compose([regular_transforms, common_transforms, slice_transforms]) + + return transforms + + +# prepare the data for the SAS model, +# random is for training, basic for validation +# returns a ConcatDataset of the left and right data +def prepare_data_sas(data_dir, csv_file, random=True): + data = [] + labels_df = pd.read_csv(csv_file) + + counter = 0 + # Label conversion dictionary + text2int = {"Normal/Mild": 0, "Moderate": 1, "Severe": 2} + + for subject in os.listdir(data_dir): + + subject_dir = os.path.join(data_dir, subject, 'anat') + if os.path.isdir(subject_dir): + for file in os.listdir(subject_dir): + if '_patch.nii.gz' in file and 'foramen' in file and 'T2' in file: + image_path = os.path.join(subject_dir, file) + parts = image_path.split('_') + disk_level = f"{parts[-5]}_{parts[-4]}" + + if os.path.exists(image_path): + + subject_id = (subject.replace('sub-', '')) + if 'left' in file: + orientation = 'left' + elif 'right' in file: + orientation = 'right' + label_column = ( + f'{orientation}_subarticular_stenosis_{disk_level.lower()}' + ) + print(file) + print(label_column ) + # Get raw label + label = labels_df.loc[ + labels_df['study_id'] == int(subject_id), + label_column + ].values[0] + + # Convert text label to numeric value + label_numeric = text2int.get(label, -1) + if label_numeric != -1: + counter += 1 + + data.append({ + "image": image_path, + "label": label_numeric + }) + + + print(f"Number of loaded data: {counter}") + return Dataset(data=data, transform=get_transforms_sas(mode='random') if random else get_transforms_sas(mode='basic')) diff --git a/MIL_training/prepare_scs.py b/MIL_training/prepare_scs.py new file mode 100644 index 0000000..6a22057 --- /dev/null +++ b/MIL_training/prepare_scs.py @@ -0,0 +1,171 @@ +''' define the transformations to prepare the data for the SCS model + from the raw patches we obtain from the preprocessing pipeline +''' + +import os +import pandas as pd +import torchio as tio +from monai.transforms import ( + Compose, LoadImaged, EnsureChannelFirstd, ScaleIntensityd, ConcatItemsd, + ToTensord, SpatialPadd, CenterSpatialCropd, NormalizeIntensityd, + RandRotated, RandSpatialCropd, RandBiasFieldd, Lambdad, Transform, + RandGaussianNoised, RandAffined, RandZoomd, Rand3DElasticd, Flipd, + SpatialCropd, Spacingd, RandLambdad, ResizeWithPadOrCropd, + RandGaussianSharpend, CenterSpatialCropd, RandScaleIntensityd +) + + +from torch.utils.data import DataLoader, ConcatDataset +from monai.data import Dataset +import matplotlib.pyplot as plt +import nibabel as nib +from augment import * + + + +# custom transform to extract slices from the 3D image +# and put them in the MIL bag format +class ExtractSlicesD_scs(Transform): + def __init__(self, keys=['image'], target_size=(384, 384), verbose=False): + self.keys = keys + self.target_size = target_size + self.resize = tio.Resize(target_shape=(*target_size, 1)) + self.verbose = verbose + + def __call__(self, data): + d = dict(data) + + for key in self.keys: + # Get image and remove channel dimension (1, X, Y, 6) -> (X, Y, 6) + image = d[key].squeeze(0) + for i in range(image.shape[2]): + # Extract slice, add channel dim for torchio, + # resize, then normalize + slice_2d = image[:, :, i] + slice_3d = slice_2d.unsqueeze(0).unsqueeze(-1) + if self.verbose: + print(f"Shape before resize: {slice_3d.shape}") + slice_resized = self.resize(slice_3d) + if self.verbose: + print(f"Shape after resize: {slice_resized.shape}") + # Remove the z dimension that we added + slice_final = slice_resized.squeeze(-1) + d[f'slice_{i}'] = slice_final + if self.verbose: + print(f"Final slice {i} shape: {slice_final.shape}") + return d + +# transformation pipeline for the data +def get_transforms_scs(mode='basic'): + + regular_transforms = Compose([ + LoadImaged(keys=['image']), + EnsureChannelFirstd(keys=["image"]), + Spacingd(keys=['image'], pixdim=(0.4, 0.4, 4.4), mode=('bilinear')), # Ré-échantillonnage de l'image + + ]) + + + if mode == 'basic': + common_transforms = Compose([ + CenterSpatialCropd(keys=['image'],roi_size=(120, 80, 6)), + SpatialPadd(keys=['image'], spatial_size=(120, 80, 6)), # Padding pour atteindre une taille fixe + ScaleIntensityd(keys=['image']), + NormalizeIntensityd(keys=['image'],nonzero=True), + ]) + + elif mode == 'random': + # Same transforms but with random augmentations + common_transforms = Compose([ + RandRotated(keys=['image'], prob=0.5, range_y=0.1), + RandSpatialCropd(keys=['image'], roi_size=(120, 80, 6), random_size=False), + RandLambdad(keys=['image'],func=aug_sqrt,prob=0.05,), + RandLambdad(keys=['image'],func=aug_sin,prob=0.05,), + RandLambdad(keys=['image'],func=aug_exp,prob=0.05,), + RandLambdad(keys=['image'],func=aug_sig,prob=0.05, ), + RandLambdad(keys=['image'],func=aug_laplace,prob=0.05,), + RandLambdad(keys=['image'],func=aug_inverse,prob=0.05, ), + RandBiasFieldd(keys=['image'],prob=0.05), + RandAffined(keys=['image'],prob=0.05, padding_mode="zeros", mode=["bilinear"]), + + RandGaussianNoised(keys=['image'], mean=0.0, std=0.1, prob=0.05), + RandGaussianSharpend(keys=['image'], prob=0.05), + + #Rand3DElasticd(keys=['image'],prob=0.05, padding_mode="zeros", mode=["bilinear"], sigma_range=(5,7), magnitude_range=(50,150)), + + ResizeWithPadOrCropd(keys=['image'], spatial_size=(120, 80, 6)), + RandScaleIntensityd(keys=['image'], factors=(0.8, 1.2), prob=1), + ]) + + # Create list of transforms for processing 2D slices + slice_transforms = Compose([ + # Custom transform to extract and resize slices + ExtractSlicesD_scs(keys=['image'], target_size=(384, 384)), + + # Ensure all slices are tensors + ToTensord( + keys=[f'slice_{i}' for i in range(6)] + ), + # Concatenate all slices into a bag + ConcatItemsd( + keys=[f'slice_{i}' for i in range(6)], + name='bag', + dim=0 + ), + # Add a transform to ensure bag has the correct shape + Lambdad( + keys=['bag'], + func=lambda x: x.reshape(6, 1, 384, 384) + ) + ]) + + # Combine common_transforms with slice_transforms + transforms = Compose([regular_transforms, common_transforms, slice_transforms]) + + return transforms + +# prepare the data for the SCS model, +# random is for training, basic for validation +# returns a ConcatDataset of the left and right data +def prepare_data_scs(data_dir, csv_file, random=True): + data = [] + labels_df = pd.read_csv(csv_file) + + counter = 0 + # Label conversion dictionary + text2int = {"Normal/Mild": 0, "Moderate": 1, "Severe": 2} + + for subject in os.listdir(data_dir): + print(subject) + subject_dir = os.path.join(data_dir, subject, 'anat') + if os.path.isdir(subject_dir): + for file in os.listdir(subject_dir): + if '_patch.nii.gz' in file and 'foramen' not in file: + image_path = os.path.join(subject_dir, file) + parts = image_path.split('_') + disk_level = f"{parts[-3]}_{parts[-2]}" + + if os.path.exists(image_path): + + subject_id = (subject.replace('sub-', '')) + + label_column = ( + f'spinal_canal_stenosis_{disk_level.lower()}' + ) + # Get raw label + label = labels_df.loc[ + labels_df['study_id'] == int(subject_id), + label_column + ].values[0] + + # Convert text label to numeric value + label_numeric = text2int.get(label, -1) + if label_numeric != -1: + counter += 1 + data.append({ + "image": image_path, + "label": label_numeric + }) + + print(f"Number of loaded data: {counter}") + return Dataset(data=data, transform=get_transforms_scs(mode='random') if random else get_transforms_scs(mode='basic')) \ No newline at end of file diff --git a/MIL_training/train_mil_nfn.py b/MIL_training/train_mil_nfn.py new file mode 100644 index 0000000..d7a69a8 --- /dev/null +++ b/MIL_training/train_mil_nfn.py @@ -0,0 +1,205 @@ +''' train the NFN MIL model ''' + +import os +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, ConcatDataset +from training_utils import CosineAnnealingStabilizeLR, weight_challenge, train_epoch, validate, visualize_batch +import wandb +from tqdm import tqdm +from prepare_nfn import prepare_data_nfn +from mil_definition import MILmodel +import numpy as np +import matplotlib.pyplot as plt +import random +import json +import math +import timm + + +# main function to train the NFN MIL model +def train_model_nfn( + encoder, + data_dir, + csv_file, + num_epochs=20, + batch_size=8, + learning_rate=1e-4, + encoder_lr=1e-5, # Learning rate spécifique pour le ConvNext + freeze_encoder_epoch=5, # Époque à partir de laquelle on freeze le ConvNext + encoder_cosine_epochs=3, # Nombre d'époques pour atteindre le minimum du cosine pour l'encoder + other_cosine_epochs=6, # Nombre d'époques pour atteindre le minimum du cosine pour le reste + eta_min_factor_encoder=0.04, # Facteur pour calculer eta_min de l'encoder (par rapport à encoder_lr) + eta_min_factor_other=0.04, # Facteur pour calculer eta_min du reste (par rapport à learning_rate) + num_layers=1, + device='cuda' +): + # Initialize wandb + wandb.init( + project="lumbar-mil-nfn", + config={ + "epochs": num_epochs, + "batch_size": batch_size, + "learning_rate": learning_rate, + "encoder_lr": encoder_lr, + "freeze_encoder_epoch": freeze_encoder_epoch, + "encoder_cosine_epochs": encoder_cosine_epochs, + "other_cosine_epochs": other_cosine_epochs, + "eta_min_factor_encoder": eta_min_factor_encoder, + "eta_min_factor_other": eta_min_factor_other, + "scheduler": "CosineAnnealing", + "architecture": "ConvNeXt-Small-MIL", + "num_layers": num_layers + } + ) + + # create a folder with a random name in the current directory + folder_name = f"mil_model_nfn" + os.makedirs(folder_name, exist_ok=True) + + # Prepare data + train_dir = os.path.join(data_dir, 'training') + val_dir = os.path.join(data_dir, 'validation') + + # Create datasets + train_data = prepare_data_nfn(train_dir, csv_file, random=True) + val_data= prepare_data_nfn(val_dir, csv_file, random=False) + + # Create dataloaders + train_loader = DataLoader(train_data, batch_size=batch_size, + shuffle=True, num_workers=0) + val_loader = DataLoader(val_data, batch_size=batch_size, + shuffle=False, num_workers=0) + + # Initialize model + model = MILmodel(encoder=convnext_small, num_layers=num_layers).to(device) + + # Loss function - CrossEntropyLoss with class weights if needed + #criterion_encoder = nn.CrossEntropyLoss() + #criterion_no_encoder = nn.CrossEntropyLoss(weight=weight_challenge) + criterion = nn.CrossEntropyLoss(weight=weight_challenge) + + # Séparer les paramètres du ConvNext et du reste du modèle + encoder_params = model.encoder.parameters() + other_params = [p for n, p in model.named_parameters() if not n.startswith('encoder')] + + encoder_optimizer = optim.AdamW(encoder_params, lr=encoder_lr, weight_decay=0.01) + other_optimizer = optim.AdamW(other_params, lr=learning_rate, weight_decay=0.01) + + encoder_scheduler = CosineAnnealingStabilizeLR(encoder_optimizer, # Pass the entire optimizer + T_max=len(train_loader) * encoder_cosine_epochs, # Période spécifique pour l'encoder + eta_min=encoder_lr * eta_min_factor_encoder # Minimum learning rate pour l'encoder + ) + other_scheduler = CosineAnnealingStabilizeLR(other_optimizer, # Pass the entire optimizer + T_max=len(train_loader) * encoder_cosine_epochs, # Période spécifique pour l'encoder + eta_min=encoder_lr * eta_min_factor_encoder # Minimum learning rate pour l'encoder + ) + + # Log initial learning rates and minimum values + wandb.log({ + "initial_encoder_lr": encoder_lr, + "initial_other_lr": learning_rate, + "min_encoder_lr": encoder_lr * eta_min_factor_encoder, + "min_other_lr": learning_rate * eta_min_factor_other, + "encoder_cosine_epochs": encoder_cosine_epochs, + "other_cosine_epochs": other_cosine_epochs, + "eta_min_factor_encoder": eta_min_factor_encoder, + "eta_min_factor_other": eta_min_factor_other + }) + + # Training loop + best_val_loss = float('inf') + for epoch in range(num_epochs): + print(f"\nEpoch {epoch + 1}/{num_epochs}") + #criterion = criterion_encoder + # Freeze le ConvNext après freeze_encoder_epoch époques + if epoch >= freeze_encoder_epoch: + #criterion = criterion_no_encoder + for param in model.encoder.parameters(): + param.requires_grad = False + print("ConvNext encoder frozen") + + + # Train + train_loss, train_acc = train_epoch( + model, train_loader, criterion, + (encoder_optimizer, other_optimizer), + (encoder_scheduler, other_scheduler), device, + epoch=epoch + ) + + # Validate + val_loss, val_acc = validate( + model, val_loader, criterion, device + ) + + # Log metrics + wandb.log({ + "epoch": epoch + 1, + "train_loss": train_loss, + "train_acc": train_acc, + "val_loss": val_loss, + "val_acc": val_acc, + "encoder_learning_rate": encoder_scheduler.get_last_lr()[0], + "other_learning_rate": other_scheduler.get_last_lr()[0], + "encoder_frozen": epoch >= freeze_encoder_epoch, + "encoder_lr_percentage": (encoder_scheduler.get_last_lr()[0] / encoder_lr) * 100, # Pourcentage du LR initial + "other_lr_percentage": (other_scheduler.get_last_lr()[0] / learning_rate) * 100 # Pourcentage du LR initial + }) + + # Save best model + if val_loss < best_val_loss: + best_val_loss = val_loss + torch.save({ + 'epoch': epoch + 1, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': other_optimizer.state_dict(), + 'encoder_scheduler_state_dict': encoder_scheduler.state_dict(), + 'other_scheduler_state_dict': other_scheduler.state_dict(), + 'val_loss': val_loss, + 'val_acc': val_acc, + }, os.path.join(folder_name, f"best_mil_model.pth")) + print(f"Saved new best model with validation loss: {val_loss:.4f}") + + # adds a json file in the folder with the config and the best loss + with open(os.path.join(folder_name, 'config.json'), 'w') as f: + json.dump(dict(wandb.config), f) + with open(os.path.join(folder_name, 'best_loss.json'), 'w') as f: + json.dump({'best_loss': best_val_loss}, f) + + wandb.finish() + return model + +# lauching a training with the NFN model +if __name__ == "__main__": + # Set device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + + # Set paths + data_dir = '../../rsna_challenge_data_split' + csv_file = '../../train.csv' + + convnext_small = timm.create_model('convnext_small.fb_in22k_ft_in1k_384', + in_chans=1, pretrained=True, num_classes=0) + + + # Train model + model = train_model_nfn( + convnext_small, + data_dir=data_dir, + csv_file=csv_file, + num_epochs=16, + batch_size=4, + learning_rate=1e-4, + encoder_lr=5e-5, # Learning rate plus faible pour le ConvNext + freeze_encoder_epoch=4, # Freeze le ConvNext après 3 époques + encoder_cosine_epochs=12, # Le ConvNext atteint son minimum en 2 époques + other_cosine_epochs=12, # Le reste du modèle atteint son minimum en 4 époques + eta_min_factor_encoder=0.05, # Le lr de l'encoder descend à 4% de sa valeur initiale + eta_min_factor_other=0.05, # Le lr du reste descend à 4% de sa valeur initiale + num_layers=2, + device=device + ) + \ No newline at end of file diff --git a/MIL_training/train_mil_sas.py b/MIL_training/train_mil_sas.py new file mode 100644 index 0000000..5d30ab2 --- /dev/null +++ b/MIL_training/train_mil_sas.py @@ -0,0 +1,205 @@ +''' train the SAS MIL model ''' + +import os +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, ConcatDataset +from training_utils import CosineAnnealingStabilizeLR, weight_challenge, train_epoch, validate, visualize_batch +import wandb +from tqdm import tqdm +from prepare_sas import prepare_data_sas +from mil_definition import MILmodel +import numpy as np +import matplotlib.pyplot as plt +import random +import json +import math +import timm + + +# main function to train the SAS MIL model +def train_model_sas( + encoder, + data_dir, + csv_file, + num_epochs=20, + batch_size=8, + learning_rate=1e-4, + encoder_lr=1e-5, # Learning rate spécifique pour le ConvNext + freeze_encoder_epoch=5, # Époque à partir de laquelle on freeze le ConvNext + encoder_cosine_epochs=3, # Nombre d'époques pour atteindre le minimum du cosine pour l'encoder + other_cosine_epochs=6, # Nombre d'époques pour atteindre le minimum du cosine pour le reste + eta_min_factor_encoder=0.04, # Facteur pour calculer eta_min de l'encoder (par rapport à encoder_lr) + eta_min_factor_other=0.04, # Facteur pour calculer eta_min du reste (par rapport à learning_rate) + num_layers=1, + device='cuda' +): + # Initialize wandb + wandb.init( + project="lumbar-mil-sas", + config={ + "epochs": num_epochs, + "batch_size": batch_size, + "learning_rate": learning_rate, + "encoder_lr": encoder_lr, + "freeze_encoder_epoch": freeze_encoder_epoch, + "encoder_cosine_epochs": encoder_cosine_epochs, + "other_cosine_epochs": other_cosine_epochs, + "eta_min_factor_encoder": eta_min_factor_encoder, + "eta_min_factor_other": eta_min_factor_other, + "scheduler": "CosineAnnealing", + "architecture": "ConvNeXt-Small-MIL", + "num_layers": num_layers + } + ) + + # create a folder with a random name in the current directory + folder_name = f"mil_model_sas" + os.makedirs(folder_name, exist_ok=True) + + # Prepare data + train_dir = os.path.join(data_dir, 'training') + val_dir = os.path.join(data_dir, 'validation') + + # Create datasets + train_data = prepare_data_sas(train_dir, csv_file, random=True) + val_data= prepare_data_sas(val_dir, csv_file, random=False) + + # Create dataloaders + train_loader = DataLoader(train_data, batch_size=batch_size, + shuffle=True, num_workers=0) + val_loader = DataLoader(val_data, batch_size=batch_size, + shuffle=False, num_workers=0) + + # Initialize model + model = MILmodel(encoder=convnext_small, num_layers=num_layers).to(device) + + # Loss function - CrossEntropyLoss with class weights if needed + #criterion_encoder = nn.CrossEntropyLoss() + #criterion_no_encoder = nn.CrossEntropyLoss(weight=weight_challenge) + criterion = nn.CrossEntropyLoss(weight=weight_challenge) + + # Séparer les paramètres du ConvNext et du reste du modèle + encoder_params = model.encoder.parameters() + other_params = [p for n, p in model.named_parameters() if not n.startswith('encoder')] + + encoder_optimizer = optim.AdamW(encoder_params, lr=encoder_lr, weight_decay=0.01) + other_optimizer = optim.AdamW(other_params, lr=learning_rate, weight_decay=0.01) + + encoder_scheduler = CosineAnnealingStabilizeLR(encoder_optimizer, # Pass the entire optimizer + T_max=len(train_loader) * encoder_cosine_epochs, # Période spécifique pour l'encoder + eta_min=encoder_lr * eta_min_factor_encoder # Minimum learning rate pour l'encoder + ) + other_scheduler = CosineAnnealingStabilizeLR(other_optimizer, # Pass the entire optimizer + T_max=len(train_loader) * encoder_cosine_epochs, # Période spécifique pour l'encoder + eta_min=encoder_lr * eta_min_factor_encoder # Minimum learning rate pour l'encoder + ) + + # Log initial learning rates and minimum values + wandb.log({ + "initial_encoder_lr": encoder_lr, + "initial_other_lr": learning_rate, + "min_encoder_lr": encoder_lr * eta_min_factor_encoder, + "min_other_lr": learning_rate * eta_min_factor_other, + "encoder_cosine_epochs": encoder_cosine_epochs, + "other_cosine_epochs": other_cosine_epochs, + "eta_min_factor_encoder": eta_min_factor_encoder, + "eta_min_factor_other": eta_min_factor_other + }) + + # Training loop + best_val_loss = float('inf') + for epoch in range(num_epochs): + print(f"\nEpoch {epoch + 1}/{num_epochs}") + #criterion = criterion_encoder + # Freeze le ConvNext après freeze_encoder_epoch époques + if epoch >= freeze_encoder_epoch: + #criterion = criterion_no_encoder + for param in model.encoder.parameters(): + param.requires_grad = False + print("ConvNext encoder frozen") + + + # Train + train_loss, train_acc = train_epoch( + model, train_loader, criterion, + (encoder_optimizer, other_optimizer), + (encoder_scheduler, other_scheduler), device, + epoch=epoch + ) + + # Validate + val_loss, val_acc = validate( + model, val_loader, criterion, device + ) + + # Log metrics + wandb.log({ + "epoch": epoch + 1, + "train_loss": train_loss, + "train_acc": train_acc, + "val_loss": val_loss, + "val_acc": val_acc, + "encoder_learning_rate": encoder_scheduler.get_last_lr()[0], + "other_learning_rate": other_scheduler.get_last_lr()[0], + "encoder_frozen": epoch >= freeze_encoder_epoch, + "encoder_lr_percentage": (encoder_scheduler.get_last_lr()[0] / encoder_lr) * 100, # Pourcentage du LR initial + "other_lr_percentage": (other_scheduler.get_last_lr()[0] / learning_rate) * 100 # Pourcentage du LR initial + }) + + # Save best model + if val_loss < best_val_loss: + best_val_loss = val_loss + torch.save({ + 'epoch': epoch + 1, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': other_optimizer.state_dict(), + 'encoder_scheduler_state_dict': encoder_scheduler.state_dict(), + 'other_scheduler_state_dict': other_scheduler.state_dict(), + 'val_loss': val_loss, + 'val_acc': val_acc, + }, os.path.join(folder_name, f"best_mil_model.pth")) + print(f"Saved new best model with validation loss: {val_loss:.4f}") + + # adds a json file in the folder with the config and the best loss + with open(os.path.join(folder_name, 'config.json'), 'w') as f: + json.dump(dict(wandb.config), f) + with open(os.path.join(folder_name, 'best_loss.json'), 'w') as f: + json.dump({'best_loss': best_val_loss}, f) + + wandb.finish() + return model + +# lauching a training with the SAS model +if __name__ == "__main__": + # Set device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + + # Set paths + data_dir = '../../rsna_challenge_data_split' + csv_file = '../../train.csv' + + convnext_small = timm.create_model('convnext_small.fb_in22k_ft_in1k_384', + in_chans=1, pretrained=True, num_classes=0) + + + # Train model + model = train_model_sas( + convnext_small, + data_dir=data_dir, + csv_file=csv_file, + num_epochs=16, + batch_size=2, + learning_rate=0.00005, + encoder_lr=0.00005, # Learning rate plus faible pour le ConvNext + freeze_encoder_epoch=4, # Freeze le ConvNext après 3 époques + encoder_cosine_epochs=12, # Le ConvNext atteint son minimum en 2 époques + other_cosine_epochs=12, # Le reste du modèle atteint son minimum en 4 époques + eta_min_factor_encoder=0.05, # Le lr de l'encoder descend à 4% de sa valeur initiale + eta_min_factor_other=0.05, # Le lr du reste descend à 4% de sa valeur initiale + num_layers=2, + device=device + ) + \ No newline at end of file diff --git a/MIL_training/train_mil_scs.py b/MIL_training/train_mil_scs.py new file mode 100644 index 0000000..70d146c --- /dev/null +++ b/MIL_training/train_mil_scs.py @@ -0,0 +1,205 @@ +''' train the SCS MIL model ''' + +import os +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, ConcatDataset +from training_utils import CosineAnnealingStabilizeLR, weight_challenge, train_epoch, validate, visualize_batch +import wandb +from tqdm import tqdm +from prepare_scs import prepare_data_scs +from mil_definition import MILmodel +import numpy as np +import matplotlib.pyplot as plt +import random +import json +import math +import timm + + +# main function to train the SCS MIL model +def train_model_scs( + encoder, + data_dir, + csv_file, + num_epochs=20, + batch_size=8, + learning_rate=1e-4, + encoder_lr=1e-5, # Learning rate spécifique pour le ConvNext + freeze_encoder_epoch=5, # Époque à partir de laquelle on freeze le ConvNext + encoder_cosine_epochs=3, # Nombre d'époques pour atteindre le minimum du cosine pour l'encoder + other_cosine_epochs=6, # Nombre d'époques pour atteindre le minimum du cosine pour le reste + eta_min_factor_encoder=0.04, # Facteur pour calculer eta_min de l'encoder (par rapport à encoder_lr) + eta_min_factor_other=0.04, # Facteur pour calculer eta_min du reste (par rapport à learning_rate) + num_layers=1, + device='cuda' +): + # Initialize wandb + wandb.init( + project="lumbar-mil-scs", + config={ + "epochs": num_epochs, + "batch_size": batch_size, + "learning_rate": learning_rate, + "encoder_lr": encoder_lr, + "freeze_encoder_epoch": freeze_encoder_epoch, + "encoder_cosine_epochs": encoder_cosine_epochs, + "other_cosine_epochs": other_cosine_epochs, + "eta_min_factor_encoder": eta_min_factor_encoder, + "eta_min_factor_other": eta_min_factor_other, + "scheduler": "CosineAnnealing", + "architecture": "ConvNeXt-Small-MIL", + "num_layers": num_layers + } + ) + + # create a folder with a random name in the current directory + folder_name = f"mil_model_scs" + os.makedirs(folder_name, exist_ok=True) + + # Prepare data + train_dir = os.path.join(data_dir, 'training') + val_dir = os.path.join(data_dir, 'validation') + + # Create datasets + train_data = prepare_data_scs(train_dir, csv_file, random=True) + val_data= prepare_data_scs(val_dir, csv_file, random=False) + + # Create dataloaders + train_loader = DataLoader(train_data, batch_size=batch_size, + shuffle=True, num_workers=0) + val_loader = DataLoader(val_data, batch_size=batch_size, + shuffle=False, num_workers=0) + + # Initialize model + model = MILmodel(encoder=convnext_small, num_layers=num_layers).to(device) + + # Loss function - CrossEntropyLoss with class weights if needed + #criterion_encoder = nn.CrossEntropyLoss() + #criterion_no_encoder = nn.CrossEntropyLoss(weight=weight_challenge) + criterion = nn.CrossEntropyLoss(weight=weight_challenge) + + # Séparer les paramètres du ConvNext et du reste du modèle + encoder_params = model.encoder.parameters() + other_params = [p for n, p in model.named_parameters() if not n.startswith('encoder')] + + encoder_optimizer = optim.AdamW(encoder_params, lr=encoder_lr, weight_decay=0.01) + other_optimizer = optim.AdamW(other_params, lr=learning_rate, weight_decay=0.01) + + encoder_scheduler = CosineAnnealingStabilizeLR(encoder_optimizer, # Pass the entire optimizer + T_max=len(train_loader) * encoder_cosine_epochs, # Période spécifique pour l'encoder + eta_min=encoder_lr * eta_min_factor_encoder # Minimum learning rate pour l'encoder + ) + other_scheduler = CosineAnnealingStabilizeLR(other_optimizer, # Pass the entire optimizer + T_max=len(train_loader) * encoder_cosine_epochs, # Période spécifique pour l'encoder + eta_min=encoder_lr * eta_min_factor_encoder # Minimum learning rate pour l'encoder + ) + + # Log initial learning rates and minimum values + wandb.log({ + "initial_encoder_lr": encoder_lr, + "initial_other_lr": learning_rate, + "min_encoder_lr": encoder_lr * eta_min_factor_encoder, + "min_other_lr": learning_rate * eta_min_factor_other, + "encoder_cosine_epochs": encoder_cosine_epochs, + "other_cosine_epochs": other_cosine_epochs, + "eta_min_factor_encoder": eta_min_factor_encoder, + "eta_min_factor_other": eta_min_factor_other + }) + + # Training loop + best_val_loss = float('inf') + for epoch in range(num_epochs): + print(f"\nEpoch {epoch + 1}/{num_epochs}") + #criterion = criterion_encoder + # Freeze le ConvNext après freeze_encoder_epoch époques + if epoch >= freeze_encoder_epoch: + #criterion = criterion_no_encoder + for param in model.encoder.parameters(): + param.requires_grad = False + print("ConvNext encoder frozen") + + + # Train + train_loss, train_acc = train_epoch( + model, train_loader, criterion, + (encoder_optimizer, other_optimizer), + (encoder_scheduler, other_scheduler), device, + epoch=epoch + ) + + # Validate + val_loss, val_acc = validate( + model, val_loader, criterion, device + ) + + # Log metrics + wandb.log({ + "epoch": epoch + 1, + "train_loss": train_loss, + "train_acc": train_acc, + "val_loss": val_loss, + "val_acc": val_acc, + "encoder_learning_rate": encoder_scheduler.get_last_lr()[0], + "other_learning_rate": other_scheduler.get_last_lr()[0], + "encoder_frozen": epoch >= freeze_encoder_epoch, + "encoder_lr_percentage": (encoder_scheduler.get_last_lr()[0] / encoder_lr) * 100, # Pourcentage du LR initial + "other_lr_percentage": (other_scheduler.get_last_lr()[0] / learning_rate) * 100 # Pourcentage du LR initial + }) + + # Save best model + if val_loss < best_val_loss: + best_val_loss = val_loss + torch.save({ + 'epoch': epoch + 1, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': other_optimizer.state_dict(), + 'encoder_scheduler_state_dict': encoder_scheduler.state_dict(), + 'other_scheduler_state_dict': other_scheduler.state_dict(), + 'val_loss': val_loss, + 'val_acc': val_acc, + }, os.path.join(folder_name, f"best_mil_model.pth")) + print(f"Saved new best model with validation loss: {val_loss:.4f}") + + # adds a json file in the folder with the config and the best loss + with open(os.path.join(folder_name, 'config.json'), 'w') as f: + json.dump(dict(wandb.config), f) + with open(os.path.join(folder_name, 'best_loss.json'), 'w') as f: + json.dump({'best_loss': best_val_loss}, f) + + wandb.finish() + return model + +# lauching a training with the SCS model +if __name__ == "__main__": + # Set device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + + # Set paths + data_dir = '../../rsna_challenge_data_split' + csv_file = '../../train.csv' + + convnext_small = timm.create_model('convnext_small.fb_in22k_ft_in1k_384', + in_chans=1, pretrained=True, num_classes=0) + + + # Train model + model = train_model_scs( + convnext_small, + data_dir=data_dir, + csv_file=csv_file, + num_epochs=16, + batch_size=16, + learning_rate=5e-4, + encoder_lr=5e-7, # Learning rate plus faible pour le ConvNext + freeze_encoder_epoch=4, # Freeze le ConvNext après 3 époques + encoder_cosine_epochs=12, # Le ConvNext atteint son minimum en 2 époques + other_cosine_epochs=12, # Le reste du modèle atteint son minimum en 4 époques + eta_min_factor_encoder=0.05, # Le lr de l'encoder descend à 4% de sa valeur initiale + eta_min_factor_other=0.05, # Le lr du reste descend à 4% de sa valeur initiale + num_layers=2, + device=device + ) + \ No newline at end of file diff --git a/MIL_training/training_utils.py b/MIL_training/training_utils.py new file mode 100644 index 0000000..4b92157 --- /dev/null +++ b/MIL_training/training_utils.py @@ -0,0 +1,167 @@ +''' utility functions for training the models ''' + +import torch +import torch.nn as nn +import torch.optim as optim +from torch.optim.lr_scheduler import _LRScheduler +from tqdm import tqdm +import math +import matplotlib.pyplot as plt +import wandb + + +# use the challenge's loss function : weighted cross entropy +# with weights 1, 2, 4 for the 3 classes +weight_challenge = torch.tensor([1.0, 2.0, 4.0]).cuda() + +# custom learning rate scheduler +class CosineAnnealingStabilizeLR(_LRScheduler): + def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1): + self.T_max = T_max + self.eta_min = eta_min + super(CosineAnnealingStabilizeLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch >= self.T_max: + return [self.eta_min for _ in self.base_lrs] + + return [self.eta_min + (base_lr - self.eta_min) * + (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 + for base_lr in self.base_lrs] + +# function to plot the first batch of each epoch on wandb +def visualize_batch(batch, epoch): + """ + Visualize a batch of images and save them to wandb + batch: dictionary containing 'bag' tensor of shape [B, 6, 1, 384, 384] and 'label' + epoch: current epoch number + """ + try: + # Get the first batch and ensure it's on CPU + images = batch['bag'].cpu().detach() # Shape: [B, 6, 1, 384, 384] + labels = batch['label'].cpu().detach() + + # Take only the first 4 samples to avoid too large figures + n_samples = min(4, images.shape[0]) + + # Create a figure with subplots for each sample and its 6 slices + fig, axes = plt.subplots(n_samples, 6, figsize=(20, 4*n_samples)) + if n_samples == 1: + axes = axes[None, :] # Add dimension for consistent indexing + + for i in range(n_samples): + for j in range(6): + # Get the image slice and ensure it's a valid image + img = images[i, j, 0].numpy() + + # Normalize the image for better visualization + img = (img - img.min()) / (img.max() - img.min() + 1e-8) + + # Plot the image + axes[i, j].imshow(img, cmap='gray') + axes[i, j].axis('off') + + # Add title only to the first row + if i == 0: + axes[i, j].set_title(f'Slice {j+1}') + + # Add label information on the left + axes[i, 0].set_ylabel(f'Sample {i+1}\nLabel: {labels[i].item()}') + + plt.tight_layout() + + # Log to wandb + wandb.log({f"batch_visualization_epoch_{epoch}": wandb.Image(fig)}) + plt.close(fig) + except Exception as e: + print(f"Warning: Could not visualize batch: {str(e)}") + plt.close('all') # Ensure all figures are closed in case of error + + +def train_epoch( + model, + train_loader, + criterion, + optimizers, # Tuple: (encoder_optimizer, other_optimizer) + schedulers, # Tuple: (encoder_scheduler, other_scheduler) + device, + epoch=None +): + model.train() + running_loss = 0.0 + correct = 0 + total = 0 + + encoder_optimizer, other_optimizer = optimizers + encoder_scheduler, other_scheduler = schedulers + + pbar = tqdm(train_loader, desc='Training') + for i, batch in enumerate(pbar): + bags = batch['bag'].to(device) + labels = batch['label'].to(device) + + if i == 0 and epoch is not None: + visualize_batch({k: v.cpu() for k, v in batch.items()}, epoch) + + encoder_optimizer.zero_grad() + other_optimizer.zero_grad() + + main_output = model(bags) + loss = criterion(main_output, labels) + + loss.backward() + + encoder_optimizer.step() + other_optimizer.step() + + _, predicted = torch.max(main_output, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + running_loss += loss.item() + + pbar.set_postfix({ + 'loss': f'{loss.item():.4f}', + 'acc': f'{100 * correct / total:.2f}%' + }) + + encoder_scheduler.step() + other_scheduler.step() + + epoch_loss = running_loss / len(train_loader) + acc = 100 * correct / total + + return epoch_loss, acc + + + +# Function to validate the model +@torch.no_grad() +def validate(model, val_loader, criterion, device): + model.eval() + running_loss = 0.0 + correct = 0 + total = 0 + + for batch in tqdm(val_loader, desc='Validation'): + # Get data + bags = batch['bag'].to(device) + labels = batch['label'].to(device) + + # Forward pass + main_output = model(bags) + + # Calculate losses + loss = criterion(main_output, labels) + + # Calculate accuracy + _, predicted = torch.max(main_output, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + running_loss += loss.item() + + # Calculate epoch statistics + val_loss = running_loss / len(val_loader) + acc = 100 * correct / total + + return val_loss, acc \ No newline at end of file diff --git a/README.md b/README.md index 0e79930..ede3344 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,18 @@ # Team Neuropoly: RSNA 2024 Lumbar Spine Degenerative Classification Challenge -TODO +This branch of the repo uses the data obtained from the preprocessing pipeline to train classification models to predict the severity of the pathologies. +We first preprocess the data for each pathology in each "prepare" file as follows : +- we start from a defined 3d pacth (constant physical size) +- we resample it and apply different random transforms for data augmentation +- finally we create a bag of 2d slices for the MIL architecture : + +![Exemple de prétraitement](images/patch2bag.png "Exemple de prétraitement des données") + +Then we train models using each "train" files, for each pathology, training functions are in training_utils.py, and model definition is in mil_definition.py. + +The models have the following architecture : + +![Architecture du modèle](images/mil.png "Architecture du modèle MIL") + +Images are encoded into latent vectors, then processed iteratively through a bidirectional RNN. Finally an attention layer outputs a wieghts, allowing us to sum the vectors with thoses normalized weights, before outputing a severity prediction. \ No newline at end of file diff --git a/images/mil.png b/images/mil.png new file mode 100644 index 0000000..d1a79c5 Binary files /dev/null and b/images/mil.png differ diff --git a/images/patch2bag.png b/images/patch2bag.png new file mode 100644 index 0000000..24eb402 Binary files /dev/null and b/images/patch2bag.png differ diff --git a/inference_mil.py b/inference_mil.py new file mode 100644 index 0000000..01c24c2 --- /dev/null +++ b/inference_mil.py @@ -0,0 +1,181 @@ +''' file to run inference on the validation set for the SAS and SCS models ''' +# inference should be added for NFN also.. + +import os +import torch +import json +from torch.utils.data import DataLoader, ConcatDataset +from prepare_data_mil import prepare_data_scs, prepare_data_sas, prepare_data_sas_option, prepare_data_nfn +from mil_definition import MILmodel, convnext_small +from train_mil import weight_challenge, run_inference_on_validation_set +import numpy as np +from sklearn.metrics import confusion_matrix +import seaborn as sns +import matplotlib.pyplot as plt +from tqdm import tqdm +import csv +import torch.nn as nn + + +def plot_confusion_matrices(y_true, y_pred, save_dir): + """ + Plot two confusion matrices: one with raw counts and one with percentages + """ + # Créer la matrice de confusion avec les valeurs brutes + cm = confusion_matrix(y_true, y_pred) + + # Créer la figure avec deux subplots + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8)) + + # Premier subplot : valeurs brutes + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax1) + ax1.set_title('Confusion Matrix (Raw Counts)') + ax1.set_xlabel('Predicted') + ax1.set_ylabel('True') + + # Deuxième subplot : pourcentages + cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100 + sns.heatmap(cm_percent, annot=True, fmt='.1f', cmap='Blues', ax=ax2) + ax2.set_title('Confusion Matrix (Percentages)') + ax2.set_xlabel('Predicted') + ax2.set_ylabel('True') + + plt.tight_layout() + plt.savefig(os.path.join(save_dir, 'confusion_matrices.png')) + plt.close() + +def save_predicted_values(predictions, labels, save_dir): + """ + Save the predicted probabilities and labels to a CSV file + """ + # normalize the probabilities + predictions = np.array(predictions) + predictions = np.exp(predictions) / np.sum(np.exp(predictions), axis=1, keepdims=True) + + with open(os.path.join(save_dir, 'predictions.csv'), 'w') as f: + writer = csv.writer(f) + writer.writerow(['prediction', 'label']) + for pred, label in zip(predictions, labels): + writer.writerow([pred, label]) + + +@torch.no_grad() +def run_inference(model, val_loader, device): + """ + Run inference on the validation set + """ + model.eval() + all_preds = [] + all_probs = [] + all_labels = [] + + for batch in tqdm(val_loader, desc='Inference'): + # Get data + bags = batch['bag'].to(device) + labels = batch['label'].to(device) + + # Forward pass + main_output, _ = model(bags) + + # Get predictions + _, predicted = torch.max(main_output, 1) + + # Store predictions and labels + all_preds.extend(predicted.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + all_probs.extend(main_output.cpu().numpy()) + return np.array(all_preds), np.array(all_labels), np.array(all_probs) + +def run_inference_on_severe_cases(model_path, data_dir, csv_file, device='cuda'): + """ + Run inference on severe cases using a pretrained model and calculate CV. + + Args: + model_path (str): Path to the pretrained model (.pth file) + data_dir (str): Directory containing the data + csv_file (str): Path to the CSV file with labels + device (str): Device to run inference on ('cuda' or 'cpu') + + Returns: + float: Coefficient of variation for severe cases + """ + # Load model + model = MILmodel(encoder=convnext_small, num_layers=2).to(device) + checkpoint = torch.load(model_path) + model.load_state_dict(checkpoint['model_state_dict']) + print(f"Loaded model from epoch {checkpoint['epoch']}") + + # Prepare data + val_dir = os.path.join(data_dir, 'validation') + val_data_left, val_data_right = prepare_data_sas_severe(val_dir, csv_file, random=False) + val_data = ConcatDataset([val_data_left, val_data_right]) + + severe_loader = DataLoader(val_data, + batch_size=2, + shuffle=False, + num_workers=8) + + criterion = nn.CrossEntropyLoss(weight=weight_challenge) + + # Run inference and calculate CV + cv, mean_loss, var_loss = run_inference_on_validation_set(model, severe_loader, device, criterion) + print(f"Coefficient of variation for severe cases: {cv:.4f}, mean loss: {mean_loss:.4f}, var loss: {var_loss:.4f}") + + return cv + +def main(): + # Configuration + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + + # Paths + model_dir = 'mil_model_nfn711201' + data_dir = '../../duke/public/rsna_challenge/20250408nii_data' + csv_file = '../../duke/public/rsna_challenge/dcom_data/train.csv' + model_dir = '/home/ge.polymtl.ca/p121315/rsna_git/lumbar-classification-rsna-challenge-2024/mil_model_sas247141' + data_dir = '/home/ge.polymtl.ca/p121315/duke/public/rsna_challenge/20250212nii_data_splits' + csv_file = '/home/ge.polymtl.ca/p121315/duke/public/rsna_challenge/dcom_data/train.csv' + + # Load configuration + with open(os.path.join(model_dir, 'config.json'), 'r') as f: + config = json.load(f) + + # Create validation dataset and dataloader + val_dir = os.path.join(data_dir, 'validation') + val_data = prepare_data_nfn(val_dir, csv_file, random=False) + val_loader = DataLoader(val_data, + batch_size=config['batch_size'], + shuffle=False, + num_workers=4) + + # Initialize model + model = MILmodel(encoder=convnext_small, num_layers=config['num_layers']).to(device) + + # Load model weights + checkpoint = torch.load(os.path.join(model_dir, 'best_mil_model.pth')) + model.load_state_dict(checkpoint['model_state_dict']) + print(f"Loaded model from epoch {checkpoint['epoch']} with validation loss {checkpoint['val_loss']:.4f}") + + # Run inference + predictions, labels, probs = run_inference(model, val_loader, device) + + # Plot and save confusion matrices + plot_confusion_matrices(labels, predictions, model_dir) + print(f"Confusion matrices saved to {model_dir}/confusion_matrices.png") + + # Save predicted values + save_predicted_values(probs, labels, model_dir) + + # Calculate and print accuracy + accuracy = (predictions == labels).mean() * 100 + print(f"\nValidation Accuracy: {accuracy:.2f}%") + + +if __name__ == "__main__": + # Example usage + model_path = "/home/ge.polymtl.ca/p121315/rsna_git/lumbar-classification-rsna-challenge-2024/models/mil_models/mil_model_sas566773/best_mil_model.pth" + data_dir = '../../duke/public/rsna_challenge/20250212nii_data_splits' + csv_file = '../../duke/public/rsna_challenge/dcom_data/train.csv' + + run_inference_on_severe_cases(model_path, data_dir, csv_file) +