Skip to content

Commit 45e7da3

Browse files
committed
ENH(register): better partial affine optimization
1 parent 43cc7fb commit 45e7da3

File tree

7 files changed

+309
-147
lines changed

7 files changed

+309
-147
lines changed

nitorch/cli/registration/register/cli.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,8 @@ def build_losses(options, pyramids, device):
278278
image_dict[loss.mov.name or loss.mov.files[0]] = mov
279279

280280
# Forward loss
281-
factor = loss.factor / (2 if loss.symmetric else 1)
281+
loss.factor = py.ensure_list(loss.factor, len(pyramid))[::-1]
282+
factor = loss.factor[-1] / (2 if loss.symmetric else 1)
282283
sumloss = sumloss + factor * objects.Similarity(lossobj, mov, fix)
283284

284285
# Backward loss
@@ -296,6 +297,40 @@ def build_losses(options, pyramids, device):
296297
pyramid_fn = (concurrent_pyramid if options.pyramid.concurrent else
297298
sequential_pyramid)
298299
loss_list = pyramid_fn(sumloss)
300+
301+
if options.pyramid.concurrent:
302+
if not isinstance(loss_list, objects.SumSimilarity):
303+
loss_list = objects.SumSimilarity([loss_list])
304+
loss_list = list(loss_list)
305+
new_list = []
306+
nb_levels = len(loss_list) // len(options.loss)
307+
nb_losses = len(options.loss)
308+
for i in range(nb_levels):
309+
for j in range(nb_losses):
310+
rel_factor = (
311+
options.loss[j].factor[i] /
312+
options.loss[j].factor[-1]
313+
)
314+
new_list.append(loss_list[i*nb_losses+j] * rel_factor)
315+
loss_list = objects.SumSimilarity.sum(new_list)
316+
else:
317+
new_list = []
318+
for i, level in enumerate(loss_list):
319+
if not isinstance(level, objects.SumSimilarity):
320+
level = objects.SumSimilarity([level])
321+
level = list(level)
322+
new_level = []
323+
for j, loss_elem in enumerate(level):
324+
i = min(i, len(options.loss[j].factor) - 1)
325+
rel_factor = (
326+
options.loss[j].factor[i] /
327+
options.loss[j].factor[-1]
328+
)
329+
new_level.append(loss_elem * rel_factor)
330+
new_level = objects.SumSimilarity.sum(new_level)
331+
new_list.append(new_level)
332+
loss_list = new_list
333+
299334
return loss_list, image_dict
300335

301336

nitorch/cli/registration/register/parser.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@
9595
@pyramid [NAME] ...
9696
9797
@loss options:
98-
FACTOR must be a scalar value [1]
9998
NAME can take values (with specific sub-options):
10099
[mi, nmi] Mutual information (can be normalized)
101100
-m, --norm NAME Normalization: [studholme], arithmetic, geometric, no
@@ -126,6 +125,7 @@
126125
cat, cce Categorical cross-entropy
127126
dice, f1 Dice coefficient
128127
-w, --weight *VAL Weight per class [1]
128+
FACTOR must be a scalar value [1]
129129
Common options:
130130
-s, --symmetric Make loss symmetric [False]
131131
-z, --slicewise [AXIS=-1] Make loss slice-wise [False]
@@ -153,13 +153,13 @@
153153
-c, --channels *C Channels to load. Can be a range start:stop:step [:]
154154
155155
@affine options:
156-
FACTOR must be a scalar value [1] and is a global penalty factor
157156
NAME can take values:
158157
t, translation Translations only
159158
o, rotation Rotations only
160159
[r, rigid] Translations + Rotations
161160
s, similitude Translations + Rotations + Iso zoom
162161
a, affine Full affine
162+
FACTOR must be a scalar value [1] and is a global penalty factor
163163
Common options:
164164
-p, --position Position of the affine: [sym], mov, fix
165165
-g, --progressive Progressive optimization (t -> r -> s -> a) [false]
@@ -168,11 +168,11 @@
168168
-2d [AXIS=2] Force transform to be 2d about AXIS
169169
170170
@nonlin options:
171-
FACTOR must be a scalar value [1] and is a global penalty factor
172171
NAME can take values:
173172
[v, svf] Stationary velocity field
174173
g, shoot Geodesic shooting
175174
d, smalldef Dense deformation field
175+
FACTOR must be a scalar value [1] and is a global penalty factor
176176
Common options:
177177
-i, --input Path to initial transform
178178
-o, --output Path to the output transform: [{dir}/{name}.nii.gz]
@@ -336,7 +336,7 @@ def _convert(x):
336336
help='Name of the image loss')
337337
loss = cli.NamedGroup('loss', loss_choices, '@loss', n='+',
338338
help='A loss between two images')
339-
loss.add_positional('factor', nargs='?', default=1., convert=float,
339+
loss.add_positional('factor', nargs='*', default=[1.], convert=float,
340340
help='Weight it this component in the global loss')
341341
loss.add_option('symmetric', ('-s', '--symmetric'), nargs=0, default=False,
342342
help='Make the loss symmetric')
@@ -529,7 +529,7 @@ def _convert(x):
529529
convert=lambda x: affine_aliases.get(x, x),
530530
help='Name of the affine transform')
531531
affine = cli.NamedGroup('affine', affine_choices, '@affine', n='?', make_default=False)
532-
affine.add_positional('factor', nargs='?', default=1., convert=float,
532+
affine.add_positional('factor', nargs='*', default=[1.], convert=float,
533533
help='Penalty factor')
534534
affine.add_option('position', ('-p', '--position'), default='sym', nargs=1,
535535
validation=cli.Validations.choice(['sym', 'mov', 'fix']),
@@ -556,7 +556,7 @@ def _convert(x):
556556
convert=lambda x: nonlin_aliases.get(x, x),
557557
help='Name of the nonlinear transform')
558558
nonlin = cli.NamedGroup('nonlin', nonlin_choices, '@nonlin', n='?', make_default=False)
559-
nonlin.add_positional('factor', nargs='?', default=1., convert=float,
559+
nonlin.add_positional('factor', nargs='*', default=[1.], convert=float,
560560
help='Penalty factor')
561561
nonlin.add_option('steps', ('-s', '--steps'), nargs=1, default=8, convert=int,
562562
help='Number of integration steps')

0 commit comments

Comments
 (0)