3333 𝓛 = 𝔼_q[ln p(𝒙)] = ∑ₙᵢ q(𝑥ₙ = 𝑖) ln ∑ⱼ 𝐻ᵢⱼ (𝜇 ∘ 𝜙)ₙⱼ
3434"""
3535import torch
36- from nitorch .core import linalg , utils , py
36+ from nitorch .core import linalg , utils , py , math
3737from nitorch import spatial , io
3838from .utils import jg , jhj , affine_grid_backward
3939import nitorch .plot as niplt
@@ -154,6 +154,18 @@ def align_tpm(dat, tpm=None, weights=None, spacing=(8, 4), device=None,
154154 # ------------------------------------------------------------------
155155 dat = discretize (dat , nbins = bins , mask = weights )
156156
157+ # ------------------------------------------------------------------
158+ # PREFILTER TPM
159+ # ------------------------------------------------------------------
160+ logtpm = tpm .clone ()
161+ # ensure normalized
162+ logtpm = logtpm .clamp (tiny , 1 - tiny ).div_ (logtpm .sum (0 , keepdim = True ))
163+ # transform to logits
164+ logtpm = logtpm .add_ (tiny ).log_ ()
165+ # spline prefilter
166+ splineopt = dict (interpolation = 2 , bound = 'replicate' )
167+ logtpm = spatial .spline_coeff_nd (logtpm , dim = 3 , inplace = True , ** splineopt )
168+
157169 # ------------------------------------------------------------------
158170 # OPTIONS
159171 # ------------------------------------------------------------------
@@ -175,8 +187,8 @@ def do_spacing(sp):
175187 if not sp :
176188 return dat0 , affine_dat0 , weights0
177189 sp = [max (1 , int (pymath .floor (sp / vx1 ))) for vx1 in vx ]
178- sp = [slice (None , None , sp1 ) for sp1 in sp ]
179- affine_dat , _ = spatial .affine_sub (affine_dat0 , dat0 .shape [- dim :], tuple ( sp ) )
190+ sp = tuple ( [slice (None , None , sp1 ) for sp1 in sp ])
191+ affine_dat , _ = spatial .affine_sub (affine_dat0 , dat0 .shape [- dim :], sp )
180192 dat = dat0 [(Ellipsis , * sp )]
181193 if weights0 is not None :
182194 weights = weights0 [(Ellipsis , * sp )]
@@ -234,7 +246,7 @@ def do_spacing(sp):
234246 if reorient is not None :
235247 affine_dat = reorient .matmul (affine_dat )
236248
237- mi , aff , prm = fit_affine_tpm (dat , tpm , affine_dat , affine_tpm ,
249+ mi , aff , prm = fit_affine_tpm (dat , logtpm , affine_dat , affine_tpm ,
238250 weights , ** opt , prm = prm )
239251
240252 if reorient is not None :
@@ -263,7 +275,7 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None,
263275 affine_tpm : (4, 4) tensor
264276 weights : (*spatial) tensor
265277 basis : {'translation', 'rotation', 'rigid', 'similitude', 'affine'}
266- fwhm : float, default=J/32
278+ fwhm : float, default=J/64
267279 max_iter_gn : int, default=100
268280 max_iter_em : int, default=32
269281 max_line_search : int, default=12
@@ -276,6 +288,8 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None,
276288 prm : (F) tensor
277289
278290 """
291+ # !!! NOTE: `tpm` must contain spline-prefiltered log-probabilities
292+
279293 dim = tpm .dim () - 1
280294
281295 # ------------------------------------------------------------------
@@ -326,7 +340,7 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None,
326340 affine_tpm = affine_tpm .to (** utils .backend (tpm ))
327341 shape = dat .shape [- dim :]
328342
329- tpm = tpm .to (dat .device ). clamp ( tiny , 1 - tiny )
343+ tpm = tpm .to (dat .device )
330344 basis = make_basis (basis , dim , ** utils .backend (tpm ))
331345 F = len (basis )
332346
@@ -337,7 +351,7 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None,
337351 em_opt = dict (fwhm = fwhm , max_iter = max_iter_em , weights = weights ,
338352 verbose = verbose - 2 )
339353 drv_opt = dict (weights = weights )
340- pull_opt = dict (bound = 'replicate' , extrapolate = True )
354+ pull_opt = dict (bound = 'replicate' , extrapolate = True , interpolation = 2 )
341355
342356 # ------------------------------------------------------------------
343357 # OPTIMIZE
@@ -365,6 +379,7 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None,
365379
366380 # --- warp TPM ---------------------------------------------
367381 mov = spatial .grid_pull (tpm , phi , ** pull_opt )
382+ mov = math .softmax (mov , dim = 1 )
368383
369384 # --- mutual info ------------------------------------------
370385 mi , Nm , prior = em_prior (mov , dat , prior0 , ** em_opt )
@@ -399,8 +414,8 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None,
399414 end = '\n ' if verbose >= 2 else '\r '
400415 print (f'({ basis_name [:6 ]} ){ space } | { n_iter :02d} | { mi .mean ():12.6g} ' , end = end )
401416
402- if mi .mean () - mi0 .mean () < 1e-4 :
403- # print('converged', mi.mean() - mi0.mean())
417+ if mi .mean () - mi0 .mean () < 0 : # 1e-4:
418+ print ('converged' , mi .mean () - mi0 .mean ())
404419 break
405420
406421 # --------------------------------------------------------------
@@ -412,16 +427,22 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None,
412427 g = g .sum (0 )
413428 h = h .sum (0 )
414429
415- # --- chain rule -----------------------------------------------
430+ # --- spatial derivatives --------------------------------------
431+ mov = mov .unsqueeze (- 1 )
416432 gmov = spatial .grid_grad (tpm , phi , ** pull_opt )
433+ gmov = mov * (gmov - (mov * gmov ).sum (1 , keepdim = True ))
434+ mov = mov .squeeze (- 1 )
435+
436+ # --- chain rule -----------------------------------------------
417437 gaff = lmdiv (affine_tpm , mm (gaff , affine ))
418438 g , h = chain_rule (g , h , gmov , gaff , maj = False )
419439 del gmov
420440
421441 # --- Gauss-Newton ---------------------------------------------
422442 h .diagonal (0 , - 1 , - 2 ).add_ (h .diagonal (0 , - 1 , - 2 ).abs ().max () * 1e-5 )
423443 delta = lmdiv (h , g .unsqueeze (- 1 )).squeeze (- 1 )
424- foo = 0
444+
445+ plot_registration (dat , mov , f'{ basis_name } | { n_iter } ' )
425446
426447 if verbose == 1 :
427448 print ('' )
@@ -898,7 +919,8 @@ def discretize(dat, nbins=256, mask=None):
898919
899920def get_spm_prior (** backend ):
900921 """Download the SPM prior"""
901- url = 'https://github.com/spm/spm12/raw/master/tpm/TPM.nii'
922+ # url = 'https://github.com/spm/spm12/raw/master/tpm/TPM.nii'
923+ url = 'https://github.com/spm/spm12/raw/refs/heads/main/tpm/TPM.nii'
902924 fname = os .path .join (cache_dir , 'SPM12_TPM.nii' )
903925 if not os .path .exists (fname ):
904926 os .makedirs (cache_dir , exist_ok = True )
0 commit comments