Skip to content

Commit 36eb36f

Browse files
committed
[FIX] fix progressive affine when an initial guess is provided
1 parent b2861df commit 36eb36f

File tree

2 files changed

+37
-40
lines changed

2 files changed

+37
-40
lines changed

nitorch/tools/registration/objects.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2043,27 +2043,24 @@ def __init__(self, model: "AffineModel"):
20432043
optim = property(lambda self: self.model.optim)
20442044
basis = property(lambda self: self.model.basis_name)
20452045

2046-
def add_(self, value, **kwargs):
2046+
def _switch(self, value):
20472047
if self.optim not in (None, self.basis):
20482048
value = LogAffine._switch_basis(value, self.optim, self.basis, self.model.dat.dim)
2049-
self.dat.add_(value, **kwargs)
2049+
return value
2050+
2051+
def add_(self, value, **kwargs):
2052+
self.dat.add_(self._switch(value), **kwargs)
20502053
return self
20512054

20522055
def add(self, value, **kwargs):
2053-
if self.optim not in (None, self.basis):
2054-
value = LogAffine._switch_basis(value, self.optim, self.basis, self.model.dat.dim)
2055-
return self.dat.add(value, **kwargs)
2056+
return self.dat.add(self._switch(value), **kwargs)
20562057

20572058
def mul_(self, value, **kwargs):
2058-
if self.optim not in (None, self.basis):
2059-
value = LogAffine._switch_basis(value, self.optim, self.basis, self.model.dat.dim)
2060-
self.dat.mul_(value, **kwargs)
2059+
self.dat.mul_(self._switch(value), **kwargs)
20612060
return self
20622061

20632062
def mul(self, value, **kwargs):
2064-
if self.optim not in (None, self.basis):
2065-
value = LogAffine._switch_basis(value, self.optim, self.basis, self.model.dat.dim)
2066-
return self.dat.mul(value, **kwargs)
2063+
return self.dat.mul(self._switch(value), **kwargs)
20672064

20682065
def __add__(self, other):
20692066
return self.add(other)
@@ -2247,6 +2244,8 @@ def __repr__(self):
22472244
else:
22482245
s += ['<uninitialized>']
22492246
s += [f'basis={self._basis}']
2247+
if self.optim is not None:
2248+
s += [f'optim={self.optim}']
22502249
s += [f'factor={self.factor}']
22512250
s += [f'position={self.position}']
22522251
if self.penalty is not None:
@@ -2260,27 +2259,10 @@ class Affine2dModel(AffineModel):
22602259

22612260
class Parameters(AffineModel.Parameters):
22622261

2263-
def add_(self, value, **kwargs):
2264-
if self.optim not in (None, self.basis):
2265-
value = LogAffine2d._switch_basis(value, self.optim, self.basis)
2266-
self.dat.add_(value, **kwargs)
2267-
return self
2268-
2269-
def add(self, value, **kwargs):
2270-
if self.optim not in (None, self.basis):
2271-
value = LogAffine._switch_basis(value, self.optim, self.basis, self.model.dat.dim)
2272-
return self.dat.add(value, **kwargs)
2273-
2274-
def mul_(self, value, **kwargs):
2262+
def _switch(self, value):
22752263
if self.optim not in (None, self.basis):
22762264
value = LogAffine2d._switch_basis(value, self.optim, self.basis)
2277-
self.dat.mul_(value, **kwargs)
2278-
return self
2279-
2280-
def mul(self, value, **kwargs):
2281-
if self.optim not in (None, self.basis):
2282-
value = LogAffine._switch_basis(value, self.optim, self.basis, self.model.dat.dim)
2283-
return self.dat.mul(value, **kwargs)
2265+
return value
22842266

22852267
def __init__(self, basis, plane, ref_affine=None, factor=1, penalty=None,
22862268
dat=None, position='symmetric'):
@@ -2307,7 +2289,13 @@ def plane_name(self):
23072289
return self._plane
23082290

23092291
def set_dat(self, dat=None, dim=None, **backend):
2310-
self.dat = LogAffine2d(dat, basis=self._basis, dim=dim,
2311-
rotation=self._rotation,
2312-
plane=self._plane, **backend)
2292+
self.dat = LogAffine2d(
2293+
dat,
2294+
basis=self._basis,
2295+
optim=self._optim,
2296+
dim=dim,
2297+
rotation=self._rotation,
2298+
plane=self._plane,
2299+
**backend
2300+
)
23132301
return self

nitorch/tools/registration/pairwise_run.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,7 @@ def run(
106106
# --- progressive affine initialization ---
107107
if affine and not affine.frozen and progressive:
108108
affine, figure = run_progressive_init(
109-
runner, losses, affine, affine_optim, line_size=line_size,
110-
**plotopt)
109+
runner, losses, affine, affine_optim, **plotopt)
111110
plotopt["figure"] = figure
112111

113112
# --- full affine ---
@@ -185,16 +184,16 @@ def run_progressive_init(
185184
return affine, None
186185

187186
if verbose:
188-
print('-' * line_size)
187+
print('=' * line_size)
189188
print(' PROGRESSIVE INITIALIZATION')
190-
print('-' * line_size)
189+
print('=' * line_size)
191190

192191
affine.optim = names.pop(0)
193192
while names:
194193
if verbose:
195194
line_pad = line_size - len(affine.optim) - 5
196-
print(f'--- {affine.optim} ', end='')
197-
print('-' * max(0, line_pad))
195+
print(f'*** {affine.optim} ', end='')
196+
print('*' * max(0, line_pad))
198197
affine_optim.reset_state()
199198
torch.cuda.empty_cache()
200199

@@ -244,11 +243,21 @@ def run_pyramid(
244243
optim,
245244
verbose=True,
246245
framerate=1,
247-
figure=None
246+
figure=None,
248247
):
249248
"""Run sequential pyramid registration"""
250249
line_size = 89 if nonlin else 74
251250

251+
if verbose:
252+
print('=' * line_size)
253+
if affine and not nonlin:
254+
print(' AFFINE')
255+
elif nonlin and not affine:
256+
print(' NONLINEAR')
257+
else:
258+
print(' AFFINE & NONLINEAR')
259+
print('=' * line_size)
260+
252261
nonlin_optim = None
253262
if nonlin and not getattr(nonlin, 'frozen', False):
254263
if affine and not getattr(affine, 'frozen', False):

0 commit comments

Comments
 (0)