Skip to content

Commit 034d341

Browse files
committed
Initial commit
1 parent 976ce60 commit 034d341

20 files changed

+10849
-2
lines changed

1_prep_warp.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import torch
2+
import numpy as np
3+
import pandas as pd
4+
import nibabel as nib
5+
import torch.nn.functional as F
6+
from tqdm import tqdm
7+
from pathlib import Path
8+
from niftiai import TensorImage3d
9+
from src.utils import ALIGN, create_grid, gauss_smoothing, LinearElasticity
10+
TEMPLATE_SIZE = (168, 204, 168)
11+
TEMPLATE_SHAPE = (113, 137, 113)
12+
TEMPLATE_ORIGIN = (72, 120, 84)
13+
data_path = 'data'
14+
15+
16+
def compose_affine(translation, rotation, zoom, shear):
17+
affine = torch.eye(4, device=translation.device)
18+
ZS = torch.diag(zoom)
19+
ZS[0, -2:] = shear[:2]
20+
ZS[1, 2] = shear[2]
21+
affine[:3, :3] = torch.mm(rotation, ZS)
22+
affine[:3, 3] = translation
23+
return affine
24+
25+
26+
def fill_nans(iy_disp, grid, nans):
27+
iy_disp_org = iy_disp.clone()
28+
nans_in_template = F.grid_sample(nans[None, None].float(), grid, align_corners=ALIGN)[0, 0]
29+
nans_in_template = nans_in_template > .0001
30+
empty_mask = (grid[0, ..., 0].abs() > 1) + (grid[0, ..., 1].abs() > 1) + (grid[0, ..., 2].abs() > 1) > 0
31+
nan_mask = nans_in_template | empty_mask
32+
elast = LinearElasticity()
33+
fill = torch.nn.Parameter(torch.zeros(int(nan_mask.sum()), 3, device=iy_disp.device, requires_grad=ALIGN))
34+
opt = torch.optim.Adam([fill], lr=1e-3)
35+
for i in range(100):
36+
opt.zero_grad()
37+
iy_disp[nan_mask] = fill
38+
loss = elast(iy_disp[None])
39+
loss.backward(retain_graph=True)
40+
opt.step()
41+
iy_disp = iy_disp.detach()
42+
iy_disp = gauss_smoothing(iy_disp.permute(3, 0, 1, 2)[None])[0].permute(1, 2, 3, 0)
43+
iy_disp_org[nan_mask] = iy_disp[nan_mask]
44+
return iy_disp_org
45+
46+
47+
def decompose_iy(iy, mask, nans):
48+
iy = (iy + torch.tensor(TEMPLATE_ORIGIN, device=iy.device)) / torch.tensor(TEMPLATE_SIZE, device=iy.device)
49+
iy = (iy * 2) - 1
50+
translation = torch.nn.Parameter(torch.zeros(3).cuda(), requires_grad=True)
51+
rotation = torch.nn.Parameter(torch.eye(3).cuda(), requires_grad=True)
52+
zoom = torch.nn.Parameter(torch.ones(3).cuda(), requires_grad=True)
53+
shear = torch.nn.Parameter(torch.zeros(3).cuda(), requires_grad=True)
54+
opt = torch.optim.Adam([translation, rotation, zoom, shear], lr=3e-2)
55+
pbar = tqdm(range(100), disable=True)
56+
for _ in pbar:
57+
opt.zero_grad()
58+
affine = compose_affine(translation, rotation, zoom, shear)
59+
affine_grid = F.affine_grid(affine[None, :3], [1, 3, *iy.shape[:3]], align_corners=ALIGN)
60+
loss = ((iy[None, mask, :] - affine_grid[:, mask, :]) ** 2).mean()
61+
pbar.set_description(f'{loss.item()}')
62+
loss.backward()
63+
opt.step()
64+
iy_disp = iy.detach() - affine_grid.detach()[0]
65+
inv_affine = torch.linalg.inv(affine.detach())
66+
inv_affine_grid = F.affine_grid(inv_affine[None, :3], [1, 3, *TEMPLATE_SHAPE], align_corners=ALIGN)
67+
iy_disp = iy_disp.permute(3, 0, 1, 2)[None]
68+
iy_disp_in_template_space = F.grid_sample(iy_disp, inv_affine_grid, align_corners=ALIGN)[0].permute(1, 2, 3, 0)
69+
iy_disp_in_template_space = fill_nans(iy_disp_in_template_space, inv_affine_grid, nans)
70+
return iy_disp_in_template_space.detach(), affine
71+
72+
73+
def mse(x, y):
74+
return ((x - y) ** 2).mean()
75+
76+
77+
class SyN:
78+
def __init__(self, time_steps=7, factor_diffeo=.1, sim_func=mse, mu=2., lam=1., optimizer=torch.optim.Adam):
79+
self.time_steps = time_steps
80+
self.factor_diffeo = factor_diffeo
81+
self.sim_func = sim_func
82+
self.reg_func = LinearElasticity(mu, lam, refresh_id_grid=True)
83+
self.optimizer = optimizer
84+
self.grid = None
85+
86+
def fit_xy(self, targ_f_yx, iterations, learning_rate):
87+
x = 0 * targ_f_yx[:, :1]
88+
y = 0 * targ_f_yx[:, :1]
89+
self.grid = create_grid(x.shape[2:], x.device, dtype=x.dtype)
90+
v_xy = torch.zeros((x.shape[0], 3, *x.shape[2:]), device=x.device, dtype=x.dtype)
91+
v_yx = torch.zeros((x.shape[0], 3, *x.shape[2:]), device=x.device, dtype=x.dtype)
92+
v_xy = torch.nn.Parameter(v_xy, requires_grad=True)
93+
v_yx = torch.nn.Parameter(v_yx, requires_grad=True)
94+
optimizer = self.optimizer([v_xy, v_yx], learning_rate)
95+
for i in range(iterations):
96+
optimizer.zero_grad()
97+
images, flows = self.apply_flows(x, y, v_xy, v_yx)
98+
loss = self.sim_func(targ_f_yx, flows['yx_full'])
99+
loss.backward()
100+
optimizer.step()
101+
return flows['xy_full'].detach(), v_xy, v_yx, loss.detach().item()
102+
103+
def apply_flows(self, x, y, v_xy, v_yx):
104+
half_flows = self.diffeomorphic_transform(torch.cat([v_xy, v_yx, -v_xy, -v_yx]))
105+
half_images = self.spatial_transform(torch.cat([x, y]), half_flows[:2])
106+
full_flows = self.composition_transform(half_flows[:2], half_flows[2:].flip(0))
107+
full_images = self.spatial_transform(torch.cat([x, y]), full_flows)
108+
images = {'xy_half': half_images[:1], 'yx_half': half_images[1:2],
109+
'xy_full': full_images[:1], 'yx_full': full_images[1:2]}
110+
flows = {'xy_half': half_flows[:1], 'yx_half': half_flows[1:2],
111+
'xy_full': full_flows[:1], 'yx_full': full_flows[1:2]}
112+
return images, flows
113+
114+
def diffeomorphic_transform(self, flow):
115+
flow = self.factor_diffeo * flow / (2 ** self.time_steps)
116+
for i in range(self.time_steps):
117+
flow = flow + self.spatial_transform(flow, flow)
118+
return flow
119+
120+
def composition_transform(self, flow_1, flow_2):
121+
return flow_2 + self.spatial_transform(flow_1, flow_2)
122+
123+
def spatial_transform(self, x, flow):
124+
return F.grid_sample(x.type(torch.float32), self.grid.type(torch.float32) + flow.permute(0, 2, 3, 4, 1),
125+
align_corners=ALIGN, padding_mode='border')
126+
127+
128+
def preprocess_cat12_registration(p0_filepaths, iy_filepaths, y_filepaths, dest_dir=None):
129+
nib_affine = np.array([[-1.5,0,0,84], [0,1.5,0,-120], [0,0,1.5,-72], [0,0,0,0]])
130+
for p0_fpath, iy_fpath, y_fpath in tqdm(zip(p0_filepaths, iy_filepaths, y_filepaths), total=len(y_filepaths)):
131+
p0 = TensorImage3d.create(p0_fpath)[0].cuda()
132+
iy = nib.load(iy_fpath)
133+
iy = TensorImage3d.create(iy.get_fdata(), affine=iy.affine, header=iy.header).cuda()
134+
y = nib.load(y_fpath)
135+
y = TensorImage3d.create(y.get_fdata(), affine=y.affine, header=y.header).cuda()
136+
y, iy = y[..., 0, :].flip(3), iy[..., 0, :].flip(3)
137+
brainmask = p0 > .001
138+
nans = torch.isnan(iy[..., 0])
139+
iy[nans, :] = 0
140+
mask = (~brainmask & ~nans)
141+
iy_disp, affine = decompose_iy(iy, mask, nans)
142+
iy_disp = iy_disp.permute(3, 0, 1, 2)
143+
syn = SyN()
144+
y_disp, v_xy, v_yx, lss = syn.fit_xy(iy_disp[None], iterations=100, learning_rate=1e-1)
145+
if dest_dir is not None:
146+
filename = p0_fpath.split('/')[-1].split('.')[0][2:]
147+
pd.DataFrame(affine).to_csv(f'{dest_dir}/affine/{filename}.csv', index=False)
148+
TensorImage3d(iy_disp, affine=nib_affine, header=iy.header).save(f'{dest_dir}/flow_yx/{filename}.nii.gz')
149+
TensorImage3d(y_disp[0], affine=nib_affine, header=iy.header).save(f'{dest_dir}/flow_xy/{filename}.nii.gz')
150+
TensorImage3d(v_yx[0], affine=nib_affine, header=iy.header).save(f'{dest_dir}/v_yx/{filename}.nii.gz')
151+
TensorImage3d(v_xy[0], affine=nib_affine, header=iy.header).save(f'{dest_dir}/v_xy/{filename}.nii.gz')
152+
153+
154+
if __name__ == '__main__':
155+
df = pd.read_csv(f'{data_path}/csvs/openneuro_hd.csv')
156+
cat_dir = f'{data_path}/t1/CAT12.8.2'
157+
p0_fps = cat_dir + '/mri/p0' + df.filename + '.nii'
158+
iy_fps = cat_dir + '/mri/iy_' + df.filename + '.nii'
159+
y_fps = cat_dir + '/mri/y_' + df.filename + '.nii'
160+
for subdir in ['affine', 'flow_xy', 'flow_yx', 'v_xy', 'v_yx']: Path(f'{data_path}/{subdir}').mkdir(exist_ok=True)
161+
preprocess_cat12_registration(p0_fps, iy_fps, y_fps, dest_dir=data_path)

