Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
6a68533
add a gitignore
Nov 21, 2024
69357c0
remove of undesired files
Nov 21, 2024
4778473
training code for sbs added
Nov 26, 2024
1cd1ecc
add a code to train for nfn with both left and right, rnfn with only …
Nov 27, 2024
9cb7783
Realised the crop and resampling were not done on the good coordinate…
Nov 28, 2024
3132064
tried to change a few things for nfn
Nov 30, 2024
f10b2ca
Modification of nfn to take into account the fact the images are not …
Dec 1, 2024
0b930b0
minor adjustment
Dec 2, 2024
bbed1e6
the nfn takes only T1 scans
Dec 2, 2024
3d7dfb1
Tested a few variations with data augmentation in the transform function
Dec 3, 2024
5be05ea
added a weight decay for training to reduce overfitting
Dec 3, 2024
346ffe9
created a first version of the code for sas and modifications on nfn
Dec 4, 2024
ffb619b
try to change from Adam optimizer to SGD for scs
Dec 4, 2024
ee44d98
Changed the classic block to bottelneck
Dec 4, 2024
39e032e
before trying to do patch extraction in preprocessing
Dec 5, 2024
158c9c1
add a file to train using volume and data augmentation
abelsalm Dec 10, 2024
249a73d
added train,ing with data augmentation, see issue associated
Dec 11, 2024
cea2e05
first data augmentation plots added in models (.pth still in .gitignore)
Dec 13, 2024
36f8ad4
gitignore changed
Dec 13, 2024
0dafdff
nfn training with volumes and data aug
abelsalm Dec 14, 2024
b994273
nfn training with data aug works, also added the curves of the trainings
Dec 14, 2024
719f67a
nfn data aug working
Dec 14, 2024
fccfbfb
sas training with volumes and data aug
Dec 15, 2024
84e6dc1
training with splits among subjects, and data aug
Jan 15, 2025
9106ce9
added a code to train nfn based on the patient split
Jan 26, 2025
be313fb
The code for train_nfn_patient_split didn t work but I fixed it
Jan 27, 2025
379760c
wandb configured! see in the weight and biases folder...
Jan 27, 2025
90685ae
Merge remote-tracking branch 'origin/ResNetTraining' into ResNetTraining
Jan 27, 2025
75456bd
modif split
Jan 27, 2025
6470608
Merge branch 'ResNetTraining' of https://github.com/ivadomed/lumbar-c…
Jan 27, 2025
830ad27
Cleaning WIP
Jan 27, 2025
972a060
cutmix on scs
Jan 28, 2025
49cd566
Merge branch 'ResNetTraining' of https://github.com/ivadomed/lumbar-c…
Jan 28, 2025
2087ad8
corrected nfn
Jan 28, 2025
8831550
tests with cutmix
Jan 28, 2025
bc0726e
Merge branch 'ResNetTraining' of github.com:ivadomed/lumbar-classific…
Jan 28, 2025
9f01fde
cutmix manually
Jan 29, 2025
dfd5673
modifications on images for train_nfn_patient
Jan 29, 2025
189a39c
Merge branch 'ResNetTraining' of https://github.com/ivadomed/lumbar-c…
Jan 29, 2025
e64c52e
cutmix and mixup working
Jan 29, 2025
0f2fbff
Merge branch 'ResNetTraining' of github.com:ivadomed/lumbar-classific…
Jan 29, 2025
6f41b62
cutmix corrected
Jan 30, 2025
c598da8
from adam optimizer to adamw
Jan 31, 2025
12ea4d2
made a code to compute inference on validation split and compute conf…
Jan 31, 2025
ec1ce35
supression of wb files
abelsalm Feb 5, 2025
96ffa8a
minor changes
Feb 5, 2025
fcbfc1d
Merge branch 'ResNetTraining' of github.com:ivadomed/lumbar-classific…
Feb 5, 2025
b59e9e5
working version before removing the split part if I need to come back…
Feb 8, 2025
3e5bccf
Merge branch 'ResNetTraining' of https://github.com/ivadomed/lumbar-c…
Feb 8, 2025
b06af32
first commit, training function done
abelsalm Feb 14, 2025
62b4b69
training working
Feb 14, 2025
01ab7c0
ConvNextAxial
abelsalm Feb 16, 2025
0b9b9c9
prepared for training
abelsalm Feb 16, 2025
55f1078
prepare data works, model an training still to be tested !
abelsalm Mar 16, 2025
581226d
maj gitignore
Mar 17, 2025
abb76fd
adjustment
Mar 17, 2025
e482a74
adding model mil
abelsalm Mar 17, 2025
0a694ed
training works ! also added a bit more data augmentation...
Mar 17, 2025
a43b024
tested data aug, works well !
abelsalm Mar 17, 2025
492ed64
gitignore
Mar 17, 2025
f508204
version with RNN working
Mar 17, 2025
480cc4d
adding different schedule for encoder, including freezing it...
Mar 18, 2025
71b2aa1
up to date
Apr 3, 2025
2364237
removed CustomDataset
Apr 8, 2025
5e24670
insertion of regular_transforms in the gat_transforms and added the n…
Apr 8, 2025
371b60e
added the training for nfn
Apr 8, 2025
dc419e8
added the versions of training for nfn
Apr 9, 2025
8f231b9
minor modif
Apr 10, 2025
38294e7
made a loss without weights for training the encoder
Apr 24, 2025
6595a9d
final push before checkout
May 5, 2025
ce91439
last changes on MIL, to be cleaned still
May 12, 2025
acab887
Merge branch 'MILtraining' of github.com:ivadomed/lumbar-classificati…
abelsalm May 12, 2025
4cbd427
nfn training
abelsalm May 12, 2025
c2f562c
cleaning bien commencé, réorganisation et readme, plus retirer des fo…
abelsalm May 12, 2025
f4519f3
rdme
abelsalm May 12, 2025
490541c
rdme
abelsalm May 12, 2025
23f63d4
everything works
abelsalm May 16, 2025
c3e1fcf
removed old files
Jun 25, 2025
1b5a2a6
put everything related to mil training in MIL_training. Also standard…
Jun 25, 2025
b16a454
removed duplicated files
tomDag25 Jun 30, 2025
977f128
removed the distinction between left and right
tomDag25 Jun 30, 2025
19bbe23
harmonization of the code between the three pathologies
tomDag25 Jun 30, 2025
d32f0a1
Code cleaning
tomDag25 Jun 30, 2025
d0cd246
changed the padding moment to not have only noise slices
Jul 7, 2025
25a19b4
test
Jul 7, 2025
4c601d4
validate merge
Jul 7, 2025
d210b81
standardization of the training code based on nfn
Jul 7, 2025
579031e
hyperparameters modification
Jul 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
totalspineseg/
__pycache__
output.csv
dataset_split.csv
model*
*.pth
20*
models/
**/wandb/
**/*run*/
myenv/
saved_batches/
*.csv
mil_model_*
224 changes: 224 additions & 0 deletions MIL_training/augment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import sys, argparse, textwrap
import multiprocessing as mp
from functools import partial
from tqdm.contrib.concurrent import process_map
from pathlib import Path
import nibabel as nib
import numpy as np
import torchio as tio
import gryds
import scipy.ndimage as ndi
from scipy.stats import norm
import warnings
from augment import *



