Skip to content

Commit e149d82

Browse files
committed
Fix(align_tpm): use quadratic interpolation in logit space instead of linear interpolation in prob space (mimcs spm_maff8)
1 parent 782bcad commit e149d82

File tree

2 files changed

+94
-26
lines changed

2 files changed

+94
-26
lines changed

nitorch/tools/registration/affine_tpm.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
𝓛 = 𝔼_q[ln p(𝒙)] = ∑ₙᵢ q(𝑥ₙ = 𝑖) ln ∑ⱼ 𝐻ᵢⱼ (𝜇 ∘ 𝜙)ₙⱼ
3434
"""
3535
import torch
36-
from nitorch.core import linalg, utils, py
36+
from nitorch.core import linalg, utils, py, math
3737
from nitorch import spatial, io
3838
from .utils import jg, jhj, affine_grid_backward
3939
import nitorch.plot as niplt
@@ -154,6 +154,18 @@ def align_tpm(dat, tpm=None, weights=None, spacing=(8, 4), device=None,
154154
# ------------------------------------------------------------------
155155
dat = discretize(dat, nbins=bins, mask=weights)
156156

157+
# ------------------------------------------------------------------
158+
# PREFILTER TPM
159+
# ------------------------------------------------------------------
160+
logtpm = tpm.clone()
161+
# ensure normalized
162+
logtpm = logtpm.clamp(tiny, 1-tiny).div_(logtpm.sum(0, keepdim=True))
163+
# transform to logits
164+
logtpm = logtpm.add_(tiny).log_()
165+
# spline prefilter
166+
splineopt = dict(interpolation=2, bound='replicate')
167+
logtpm = spatial.spline_coeff_nd(logtpm, dim=3, inplace=True, **splineopt)
168+
157169
# ------------------------------------------------------------------
158170
# OPTIONS
159171
# ------------------------------------------------------------------
@@ -175,8 +187,8 @@ def do_spacing(sp):
175187
if not sp:
176188
return dat0, affine_dat0, weights0
177189
sp = [max(1, int(pymath.floor(sp / vx1))) for vx1 in vx]
178-
sp = [slice(None, None, sp1) for sp1 in sp]
179-
affine_dat, _ = spatial.affine_sub(affine_dat0, dat0.shape[-dim:], tuple(sp))
190+
sp = tuple([slice(None, None, sp1) for sp1 in sp])
191+
affine_dat, _ = spatial.affine_sub(affine_dat0, dat0.shape[-dim:], sp)
180192
dat = dat0[(Ellipsis, *sp)]
181193
if weights0 is not None:
182194
weights = weights0[(Ellipsis, *sp)]
@@ -234,7 +246,7 @@ def do_spacing(sp):
234246
if reorient is not None:
235247
affine_dat = reorient.matmul(affine_dat)
236248

237-
mi, aff, prm = fit_affine_tpm(dat, tpm, affine_dat, affine_tpm,
249+
mi, aff, prm = fit_affine_tpm(dat, logtpm, affine_dat, affine_tpm,
238250
weights, **opt, prm=prm)
239251

240252
if reorient is not None:
@@ -263,7 +275,7 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None,
263275
affine_tpm : (4, 4) tensor
264276
weights : (*spatial) tensor
265277
basis : {'translation', 'rotation', 'rigid', 'similitude', 'affine'}
266-
fwhm : float, default=J/32
278+
fwhm : float, default=J/64
267279
max_iter_gn : int, default=100
268280
max_iter_em : int, default=32
269281
max_line_search : int, default=12
@@ -276,6 +288,8 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None,
276288
prm : (F) tensor
277289
278290
"""
291+
# !!! NOTE: `tpm` must contain spline-prefiltered log-probabilities
292+
279293
dim = tpm.dim() - 1
280294

281295
# ------------------------------------------------------------------
@@ -326,7 +340,7 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None,
326340
affine_tpm = affine_tpm.to(**utils.backend(tpm))
327341
shape = dat.shape[-dim:]
328342

329-
tpm = tpm.to(dat.device).clamp(tiny, 1-tiny)
343+
tpm = tpm.to(dat.device)
330344
basis = make_basis(basis, dim, **utils.backend(tpm))
331345
F = len(basis)
332346