2_prep_segment.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import ants
2+
import torch
3+
import numpy as np
4+
import pandas as pd
5+
import torch.nn.functional as F
6+
from tqdm import tqdm
7+
from pathlib import Path
8+
from niftiai import TensorImage3d
9+
from spline_resize import resize, grid_sample
10+
ALIGN = True
11+
data_path = 'data'
12+
13+
14+
def min_max(x, low=.005, high=.995):
15+
mask = x > 0
16+
low, high = np.percentile(x[mask].cpu(), 100 * low), np.percentile(x[mask].cpu(), 100 * high)
17+
x = (x - low) / (high - low)
18+
x[x > 1] = 1 + torch.log10(x[x > 1])
19+
return x.clamp(min=0)
20+
21+
22+
def one_hot(o, n_classes=4):
23+
o = o.clip(max=n_classes - 1)
24+
one_h = torch.zeros((n_classes, *o.shape[1:]), dtype=torch.float32, device=o.device)
25+
for c in range(n_classes):
26+
mask = o[0].gt(c - 1) & o[0].le(c + 1)
27+
one_h[c, mask] = 1 - (o[0][mask] - c).abs()
28+
return one_h
29+
30+
31+
if __name__ == '__main__':
32+
data_subdirs = ['bet', 'bc', 'img_05mm', 'img_05mm_minmax', 'img_05mm_minmax_raw', 'img_075mm_minmax',
33+
'img_075mm_minmax_raw', 'p0_05mm', 'nogm', 'p0_075mm', 'p']
34+
for subdir in data_subdirs: Path(f'{data_path}/{subdir}').mkdir(exist_ok=True)
35+
shape_05mm = (339, 411, 339)
36+
shape_075mm = (224, 256, 224)
37+
shape_15mm = (113, 137, 113)
38+
nib_affine_05mm = np.array([[.5, 0, 0, -84], [0, .5, 0, -120], [0, 0, .5, -72], [0, 0, 0, 0]])
39+
nib_affine_075mm = np.array([[.75, 0, 0, -84], [0, .75, 0, -120], [0, 0, .75, -72], [0, 0, 0, 0]])
40+
nib_affine_15mm = np.array([[1.5, 0, 0, -84], [0, 1.5, 0, -120], [0, 0, 1.5, -72], [0, 0, 0, 0]])
41+
fns = pd.read_csv('data/csvs/openneuro_hd.csv').filename
42+
for fn in tqdm(fns):
43+
affine = pd.read_csv(f'{data_path}/affine/{fn}.csv')
44+
affine = torch.linalg.inv(torch.from_numpy(affine.values).cuda()).float()
45+
im = TensorImage3d.create(f'{data_path}/t1/{fn}.nii.gz').cuda()
46+
header = im.header
47+
p0 = TensorImage3d.create(f'{data_path}/t1/CAT12.8.2/mri/p0{fn}.nii').cuda()
48+
im_bet = im.clone()
49+
mask = p0 <= 0
50+
im_bet[mask] = 0 # brain extraction
51+
im_bet.affine, im_bet.header = im.affine, header
52+
bet_fp = f'{data_path}/bet/{fn}.nii.gz'
53+
im_bet.save(bet_fp)
54+
bc_fp = f'{data_path}/bc/{fn}.nii.gz'
55+
ants.n4_bias_field_correction(ants.image_read(bet_fp)).to_file(bc_fp) # bias correction
56+
im_bc = TensorImage3d.create(bc_fp).cuda()
57+
grid = F.affine_grid(affine[None, :3], [1, 3, *shape_05mm], align_corners=ALIGN)
58+
header.set_data_dtype(np.float32)
59+
mask = F.grid_sample(mask[None].float(), grid, mode='nearest')[0, :, 1:-2, 15:-12, :336]
60+
zeromask = mask > 0
61+
im = grid_sample(im[None], grid, align_corners=ALIGN)[0, :, 1:-2, 15:-12, :336]
62+
im_bc = grid_sample(im_bc[None], grid, align_corners=ALIGN, mask_value=0)[0, :, 1:-2, 15:-12, :336]
63+
im_bet = grid_sample(im_bet[None], grid, align_corners=ALIGN, mask_value=0)[0, :, 1:-2, 15:-12, :336]
64+
im = TensorImage3d(im, affine=nib_affine_05mm, header=header)
65+
im.save(f'{data_path}/img_05mm/{fn}.nii.gz') # run CAT12 on these files
66+
im_bc = TensorImage3d(min_max(im_bc), affine=nib_affine_05mm, header=header)
67+
im_bc.save(f'{data_path}/img_05mm_minmax/{fn}.nii.gz') # train input(=brain extracted+bias corrected @ 0.5mm)
68+
im_bet = TensorImage3d(min_max(im_bet), affine=nib_affine_05mm, header=header)
69+
im_bet.save(f'{data_path}/img_05mm_minmax_raw/{fn}.nii.gz') # eval input(=brain extracted @ 0.5mm)
70+
im_bc = resize(im_bc[None], shape_075mm, align_corners=ALIGN, mask_value=0)[0] # interpolate to 0.75mm
71+
im_bet = resize(im_bet[None], shape_075mm, align_corners=ALIGN, mask_value=0)[0] # interpolate to 0.75mm
72+
TensorImage3d(im_bc, affine=nib_affine_075mm, header=header).save(f'{data_path}/img_075mm_minmax/{fn}.nii.gz')
73+
TensorImage3d(im_bet, affine=nib_affine_075mm, header=header).save(f'{data_path}/img_075mm_minmax_raw/{fn}.nii.gz')
74+
print('Run CAT12 on all "img_05mm/..."-files, then run commented code block at the end of this script')
75+
# for fn in tqdm(fns):
76+
# affine = pd.read_csv(f'{data_path}/affine/{fn}.csv')
77+
# affine = torch.linalg.inv(torch.from_numpy(affine.values).cuda()).float()
78+
# im = TensorImage3d.create(f'{data_path}/t1/{fn}.nii.gz').cuda()
79+
# header = im.header
80+
# p0 = TensorImage3d.create(f'{data_path}/t1/CAT12.8.2/mri/p0{fn}.nii').cuda()
81+
# mask = p0 <= 0
82+
# grid = F.affine_grid(affine[None, :3], [1, 3, *shape_05mm], align_corners=ALIGN)
83+
# header.set_data_dtype(np.float32)
84+
# mask = F.grid_sample(mask[None].float(), grid, mode='nearest')[0, :, 1:-2, 15:-12, :336]
85+
# zeromask = mask > 0
86+
# p0 = TensorImage3d.create(f'{data_path}/img_05mm/CAT12.8.2/mri/p0{fn}.nii').cuda()
87+
# p0_header = p0.header
88+
# p0[zeromask] = 0
89+
# p0.save(f'{data_path}/p0_05mm/{fn}.nii.gz') # train target for patchwise brain segm.
90+
# p1 = TensorImage3d.create(f'{data_path}/img_05mm/CAT12.8.2/mri/p1{fn}.nii').cuda()
91+
# p1[zeromask] = 0
92+
# p1_pre = one_hot(p0[None])[2]
93+
# nogm = ((p1_pre - p1) > .015)
94+
# TensorImage3d(nogm, affine=nib_affine_05mm, header=header).save(f'{data_path}/nogm/{fn}.nii.gz') # train target for nogm
95+
# p0 = resize(p0[None], shape_075mm, align_corners=ALIGN, mask_value=0)[0]
96+
# p0 = TensorImage3d(p0, affine=nib_affine_075mm, header=p0_header)
97+
# p0.save(f'{data_path}/p0_075mm/{fn}.nii.gz') # train target for brain segm.
98+
# p1 = TensorImage3d.create(f'{data_path}/t1/CAT12.8.2/mri/p1{fn}.nii').cuda()
99+
# p2 = TensorImage3d.create(f'{data_path}/t1/CAT12.8.2/mri/p2{fn}.nii').cuda()
100+
# p3 = TensorImage3d.create(f'{data_path}/t1/CAT12.8.2/mri/p3{fn}.nii').cuda()
101+
# header = p1.header
102+
# header.set_data_dtype(np.float32)
103+
# grid = F.affine_grid(affine[None, :3], (1, 3, *shape_15mm), align_corners=True)
104+
# p = grid_sample(torch.cat([p1, p2, p3])[None], grid, align_corners=True, mask_value=0)[0].clip(0, 1)
105+
# p.affine, p.header = nib_affine_15mm, header
106+
# p.save(f'{data_path}/p/{fn}.nii.gz') # train input for syn registration