warnings.filterwarnings("ignore")


def aug_histogram_equalization(image1, seg, image2):
img_min1, img_max1 = image1.min(), image1.max()
img_min2, img_max2 = image2.min(), image2.max()

image1_flattened = image1.flatten()
hist1, bins1 = np.histogram(image1_flattened, bins=256, range=[image1_flattened.min(), image1_flattened.max()])
cdf1 = hist1.cumsum()
cdf_normalized1 = cdf1 * (hist1.max() / cdf1.max())
image1 = np.interp(image1_flattened, bins1[:-1], cdf_normalized1).reshape(image1.shape)
image1 = np.interp(image1, (image1.min(), image1.max()), (img_min1, img_max1))

image2_flattened = image2.flatten()
hist2, bins2 = np.histogram(image2_flattened, bins=256, range=[image2_flattened.min(), image2_flattened.max()])
cdf2 = hist2.cumsum()
cdf_normalized2 = cdf2 * (hist2.max() / cdf2.max())
image2 = np.interp(image2_flattened, bins2[:-1], cdf_normalized2).reshape(image2.shape)
image2 = np.interp(image2, (image2.min(), image2.max()), (img_min2, img_max2))

return image1, seg, image2

def aug_transform(image1, transform):
img_min1, img_max1 = image1.min(), image1.max()