@@ -337,7 +351,7 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None,
337351
em_opt = dict(fwhm=fwhm, max_iter=max_iter_em, weights=weights,
338352
verbose=verbose-2)
339353
drv_opt = dict(weights=weights)
340-
pull_opt = dict(bound='replicate', extrapolate=True)
354+
pull_opt = dict(bound='replicate', extrapolate=True, interpolation=2)
341355

342356
# ------------------------------------------------------------------
343357
# OPTIMIZE
@@ -365,6 +379,7 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None,
365379

366380
# --- warp TPM ---------------------------------------------
367381
mov = spatial.grid_pull(tpm, phi, **pull_opt)
382+
mov = math.softmax(mov, dim=1)
368383

369384
# --- mutual info ------------------------------------------
370385
mi, Nm, prior = em_prior(mov, dat, prior0, **em_opt)
@@ -399,8 +414,8 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None,
399414
end = '\n' if verbose >= 2 else '\r'
400415
print(f'({basis_name[:6]}){space} | {n_iter:02d} | {mi.mean():12.6g}', end=end)
401416

402-
if mi.mean() - mi0.mean() < 1e-4:
403-
# print('converged', mi.mean() - mi0.mean())
417+
if mi.mean() - mi0.mean() < 0: #1e-4:
418+
print('converged', mi.mean() - mi0.mean())
404419
break
405420

406421
# --------------------------------------------------------------
@@ -412,16 +427,22 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None,
412427
g = g.sum(0)
413428
h = h.sum(0)
414429

415-
# --- chain rule -----------------------------------------------
430+
# --- spatial derivatives --------------------------------------
431+
mov = mov.unsqueeze(-1)
416432
gmov = spatial.grid_grad(tpm, phi, **pull_opt)
433+
gmov = mov * (gmov - (mov * gmov).sum(1, keepdim=True))
434+
mov = mov.squeeze(-1)
435+
436+
# --- chain rule -----------------------------------------------
417437
gaff = lmdiv(affine_tpm, mm(gaff, affine))
418438
g, h = chain_rule(g, h, gmov, gaff, maj=False)
419439
del gmov
420440

421441
# --- Gauss-Newton ---------------------------------------------
422442
h.diagonal(0, -1, -2).add_(h.diagonal(0, -1, -2).abs().max() * 1e-5)
423443
delta = lmdiv(h, g.unsqueeze(-1)).squeeze(-1)
424-
foo = 0
444+
445+
plot_registration(dat, mov, f'{basis_name} | {n_iter}')
425446

426447
if verbose == 1:
427448
print('')
@@ -898,7 +919,8 @@ def discretize(dat, nbins=256, mask=None):
898919

899920
def get_spm_prior(**backend):
900921
"""Download the SPM prior"""
901-
url = 'https://github.com/spm/spm12/raw/master/tpm/TPM.nii'
922+
# url = 'https://github.com/spm/spm12/raw/master/tpm/TPM.nii'
923+
url = 'https://github.com/spm/spm12/raw/refs/heads/main/tpm/TPM.nii'
902924
fname = os.path.join(cache_dir, 'SPM12_TPM.nii')
903925
if not os.path.exists(fname):
904926
os.makedirs(cache_dir, exist_ok=True)

