Skip to content

Commit 1ce7516

Browse files
committed
FEAT: optimal graph-based spatial normalization
1 parent 45e7da3 commit 1ce7516

File tree

30 files changed

+1417
-178
lines changed

30 files changed

+1417
-178
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from . import affreg
2+
from . import affopt
23
from . import autoreg
34
from . import compose
45
from . import orient
6+
from . import meanspace
57
from . import register
68
from . import reorient
79
from . import reslice
810
from . import resize
11+
from . import vopt
912
from . import vexp
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import cli as _
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import sys
2+
3+
from nitorch.cli.cli import commands
4+
from nitorch.core.cli import ParseError
5+
from nitorch.io.transforms import loadf, savef
6+
from nitorch.spatial import optimal_affine
7+
8+
9+
def cli(args=None):
10+
f"""Command-line interface for `affopt`
11+
12+
{help}
13+
14+
"""
15+
16+
# Exceptions are dealt with here
17+
try:
18+
_cli(args)
19+
except AskForHelp:
20+
print(help)
21+
return
22+
except ParseError as e:
23+
print(help)
24+
print('[ERROR]', e)
25+
return
26+
# except Exception as e:
27+
# print(f'[ERROR] {str(e)}', file=sys.stderr)
28+
29+
30+
commands['affopt'] = cli
31+
32+
help = r"""[nitorch] affopt
33+
34+
Compute optimal "template-to-subject" matrices from "subject-to-subject" pairs.
35+
36+
notes:
37+
In Leung et al, all possible pairs of images are registered (in both
38+
forward and backward directions), such that the "subject-to-template"
39+
transforms can easily be computed as T[i,tpl] = expm(mean_j(logm(T[i,j]))).
40+
41+
Our implementation differs in several aspects:
42+
- We allow some transformation pairs to be missing, at the cost of
43+
introducing bias in the mean space estimation. This bias can be
44+
overcome in the statistical sense if the number of subjects is large
45+
and evaluated pairs are randomly sampled.
46+
- Instead of first symmetrizing pairwise transforms, we fit the mean
47+
space from all possible forward and backward transformations.
48+
- Instead of minimizing the L2 norm in the matrix Lie algebra
49+
(which is done implicitly by Leung et al's method), we add
50+
the possibility to minimize the L2 norm in the embedding space (i.e.,
51+
the Frobenius norm of affine matrices). This method is more accurate
52+
when pairwise transformations are large, in which case affine
53+
composition is badly approximated by log-matrix addition.
54+
55+
usage:
56+
nitorch affopt --input <fix> <mov> <path> [--input ...]
57+
58+
arguments:
59+
-i, --input Affine transform for one pair of images
60+
<fix> Index (or label) of fixed image
61+
<mov> Index (or label) of moving image
62+
<path> Path to an LTA file that warps <mov> to <fix>
63+
-o, --output Path to output transforms (default: {label}_optimal.lta)
64+
-l, --log Minimize L2 in Lie algebra (default: L2 in matrix space)
65+
-a, --affine Assume transforms are all affine (default)
66+
-s, --similitude Assume transforms are all similitude
67+
-r, --rigid Assume transforms are all rigid
68+
69+
example:
70+
nitorch affopt \
71+
-i mtw pdw mtw_to_pdw.lta \
72+
-i mtw t1w mtw_to_t1w.lta \
73+
-i pdw mtw pdw_to_mtw.lta \
74+
-i pdw t1w pdw_to_t1w.lta \
75+
-i t1w mtw t1w_to_mtw.lta \
76+
-i t1w pdw t1w_to_pdw.lta \
77+
-o out/{label}_to_mean.lta
78+
79+
references:
80+
"Consistent multi-time-point brain atrophy estimation from the
81+
boundary shift integral"
82+
Leung, Ridgway, Ourselin, Fox
83+
NeuroImage (2011)
84+
85+
"Symmetric Diffeomorphic Modeling of Longitudinal Structural MRI"
86+
Ashburner, Ridgway
87+
Front Neurosci. (2012)
88+
"""
89+
90+
91+
class AskForHelp(Exception):
92+
pass
93+
94+
95+
def parse(args):
96+
97+
args = list(args)
98+
if not args:
99+
raise ParseError('No arguments')
100+
101+
inputs = {}
102+
output = log = affine = similitude = rigid = None
103+
104+
tags = (
105+
'-i', '--input',
106+
'-o', '--output',
107+
'-l', '--lie', '--log',
108+
'-a', '--aff', '--affine',
109+
'-s', '--sim', '--similitude',
110+
'-r', '--rig', '--rigid'
111+
)
112+
113+
while args:
114+
tag = args.pop(0)
115+
if tag in ('-h', '--help'):
116+
raise AskForHelp
117+
elif tag in ('-i', '--input'):
118+
fix = args.pop(0)
119+
if fix in tags:
120+
raise ParseError(f'Expected <fix> <mov> <path> after {tag}')
121+
mov = args.pop(0)
122+
if mov in tags:
123+
raise ParseError(f'Expected <fix> <mov> <path> after {tag}')
124+
path = args.pop(0)
125+
if path in tags:
126+
raise ParseError(f'Expected <fix> <mov> <path> after {tag}')
127+
inputs[(fix, mov)] = path
128+
elif tag in ('-o', '--output'):
129+
out = args.pop(0)
130+
if out in tags:
131+
raise ParseError(f'Expected <path> after {tag}')
132+
if output is not None:
133+
raise ParseError(f'Max one {tag} accepted')
134+
output = out
135+
elif tag in ('-l', '--log', '--lie'):
136+
if log is not None:
137+
raise ParseError(f'Max one {tag} accepted')
138+
log = True
139+
elif tag in ('-a', '--aff', '--affine'):
140+
if affine is not None:
141+
raise ParseError(f'Max one {tag} accepted')
142+
affine = True
143+
elif tag in ('-s', '--sim', '--similitude'):
144+
if similitude is not None:
145+
raise ParseError(f'Max one {tag} accepted')
146+
similitude = True
147+
elif tag in ('-r', '--rig', '--rigid'):
148+
if rigid is not None:
149+
raise ParseError(f'Max one {tag} accepted')
150+
rigid = True
151+
else:
152+
raise ParseError(f'Unknown tag {tag}')
153+
154+
output = output or '{label}_optimal.lta'
155+
affine = affine or False
156+
similitude = similitude or False
157+
rigid = rigid or False
158+
if int(rigid) + int(similitude) + int(affine) > 1:
159+
raise ParseError('Max one of --rigid, --similitude, --affine accepeted')
160+
if int(rigid) + int(similitude) + int(affine) == 0:
161+
affine = True
162+
if affine:
163+
basis = 'Aff+'
164+
elif similitude:
165+
basis = 'CSO'
166+
else:
167+
basis = 'SE'
168+
169+
return inputs, output, log, basis
170+
171+
172+
def _cli(args):
173+
174+
inputs, output, log, basis = parse(args)
175+
176+
# LinearTransformArray(v).matrix() returns array with shape
177+
# [N, 4, 4], although N is always 1 in our case.
178+
inputs = {k: loadf(v).squeeze()
179+
for k, v in inputs.items()}
180+
optimal = optimal_affine(inputs, basis=basis,
181+
loss='log' if log else 'exp')
182+
183+
labels = list(set([label for pair in inputs for label in pair]))
184+
for i, label in enumerate(labels):
185+
savef(optimal[i], output.format(label=label), type='ras')
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import cli as _
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import sys
2+
import torch
3+
import os.path as op
4+
from tempfile import NamedTemporaryFile
5+
6+
from nitorch.cli.cli import commands
7+
from nitorch.core.cli import ParseError
8+
from nitorch import io
9+
from nitorch.cli.registration.reslice.main import reslice
10+
from nitorch.spatial import mean_space
11+
from nitorch.core.py import make_list, fileparts
12+
from .parser import parser
13+
14+
15+
def cli(args=None):
16+
f"""Command-line interface for `meanspace`
17+
18+
{help}
19+
20+
"""
21+
22+
# Exceptions are dealt with here
23+
try:
24+
_cli(args)
25+
except ParseError as e:
26+
print(help)
27+
print('[ERROR]', e)
28+
return
29+
# except Exception as e:
30+
# print(f'[ERROR] {str(e)}', file=sys.stderr)
31+
32+
33+
commands['meanspace'] = cli
34+
35+
help = r"""
36+
Compute an average "oriented voxel space" from a series of oriented images.
37+
38+
usage:
39+
nitorch meanspace --input <path> [--input ...]
40+
41+
arguments:
42+
-i, --input Affine transform for one pair of images
43+
<fix> Index (or label) of fixed image
44+
<mov> Index (or label) of moving image
45+
<path> Path to an LTA file that warps <mov> to <fix>
46+
-o, --output Path to output transforms (default: {label}_optimal.lta)
47+
-l, --log Minimize L2 in Lie algebra (default: L2 in matrix space)
48+
-a, --affine Assume transforms are all affine (default)
49+
-s, --similitude Assume transforms are all similitude
50+
-r, --rigid Assume transforms are all rigid
51+
52+
example:
53+
optimal_affine \
54+
-i mtw pdw mtw_to_pdw.lta \
55+
-i mtw t1w mtw_to_t1w.lta \
56+
-i pdw mtw pdw_to_mtw.lta \
57+
-i pdw t1w pdw_to_t1w.lta \
58+
-i t1w mtw t1w_to_mtw.lta \
59+
-i t1w pdw t1w_to_pdw.lta \
60+
-o out/{label}_to_mean.lta
61+
62+
references:
63+
"Consistent multi-time-point brain atrophy estimation from the
64+
boundary shift integral"
65+
Leung, Ridgway, Ourselin, Fox
66+
NeuroImage (2011)
67+
68+
"Symmetric Diffeomorphic Modeling of Longitudinal Structural MRI"
69+
Ashburner, Ridgway
70+
Front Neurosci. (2012)
71+
"""
72+
73+
74+
def _cli(args):
75+
args = args or sys.argv[1:]
76+
77+
options = parser.parse(args)
78+
if not options:
79+
return
80+
if options.help:
81+
print(help)
82+
return
83+
84+
# get all shape and matrices
85+
shapes = []
86+
affines = []
87+
for path in options.input:
88+
f = io.map(path)
89+
shapes += [f.shape[:3]]
90+
affines += [f.affine]
91+
92+
ndim = max(map(len, shapes))
93+
94+
# parse voxel size
95+
voxel_size = options.voxel_size
96+
voxel_size = make_list(voxel_size or [])
97+
if voxel_size and isinstance(voxel_size[-1], str):
98+
*voxel_size, vx_unit = voxel_size
99+
else:
100+
vx_unit = 'mm'
101+
if voxel_size:
102+
voxel_size = make_list(voxel_size, ndim)
103+
else:
104+
voxel_size = None
105+
106+
# parse padding
107+
pad = options.pad
108+
pad = make_list(pad or [])
109+
if pad and isinstance(pad[-1], str):
110+
*pad, pad_unit = pad
111+
else:
112+
pad_unit = '%'
113+
if pad:
114+
pad = make_list(pad, ndim)
115+
else:
116+
pad = None
117+
118+
# compute mean space
119+
affine, shape = mean_space(
120+
affines,
121+
shapes,
122+
voxel_size=voxel_size,
123+
vx_unit=vx_unit,
124+
pad=pad,
125+
pad_unit=pad_unit
126+
)
127+
print(shape)
128+
print(affine.numpy())
129+
130+
dir, base, ext = fileparts(op.abspath(options.input[0]))
131+
132+
# write mean space image
133+
write_meanspace = options.output is not False
134+
if not write_meanspace:
135+
tmp = NamedTemporaryFile("wb", delete=False, suffix=ext)
136+
options.output = tmp.name
137+
138+
options.output = options.output.format(dir=dir, base=base, ext=ext)
139+
140+
io.savef(
141+
torch.zeros(shape),
142+
options.output,
143+
affine=affine
144+
)
145+
146+
# write resliced images
147+
if options.resliced:
148+
for input in options.input:
149+
input = op.abspath(input)
150+
dir, base, ext = fileparts(input)
151+
output = options.resliced.format(dir=dir, base=base, ext=ext)
152+
print(output)
153+
reslice([input, '-o', output, '-t', options.output])
154+
155+
if not write_meanspace:
156+
tmp.close()

0 commit comments

Comments
 (0)