image1 = (image1 - image1.mean()) / image1.std()
image1 = np.interp(image1, (image1.min(), image1.max()), (0, 1))

image1 = transform(image1)


image1 = np.interp(image1, (image1.min(), image1.max()), (img_min1, img_max1))


return image1

def aug_log(image1):
return aug_transform(image1, lambda x: np.log(1 + x))

def aug_sqrt(image1 ):
return aug_transform(image1, np.sqrt)

def aug_sin(image1):
return aug_transform(image1, np.sin)

def aug_exp(image1):
return aug_transform(image1, np.exp)

def aug_sig(image1):
return aug_transform(image1, lambda x: 1 / (1 + np.exp(-x)))

def aug_laplace(image1):
return aug_transform(image1, lambda x: np.abs(ndi.laplace(x)))

def aug_inverse(image1):
image1 = image1.min() + image1.max() - image1
return image1

def aug_bspline(image1, seg, image2):
grid = rs.rand(3, 3, 3, 3)
bspline = gryds.BSplineTransformation((grid - .5) / 5)
grid[:, 0] += ((grid[:, 0] > 0) * 2 - 1) * .9
image1 = gryds.Interpolator(image1).transform(bspline).astype(np.float64)
image2 = gryds.Interpolator(image2).transform(bspline).astype(np.float64)
seg = gryds.Interpolator(seg, order=0).transform(bspline).astype(np.uint8)
return image1, seg, image2

def aug_flip(image1, seg, image2):
subject = tio.RandomFlip(axes=('LR',))(tio.Subject(
image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)),
seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)),
image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0))
))
return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64)

def aug_aff(image1, seg, image2):
subject = tio.RandomAffine()(tio.Subject(
image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)),
seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)),
image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0))
))
return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64)

def aug_elastic(image1, seg, image2):
subject = tio.RandomElasticDeformation(max_displacement=40)(tio.Subject(
image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)),
seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)),
image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0))
))
return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64)

def aug_anisotropy(image1, seg, image2, downsampling=7):
subject = tio.RandomAnisotropy(downsampling=downsampling)(tio.Subject(
image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)),
seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)),
image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0))
))
return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64)

def aug_motion(image1, seg, image2):
subject = tio.RandomMotion()(tio.Subject(
image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)),
seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)),
image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0))
))
return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64)

def aug_ghosting(image1, seg, image2):
subject = tio.RandomGhosting()(tio.Subject(
image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)),
seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)),
image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0))
))
return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64)

def aug_spike(image1, seg, image2):
subject = tio.RandomSpike(intensity=(1, 2))(tio.Subject(
image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)),
seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)),
image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0))
))
return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64)

def aug_bias_field(image1, image2, seg):
subject = tio.RandomBiasField()(tio.Subject(
image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)),
seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)),
image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0))
))
return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64)

def aug_blur(image1, seg, image2):
subject = tio.RandomBlur()(tio.Subject(
image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)),
seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)),
image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0))
))
return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64)

def aug_noise(image1, seg, image2):
original_mean1, original_std1 = np.mean(image1), np.std(image1)
original_mean2, original_std2 = np.mean(image2), np.std(image2)

image1 = (image1 - original_mean1) / original_std1
image2 = (image2 - original_mean2) / original_std2

subject = tio.RandomNoise()(tio.Subject(
image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)),
seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)),
image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0))
))
image1 = image1 * original_std1 + original_mean1
image2 = image2 * original_std2 + original_mean2

return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64)

def aug_swap(image1, seg, image2):
subject = tio.RandomSwap()(tio.Subject(
image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)),
seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)),
image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0))
))
return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64)

def aug_labels2image(image1, seg, image2, leave_background=0.5, classes=None):
_seg = seg
if classes:
_seg = combine_classes(seg, classes)
subject = tio.RandomLabelsToImage(label_key="seg", image_key="image")(tio.Subject(
seg=tio.LabelMap(tensor=np.expand_dims(_seg, axis=0))
))
new_img = subject.image.data.squeeze().numpy().astype(np.float64)

if rs.rand() < leave_background:
img_min1, img_max1 = np.min(image1), np.max(image1)
_image1 = (image1 - img_min1) / (img_max1 - img_min1)