nitorch/tools/registration/pairwise_preproc.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
def preproc_image(input, mask=None, label=False, missing=0,
1414
world=None, affine=None, rescale=.95,
15-
pad=None, bound='zero', fwhm=None,
15+
pad=None, bound='zero', fwhm=None, channels=None,
1616
dim=None, device=None, **kwargs):
1717
"""Load an image and preprocess it as required
1818
@@ -43,6 +43,8 @@ def preproc_image(input, mask=None, label=False, missing=0,
4343
fwhm : [sequence of] float
4444
Smooth the volume with a Gaussian kernel of that FWHM.
4545
If last element is "mm", values are in mm and converted to voxels.
46+
channels : [sequence of] int or range or slice
47+
Channels to load
4648
dim : int, optional
4749
Number of spatial dimensions
4850
device : torch.device
@@ -58,14 +60,30 @@ def preproc_image(input, mask=None, label=False, missing=0,
5860
Orientation matrix
5961
6062
"""
61-
if not torch.is_tensor(input):
62-
dat, mask0, affine0 = load_image(input, dim=dim, device=device,
63-
label=label, missing=missing)
64-
else:
65-
dat = input
66-
mask0 = torch.isfinite(dat)
67-
dat = dat.masked_fill(~mask0, 0)
68-
affine0 = spatial.affine_default(dat.shape[1:])
63+
dat, mask0, affine0 = load_image(input, dim=dim, device=device,
64+
label=label, missing=missing,
65+
channels=channels)
66+
67+
# if not torch.is_tensor(input):
68+
# dat, mask0, affine0 = load_image(input, dim=dim, device=device,
69+
# label=label, missing=missing,
70+
# channels=channels)
71+
# else:
72+
# dat = input
73+
# if channels is not None:
74+
# channels = make_list(channels)
75+
# channels = [
76+
# list(c) if isinstance(c, range) else
77+
# list(range(len(dat)))[c] if isinstance(c, slice) else
78+
# c for c in channels
79+
# ]
80+
# if not all([isinstance(c, int) for c in channels]):
81+
# raise ValueError('Channel list should be a list of integers')
82+
# dat = dat[channels]
83+
# mask0 = torch.isfinite(dat)
84+
# dat = dat.masked_fill(~mask0, 0)
85+
# affine0 = spatial.affine_default(dat.shape[1:])
86+
6987
dim = dat.dim() - 1
7088

7189
# load user-defined mask
@@ -199,7 +217,7 @@ def prepare_pyramid_levels(images, levels, dim=None, **opt):
199217
return pyrutils.pyramid_levels(vxs, shapes, levels, **opt)
200218

201219

202-
def map_image(fnames, dim=None):
220+
def map_image(fnames, dim=None, channels=None):
203221
"""Map an ND image from disk
204222
205223
Parameters
@@ -229,7 +247,6 @@ def map_image(fnames, dim=None):
229247
affine = img.affine
230248
if dim is None:
231249
dim = img.affine.shape[-1] - 1
232-
# img = img.fdata(rand=True, device=device)
233250
if img.dim > dim:
234251
img = img.movedim(-1, 0)
235252
else:
@@ -241,10 +258,24 @@ def map_image(fnames, dim=None):
241258
imgs.append(img)
242259
del img
243260
imgs = io.cat(imgs, dim=0)
261+
262+
# select a subset of channels
263+
if channels is not None:
264+
channels = make_list(channels)
265+
channels = [
266+
list(c) if isinstance(c, range) else
267+
list(range(len(imgs)))[c] if isinstance(c, slice) else
268+
c for c in channels
269+
]
270+
if not all([isinstance(c, int) for c in channels]):
271+
raise ValueError('Channel list should be a list of integers')
272+
imgs = io.stack([imgs[c] for c in channels])
273+
244274
return imgs, affine
245275

246276

247-
def load_image(input, dim=None, device=None, label=False, missing=0):
277+
def load_image(input, dim=None, device=None, label=False, missing=0,
278+
channels=None):
248279
"""
249280
Load a N-D image from disk
250281
@@ -272,15 +303,30 @@ def load_image(input, dim=None, device=None, label=False, missing=0):
272303
Orientation matrix
273304
"""
274305
if not torch.is_tensor(input):
275-
dat, affine = map_image(input, dim)
306+
dat, affine = map_image(input, dim, channels=channels)
276307
else:
277308
dat, affine = input, spatial.affine_default(input.shape[1:])
309+
310+
if channels is not None:
311+
channels = make_list(channels)
312+
channels = [
313+
list(c) if isinstance(c, range) else
314+
list(range(len(dat)))[c] if isinstance(c, slice) else
315+
c for c in channels
316+
]
317+
if not all([isinstance(c, int) for c in channels]):
318+
raise ValueError('Channel list should be a list of integers')
319+
dat = dat[channels]
320+
278321
if label:
279322
dtype = dat.dtype
280323
if isinstance(dtype, (list, tuple)):
281324
dtype = dtype[0]
282325
dtype = dtypes.as_torch(dtype, upcast=True)
283-
dat0 = dat.data(device=device, dtype=dtype)[0] # assume single channel
326+
if torch.is_tensor(dat):
327+
dat0 = dat[0]
328+
else:
329+
dat0 = dat.data(device=device, dtype=dtype)[0] # assume single channel
284330
if label is True:
285331
label = dat0.unique(sorted=True)
286332
label = label[label != 0].tolist()

0 commit comments

Comments
 (0)