Skip to content

Commit 8067d60

Browse files
committed
[ENH] register: option for FOV/COM initialization
1 parent 497807b commit 8067d60

File tree

4 files changed

+60
-5
lines changed

4 files changed

+60
-5
lines changed

nitorch/cli/registration/register/cli.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,43 @@ def _main(options):
9797
# ------------------------------------------------------------------
9898
affine = None
9999
if options.affine:
100+
101+
# initialization
102+
affine_init = options.affine.init
103+
if isinstance(affine_init, str):
104+
if affine_init.lower() in ('fov', 'com'):
105+
if affine_init.lower() == 'fov':
106+
get_center = 'get_center'
107+
else:
108+
get_center = 'get_center_of_mass'
109+
loss_list = losses if isinstance(losses, list) else [losses]
110+
num = den = 0
111+
for loss_sum in loss_list:
112+
for loss in loss_sum:
113+
fix = getattr(loss.fixed, get_center)()
114+
mov = getattr(loss.moving, get_center)()
115+
wgt = 0.5 * (
116+
torch.Size(fix.shape[1:]).numel() +
117+
torch.Size(mov.shape[1:]).numel()
118+
)
119+
num += (mov - fix) * wgt
120+
den += wgt
121+
shift = num / den
122+
affine_init = torch.eye(4)
123+
affine_init[:-1, -1] = shift
124+
125+
# build affine object
100126
if options.affine.is2d is not False:
101127
if isinstance(losses[0], objects.SumSimilarity):
102128
affine0 = losses[0][0].fixed.affine
103129
else:
104130
affine0 = losses[0].fixed.affine
105131
affine = make_affine_2d(options.affine.is2d, affine0,
106132
options.affine.name, options.affine.position,
107-
init=options.affine.init)
133+
init=affine_init)
108134
else:
109135
affine = make_affine(options.affine.name, options.affine.position,
110-
init=options.affine.init)
136+
init=affine_init)
111137

112138
# ------------------------------------------------------------------
113139
# BUILD DENSE

nitorch/cli/registration/register/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@
163163
Common options:
164164
-p, --position Position of the affine: [sym], mov, fix
165165
-g, --progressive Progressive optimization (t -> r -> s -> a) [false]
166-
-i, --input Path to initial transform
166+
-i, --input Path to initial transform (fov, com, or file path)
167167
-o, --output Path to the output transform: [{dir}/{name}.lta]
168168
-2d [AXIS=2] Force transform to be 2d about AXIS
169169

nitorch/tools/registration/objects.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ def cpu_(self):
334334

335335
voxel_size = property(lambda self: spatial.voxel_size(self.affine))
336336
shape = property(lambda self: self.dat.shape[-self.dim:])
337+
ndim = property(lambda self: len(self.shape))
337338
dtype = property(lambda self: self.dat.dtype)
338339
device = property(lambda self: self.dat.device)
339340

@@ -430,13 +431,13 @@ def pull(self, grid, dat=True, mask=False, preview=False):
430431
if mask:
431432
msk = None
432433
if self.masked:
433-
msk = self.mask.to(self.dat.dtype)
434+
msk = self.mask.to(self.dtype)
434435
msk = regutils.smart_pull(msk, grid, bound=self.bound,
435436
extrapolate=self.extrapolate)
436437
out += [msk]
437438
if preview:
438439
if self.previewed:
439-
prv = self.preview.to(self.dat.dtype)
440+
prv = self.preview.to(self.dtype)
440441
prv = regutils.smart_pull(prv, grid, bound=self.bound,
441442
extrapolate=self.extrapolate)
442443
else:
@@ -481,6 +482,30 @@ def grad(self):
481482
return spatial.diff(self.dat, dim=list(range(-self.dim, 0)),
482483
bound=self.bound)
483484

485+
def get_center(self):
486+
"""Compute the RAS coordinate of the center of the field of view."""
487+
ndim = self.ndim
488+
backend = dict(dtype=self.dtype, device=self.device)
489+
affine = self.affine.to(**backend)
490+
shape = torch.as_tensor(self.shape, **backend)
491+
center = affine[:ndim, :ndim].matmul((shape[:, None] - 1)*0.5)
492+
center += affine[:ndim, -1:]
493+
return center.squeeze(-1)
494+
495+
def get_center_of_mass(self, masked=True, **backend):
496+
"""Compute the RAS coordinate of the center of mass."""
497+
backend = dict(dtype=self.dtype, device=self.device)
498+
grid = spatial.identity_grid(self.shape, **backend)
499+
dat = self.dat
500+
if masked and self.masked:
501+
msk = self.mask.to(**backend)
502+
dat = dat * msk
503+
den = msk.sum()
504+
if len(msk) == 1:
505+
den *= len(dat)
506+
num = (dat[..., None] * grid).sum(list(range(dat.ndim)))
507+
return num / den
508+
484509
def _prm_as_str(self):
485510
s = []
486511
if self.bound != 'dct2':

nitorch/tools/registration/pairwise_makeobj.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ def make_affine(basis='rigid', position='symmetric', penalty=None, init=None):
109109
If 'symmetric', both images are rotated by the transformation and
110110
its inverse, towards a mean space; thereby making the model fully
111111
symmetric.
112+
penalty : None
113+
Not implemented
114+
init : str | MappedAffine | (4, 4) tensor
115+
Initial affine
112116
113117
Returns
114118
-------

0 commit comments

Comments
 (0)