new_img_min, new_img_max = np.min(new_img), np.max(new_img)
new_img = (new_img - new_img_min) / (new_img_max - new_img_min)
new_img[_seg == 0] = _image1[_seg == 0]
new_img = np.interp(new_img, (new_img.min(), new_img.max()), (img_min1, img_max1))

return new_img, seg, image2

def aug_gamma(image1, seg, image2):
subject = tio.RandomGamma()(tio.Subject(
image=tio.ScalarImage(tensor=np.expand_dims(image1, axis=0)),
seg=tio.LabelMap(tensor=np.expand_dims(seg, axis=0)),
image2=tio.ScalarImage(tensor=np.expand_dims(image2, axis=0))
))
return subject.image.data.squeeze().numpy().astype(np.float64), subject.seg.data.squeeze().numpy().astype(np.uint8), subject.image2.data.squeeze().numpy().astype(np.float64)


def parse_class(c):
c = [_.split('-') for _ in c.split(',')]
c = tuple(__ for _ in c for __ in list(range(int(_[0]), int(_[-1]) + 1)))
return c

def combine_classes(seg, classes):
_seg = np.zeros_like(seg)
for i, c in enumerate(classes):
_seg[np.isin(seg, c)] = i + 1
return _seg

if __name__ == '__main__':
main()
101 changes: 101 additions & 0 deletions MIL_training/mil_definition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
'''
File to introduce a MIL model
Note that loads of hyperparameters could be included as arguments
It could be avg pooling size, hidden dim, etc...
Also encoder could be changed to a different model
'''

import torch
import torch.nn as nn
import timm



class MILsection(nn.Module):
def __init__(self, input_dim, hidden_dim, num_classes, num_layers=1):
super(MILsection, self).__init__()
self.num_layers = num_layers
# RNN cells, here we use a GRU
if num_layers > 0:
self.rnn = nn.GRU(input_dim, input_dim//2, num_layers=num_layers,
batch_first=True, dropout=0.1, bidirectional=True)
# attention layer, here we use a simple linear layer
self.attention = nn.Sequential(
nn.Tanh(),
nn.Linear(input_dim, 1)
)

def forward(self, bags):
"""
Args:
bags: (batch_size, num_instances, input_dim)

Returns:
logits: (batch_size, num_classes)
"""

# bags iterates in the RNN
if self.num_layers > 0:
bags_rnn, _ = self.rnn(bags)
else:
bags_rnn = bags

# main attention
attn_scores = self.attention(bags_rnn).squeeze(-1) # [batch_size, num_instances]
attn_weights = torch.softmax(attn_scores, dim=-1) # [batch_size, num_instances]
weighted_instances = torch.bmm(attn_weights.unsqueeze(1), bags_rnn).squeeze(1) # [batch_size, input_dim]

return weighted_instances


class MILmodel(nn.Module):
# attention, a new encoder instance has to be created at each MIL creation
# if not, the same encoder could be used for all MIL sections
def __init__(self, encoder, num_layers=1):
super(MILmodel, self).__init__()
# encoder
self.encoder = encoder
# flattening layer, applying pooling and flattening
# note here that we could try different pooling methods
self.flatten = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(1)
)
# size of the features after encoding
# encoder outputs n feature maps
# that are flattened into a vector of size n
self.feature_size = self.encoder.num_features

# MIL section, loads of hyperparameters here also
self.mil_section = MILsection(input_dim=self.feature_size,
hidden_dim=1024,
num_classes=3,
num_layers=num_layers)
# classifier output
# we use a final simple linear layer to output the final prediction
self.classifier = nn.Linear(self.feature_size, 3)

def forward(self, x):
# x shape: (batch_size, 6, 1, 384, 384)
batch_size, num_instances, channels, H, W = x.shape

# Reshape to process all instances through encoder
x = x.reshape(-1, channels, H, W) # shape: (batch_size * 6, 1, 384, 384)

# Pass through encoder
x = self.encoder.forward_features(x) # shape: (batch_size * 6, feature_size, h', w')

# Apply pooling and flatten
x = self.flatten(x) # shape: (batch_size * 6, feature_size)

# Reshape back to separate instances
x = x.reshape(batch_size, num_instances, self.feature_size) # shape: (batch_size, 6, feature_size)

# Pass through MIL section
weighted_instances = self.mil_section(x)
# weighted_instances: (batch_size, feature_size)

# classification output
output = self.classifier(weighted_instances) # shape: (batch_size, 3)

return output
Loading