3_train_segment.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from tqdm import tqdm
2+
from fastai.basics import np, pd, mae, torch, set_seed, Path, Learner
3+
from niftiai import aug_transforms3d, TensorImage3d, SegmentationDataLoaders3d
4+
from spline_resize import resize
5+
from src.augment import Blur3d, ScaledChiNoise3d
6+
from src.models import Unet3d, StepActivation
7+
from src.transforms import StoreZeroMask, ApplyZeroMask, ScaleIntensity
8+
set_seed(1)
9+
path = '.' # if data_path is absolute(=starts with "/") set path = '/'
10+
data_path = 'data'
11+
Path(f'{data_path}/models').mkdir(exist_ok=True)
12+
Path(f'{data_path}/p0_05mm_pred').mkdir(exist_ok=True)
13+
nib_affine_05mm = np.array([[.5, 0, 0, -84], [0, .5, 0, -120], [0, 0, .5, -72], [0, 0, 0, 0]])
14+
15+
shape = (336, 384, 336)
16+
df = pd.read_csv(f'{data_path}/csvs/openneuro_hd.csv')
17+
df['img'] = f'{data_path}/img_075mm_minmax/' + df.filename + '.nii.gz'
18+
df['mask'] = f'{data_path}/p0_075mm/' + df.filename + '.nii.gz'
19+
header = TensorImage3d.create(df['mask'].iloc[0]).header
20+
batch_tfms = aug_transforms3d(max_warp=0, max_zoom=0, max_rotate=0, max_shear=0, max_translate=.02, p_affine=.2,
21+
max_ghost=.5, max_spike=2., max_bias=.2, max_motion=.5, max_noise=.0, max_down=2,
22+
max_ring=1., max_contrast=.1, max_dof_noise=3, image_mode='nearest',
23+
dims_ghost=(0, 1, 2), n_ghosts=2, p_spike=.1, freq_spike=.5, dims_ring=(0, 1, 2))
24+
batch_tfms += [StoreZeroMask(), ScaledChiNoise3d(.1, p=.1), Blur3d(.5, p=.1), ApplyZeroMask(), ScaleIntensity()]
25+
model = torch.nn.Sequential(Unet3d(n_out=1), StepActivation())
26+
# train on full openneuro-hd dataset
27+
df_total = df.copy()
28+
df_total.loc[len(df_total)] = df_total.loc[0]
29+
df_total['is_valid'] = (len(df_total) - 1) * [0] + [1]
30+
dls = SegmentationDataLoaders3d.from_df(df_total, path=path, fn_col='img', label_col='mask',
31+
valid_col='is_valid', bs=1, batch_tfms=batch_tfms)
32+
learn = Learner(dls, model=model, loss_func=mae)
33+
learn.model = learn.model.cuda()
34+
learn.fit_one_cycle(60, 1e-3)
35+
torch.save(learn.model.state_dict(), f'{data_path}/models/segmentation_model.pth')
36+
#learn.model.load_state_dict(torch.load(f'{DATA_PATH}/models/segmentation_model.pth')) # load model
37+
for fp in tqdm(df_total.img[:-1]):
38+
filename = Path(fp).stem.split('.')[0]
39+
x = TensorImage3d.create(fp.replace('_minmax', '_minmax_raw')).cuda()
40+
with torch.no_grad():
41+
p = learn.model(x[None])
42+
p = resize(p, shape, align_corners=True, mask_value=0)[0]
43+
p = TensorImage3d(p, affine=nib_affine_05mm, header=header)
44+
p.header.set_data_dtype(np.uint8)
45+
p.save(f'{data_path}/p0_05mm_pred/{filename}.nii.gz')
46+
# train cross validation
47+
for fold in range(5):
48+
set_seed(1)
49+
df['is_valid'] = df.fold == fold
50+
dls = SegmentationDataLoaders3d.from_df(df, path=path, fn_col='img', label_col='mask',
51+
valid_col='is_valid', bs=1, batch_tfms=batch_tfms)
52+
learn = Learner(dls, model=model, loss_func=mae)
53+
learn.model = learn.model.cuda()
54+
learn.fit_one_cycle(60, 1e-3)
55+
torch.save(learn.model.state_dict(), f'{data_path}/model/segmentation_model_fold{fold}.pth')
56+
# learn.model.load_state_dict(torch.load(f'{DATA_PATH}/models/segmentation_model_fold{fold}.pth')) # load model
57+
for fp in tqdm(df.img):
58+
filename = Path(fp).stem.split('.')[0]
59+
x = TensorImage3d.create(fp.replace('_minmax', '_minmax_raw')).cuda()
60+
with torch.no_grad():
61+
p = learn.model(x[None])
62+
p = resize(p, shape, align_corners=True, prefilter=True, mask_value=0)[0]
63+
p = TensorImage3d(p, affine=nib_affine_05mm, header=header)
64+
p.header.set_data_dtype(np.uint8)
65+
p.save(f'{data_path}/p0_05mm_pred/{filename}_fold{fold}.nii.gz')

0 commit comments

Comments
 (0)