Skip to content

Commit 445bf4b

Browse files
authored
Merge pull request #35 from dhritimandas/add/synthstrip
Add/synthstrip
2 parents 5c79c95 + ed98edc commit 445bf4b

5 files changed

Lines changed: 329 additions & 0 deletions

File tree

DDIG/SynthStrip/1.0.0/Dockerfile

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# syntax=docker/dockerfile:1
2+
3+
FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime
4+
5+
RUN apt-get update \
6+
&& apt-get install --yes --quiet --no-install-recommends \
7+
ca-certificates \
8+
git \
9+
libgomp1 \
10+
gcc \
11+
&& rm -rf /var/lib/apt/lists/*
12+
13+
ENV LC_ALL=C.UTF-8 \
14+
LANG=C.UTF-8
15+
16+
# python packages
17+
COPY requirements.txt requirements.txt
18+
RUN pip install --no-cache-dir -r requirements.txt
19+
20+
# clean up
21+
RUN rm -rf /root/.cache/pip
22+
23+
WORKDIR /work
24+
LABEL maintainer="Hoda Rajaei <rajaei.hoda@gmail.com>"

DDIG/SynthStrip/1.0.0/predict.py

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
"""This code is adapted from FreeSurfer mri_synthstrip.py to be compatible for Nobrainer-zoo.
2+
3+
4+
If you use this code, please cite the SynthStrip paper:
5+
SynthStrip: Skull-Stripping for Any Brain Image.
6+
A Hoopes, JS Mora, AV Dalca, B Fischl, M Hoffmann.
7+
8+
https://github.com/freesurfer/freesurfer/blob/dev/mri_synthstrip/
9+
10+
11+
Copyright 2022 A Hoopes
12+
13+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
14+
compliance with the License. You may obtain a copy of the License at
15+
http://www.apache.org/licenses/LICENSE-2.0
16+
Unless required by applicable law or agreed to in writing, software distributed under the License is
17+
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
18+
implied. See the License for the specific language governing permissions and limitations under the
19+
License.
20+
"""
21+
22+
#!/usr/bin/env python
23+
24+
import os
25+
import sys
26+
import torch
27+
import torch.nn as nn
28+
import numpy as np
29+
import argparse
30+
import surfa as sf
31+
import scipy.ndimage
32+
33+
description = '''
34+
Robust, universal skull-stripping for brain images of any
35+
type. If you use SynthStrip in your analysis, please cite:
36+
37+
SynthStrip: Skull-Stripping for Any Brain Image.
38+
A Hoopes, JS Mora, AV Dalca, B Fischl, M Hoffmann.
39+
'''
40+
41+
# parse command line
42+
parser = argparse.ArgumentParser(description=description)
43+
parser.add_argument('-i', '--image', metavar='file', required=True, help='Input image to skullstrip.')
44+
parser.add_argument('-o', '--out', metavar='file', help='Save stripped image to path.')
45+
parser.add_argument('-m', '--mask', metavar='file', help='Save binary brain mask to path.')
46+
parser.add_argument('-g', '--gpu', action='store_true', help='Use the GPU.')
47+
parser.add_argument('-b', '--border', default=1, type=int, help='Mask border threshold in mm. Default is 1.')
48+
parser.add_argument('--model', metavar='file', help='Alternative model weights.')
49+
if len(sys.argv) == 1:
50+
parser.print_help()
51+
exit(1)
52+
args = parser.parse_args()
53+
54+
# sanity check on the inputs
55+
if not args.out and not args.mask:
56+
sf.system.fatal('Must provide at least --out or --mask output flags.')
57+
58+
# necessary for speed gains (I think)
59+
torch.backends.cudnn.benchmark = True
60+
torch.backends.cudnn.deterministic = True
61+
62+
# configure GPU device
63+
if args.gpu:
64+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
65+
device = torch.device('cuda')
66+
device_name = 'GPU'
67+
else:
68+
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
69+
device = torch.device('cpu')
70+
device_name = 'CPU'
71+
72+
# configure model
73+
print(f'Configuring model on the {device_name}')
74+
75+
class StripModel(nn.Module):
76+
77+
def __init__(self,
78+
nb_features=16,
79+
nb_levels=7,
80+
feat_mult=2,
81+
max_features=64,
82+
nb_conv_per_level=2,
83+
max_pool=2,
84+
return_mask=False):
85+
86+
super().__init__()
87+
88+
# dimensionality
89+
ndims = 3
90+
91+
# build feature list automatically
92+
if isinstance(nb_features, int):
93+
if nb_levels is None:
94+
raise ValueError('must provide unet nb_levels if nb_features is an integer')
95+
feats = np.round(nb_features * feat_mult ** np.arange(nb_levels)).astype(int)
96+
feats = np.clip(feats, 1, max_features)
97+
nb_features = [
98+
np.repeat(feats[:-1], nb_conv_per_level),
99+
np.repeat(np.flip(feats), nb_conv_per_level)
100+
]
101+
elif nb_levels is not None:
102+
raise ValueError('cannot use nb_levels if nb_features is not an integer')
103+
104+
# extract any surplus (full resolution) decoder convolutions
105+
enc_nf, dec_nf = nb_features
106+
nb_dec_convs = len(enc_nf)
107+
final_convs = dec_nf[nb_dec_convs:]
108+
dec_nf = dec_nf[:nb_dec_convs]
109+
self.nb_levels = int(nb_dec_convs / nb_conv_per_level) + 1
110+
111+
if isinstance(max_pool, int):
112+
max_pool = [max_pool] * self.nb_levels
113+
114+
# cache downsampling / upsampling operations
115+
MaxPooling = getattr(nn, 'MaxPool%dd' % ndims)
116+
self.pooling = [MaxPooling(s) for s in max_pool]
117+
self.upsampling = [nn.Upsample(scale_factor=s, mode='nearest') for s in max_pool]
118+
119+
# configure encoder (down-sampling path)
120+
prev_nf = 1
121+
encoder_nfs = [prev_nf]
122+
self.encoder = nn.ModuleList()
123+
for level in range(self.nb_levels - 1):
124+
convs = nn.ModuleList()
125+
for conv in range(nb_conv_per_level):
126+
nf = enc_nf[level * nb_conv_per_level + conv]
127+
convs.append(ConvBlock(ndims, prev_nf, nf))
128+
prev_nf = nf
129+
self.encoder.append(convs)
130+
encoder_nfs.append(prev_nf)
131+
132+
# configure decoder (up-sampling path)
133+
encoder_nfs = np.flip(encoder_nfs)
134+
self.decoder = nn.ModuleList()
135+
for level in range(self.nb_levels - 1):
136+
convs = nn.ModuleList()
137+
for conv in range(nb_conv_per_level):
138+
nf = dec_nf[level * nb_conv_per_level + conv]
139+
convs.append(ConvBlock(ndims, prev_nf, nf))
140+
prev_nf = nf
141+
self.decoder.append(convs)
142+
if level < (self.nb_levels - 1):
143+
prev_nf += encoder_nfs[level]
144+
145+
# now we take care of any remaining convolutions
146+
self.remaining = nn.ModuleList()
147+
for num, nf in enumerate(final_convs):
148+
self.remaining.append(ConvBlock(ndims, prev_nf, nf))
149+
prev_nf = nf
150+
151+
# final convolutions
152+
if return_mask:
153+
self.remaining.append(ConvBlock(ndims, prev_nf, 2, activation=None))
154+
self.remaining.append(nn.Softmax(dim=1))
155+
else:
156+
self.remaining.append(ConvBlock(ndims, prev_nf, 1, activation=None))
157+
158+
def forward(self, x):
159+
160+
# encoder forward pass
161+
x_history = [x]
162+
for level, convs in enumerate(self.encoder):
163+
for conv in convs:
164+
x = conv(x)
165+
x_history.append(x)
166+
x = self.pooling[level](x)
167+
168+
# decoder forward pass with upsampling and concatenation
169+
for level, convs in enumerate(self.decoder):
170+
for conv in convs:
171+
x = conv(x)
172+
if level < (self.nb_levels - 1):
173+
x = self.upsampling[level](x)
174+
x = torch.cat([x, x_history.pop()], dim=1)
175+
176+
# remaining convs at full resolution
177+
for conv in self.remaining:
178+
x = conv(x)
179+
180+
return x
181+
182+
class ConvBlock(nn.Module):
183+
"""
184+
Specific convolutional block followed by leakyrelu for unet.
185+
"""
186+
187+
def __init__(self, ndims, in_channels, out_channels, stride=1, activation='leaky'):
188+
super().__init__()
189+
190+
Conv = getattr(nn, 'Conv%dd' % ndims)
191+
self.conv = Conv(in_channels, out_channels, 3, stride, 1)
192+
if activation == 'leaky':
193+
self.activation = nn.LeakyReLU(0.2)
194+
elif activation == None:
195+
self.activation = None
196+
else:
197+
raise ValueError(f'Unknown activation: {activation}')
198+
199+
def forward(self, x):
200+
out = self.conv(x)
201+
if self.activation is not None:
202+
out = self.activation(out)
203+
return out
204+
205+
with torch.no_grad():
206+
model = StripModel()
207+
model.to(device)
208+
model.eval()
209+
210+
# load model weights
211+
if args.model is not None:
212+
modelfile = args.model
213+
print('Using custom model weights')
214+
else:
215+
version = '1'
216+
print(f'Running SynthStrip model version {version}')
217+
fshome = os.environ.get('FREESURFER_HOME')
218+
if fshome is None:
219+
sf.system.fatal('FREESURFER_HOME env variable must be set! Make sure FreeSurfer is properly sourced.')
220+
modelfile = os.path.join(fshome, 'models', f'synthstrip.{version}.pt')
221+
checkpoint = torch.load(modelfile, map_location=device)
222+
model.load_state_dict(checkpoint['model_state_dict'])
223+
224+
# load input volume
225+
image = sf.load_volume(args.image)
226+
print(f'Input image read from: {args.image}')
227+
228+
# frame check
229+
if image.nframes > 1:
230+
sf.system.fatal('Input image cannot have more than 1 frame')
231+
232+
# conform image and fit to shape with factors of 64
233+
conformed = image.conform(voxsize=1.0, dtype='float32', method='nearest', orientation='LIA').crop_to_bbox()
234+
target_shape = np.clip(np.ceil(np.array(conformed.shape[:3]) / 64).astype(int) * 64, 192, 320)
235+
conformed = conformed.reshape(target_shape)
236+
237+
# normalize intensities
238+
conformed -= conformed.min()
239+
conformed = (conformed / conformed.percentile(99)).clip(0, 1)
240+
241+
# predict the surface distance transform
242+
with torch.no_grad():
243+
input_tensor = torch.from_numpy(conformed.data[np.newaxis, np.newaxis]).to(device)
244+
sdt = model(input_tensor).cpu().numpy().squeeze()
245+
246+
# unconform the sdt and extract mask
247+
sdt = conformed.new(sdt).resample_like(image, fill=100)
248+
249+
# find largest CC (just do this to be safe for now)
250+
components = scipy.ndimage.label(sdt.data < args.border)[0]
251+
bincount = np.bincount(components.flatten())[1:]
252+
mask = (components == (np.argmax(bincount) + 1))
253+
mask = scipy.ndimage.binary_fill_holes(mask)
254+
255+
# write the masked output
256+
if args.out:
257+
image[mask == 0] = np.min([0, image.min()])
258+
image.save(args.out)
259+
print(f'Masked image saved to: {args.out}')
260+
261+
# write the brain mask
262+
if args.mask:
263+
image.new(mask).save(args.mask)
264+
print(f'Binary brain mask saved to: {args.mask}')
265+
266+
print('If you use SynthStrip in your analysis, please cite:')
267+
print('----------------------------------------------------')
268+
print('SynthStrip: Skull-Stripping for Any Brain Image.')
269+
print('A Hoopes, JS Mora, AV Dalca, B Fischl, M Hoffmann.')
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# pytorch==1.10.2
2+
scipy==1.8.1
3+
surfa==0.2.0
4+
PyYAML

DDIG/SynthStrip/1.0.0/spec.yaml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#### container info
2+
image:
3+
singularity: nobrainer-zoo_ddig_torch1.11.0.sif
4+
docker: neuronets/nobrainer-zoo:ddig_torch1.11.0
5+
6+
#### repository info
7+
repository:
8+
repo_url: "https://github.com/freesurfer/freesurfer/tree/dev/mri_synthstrip"
9+
committish: "e935059a"
10+
repo_download: False
11+
repo_download_location: "None"
12+
13+
#### required fields for prediction
14+
inference:
15+
prediction_script: "trained-models/DDIG/SynthStrip/1.0.0/predict.py"
16+
command: f"python3 {MODELS_PATH}/{model}/predict.py --model {model_path} -i {infile} -o {outfile}"
17+
18+
options:
19+
mask: {mandatory: False, argstr: "-m", type: "str", help: "Save binary brain mask to path."}
20+
gpu: {mandatory: False, argstr: "-g", is_flag: true, help: "Use the GPU."}
21+
border: {mandatory: False, argstr: "-b", type: "int", default: 1, help: "Mask border threshold in mm. Default is 1."}
22+
#### input data characteristics
23+
data_spec:
24+
infile: {n_files: 1}
25+
outfile: {n_files: 1}
26+
27+
#### required fields for model training
28+
train:
29+
#### TODO: Add the train spec here
30+
31+
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../../../.git/annex/objects/vx/qw/URL--https&c%%drive.google.com%file%d-d49b348e59ad5db895cc20bfaa87e1d9/URL--https&c%%drive.google.com%file%d-d49b348e59ad5db895cc20bfaa87e1d9

0 commit comments

Comments
 (0)