1010The command line function `nitorch register` calls this model under the hood.
1111
1212"""
13+ from threading import Thread , Lock
1314from nitorch import spatial
1415from nitorch .core import utils , linalg
1516import torch
@@ -245,19 +246,30 @@ def __init__(
245246 self .framerate = framerate
246247 self .figure = figure
247248 self ._last_plot = 0
249+ self ._plot_lock = Lock ()
248250 if self .verbose > 2 and not self .figure :
249251 import matplotlib .pyplot as plt
250252 self .figure = plt .figure ()
253+ self ._plot_pending = None
251254
252- def mov2fix (self , fixed , moving , warped , vel = None , title = None , lam = None ):
255+ def mov2fix (self , * args , ** kwargs ):
256+ thread = Thread (target = (lambda : self ._mov2fix (* args , ** kwargs )))
257+ if self ._plot_lock .locked ():
258+ self ._plot_pending = thread
259+ else :
260+ thread .run ()
261+
262+ def _mov2fix (self , fixed , moving , warped , vel = None , title = None , lam = None ):
253263 """Plot registration live"""
254264
265+ self ._plot_lock .acquire ()
266+ self ._plot_pending = None
267+
255268 import time
256- tic = self ._last_plot
257- toc = time .time ()
258- if toc - tic < 1 / self .framerate :
259- return
260- self ._last_plot = toc
269+ # tic = self._last_plot
270+ # toc = time.time()
271+ # if toc - tic < 1/self.framerate:
272+ # return
261273
262274 import matplotlib .pyplot as plt
263275
@@ -277,7 +289,9 @@ def mov2fix(self, fixed, moving, warped, vel=None, title=None, lam=None):
277289 def rescale2d (x ):
278290 if not x .dtype .is_floating_point :
279291 x = x .float ()
280- mn , mx = utils .quantile (x , [0.005 , 0.995 ],
292+ mask = (x == 0 ).any (dim = list (range (x .ndim - 2 )))
293+ xmasked = x [~ mask ]
294+ mn , mx = utils .quantile (xmasked , [0.005 , 0.995 ],
281295 dim = range (- 2 , 0 ), bins = 1024 ).unbind (- 1 )
282296 mx = mx .max (mn + 1e-8 )
283297 mn , mx = mn [..., None , None ], mx [..., None , None ]
@@ -314,6 +328,11 @@ def rescale2d(x):
314328 vel = [vel .square ().sum (- 1 ).sqrt ()] if vel is not None else []
315329 lam = [rescale2d (lam )] if lam is not None else []
316330
331+ if vel :
332+ vmax = max (v .max ().item () for v in vel )
333+ if lam :
334+ lmax = max (x .max ().item () for x in lam )
335+
317336 checker = []
318337 for f , w in zip (fixed , warped ):
319338 patch = max ([s // 8 for s in f .shape ])
@@ -340,8 +359,24 @@ def rescale2d(x):
340359
341360 imkwargs = dict (interpolation = 'nearest' )
342361
362+ def bbox2rect (text ):
363+ (left , bottom , width , height ) = text .get_tightbbox ().bounds
364+ (_ , _ , figwidth , figheight ) = text .figure .bbox .bounds
365+ return (left / figwidth , bottom / figheight , width / figwidth , height / figheight )
366+
367+ def add_title_axis (fig , title ):
368+ _title = fig .suptitle (title )
369+ ax : plt .Axes = fig .add_axes (bbox2rect (_title ))
370+ ax .axis ("off" )
371+ ax .text (0 , 0 , title , animated = True , font = _title .get_fontproperties ())
372+ fig .suptitle ("" )
373+ return ax
374+
375+ all_ll = torch .stack (self .all_ll ).cpu () if self .all_ll else []
376+ lldata = (range (1 , len (all_ll )+ 1 ), all_ll )
377+
343378 fig = self .figure
344- replot = len (getattr (self , 'axes_saved' , [])) != (nb_rows - 1 ) * nb_cols + 1
379+ replot = len (getattr (self , 'axes_saved' , [])) != (nb_rows - 1 ) * nb_cols + 1 + bool ( title )
345380 if replot :
346381 fig .clf ()
347382
@@ -352,95 +387,144 @@ def rescale2d(x):
352387 ax .imshow (moving [k ].cpu (), ** imkwargs )
353388 if k == 0 :
354389 ax .set_title ('moving' )
355- ax .axis ('off' )
390+ ax .axis ('image' )
391+ ax .set_xticks ([])
392+ ax .set_yticks ([])
356393 ax = fig .add_subplot (nb_rows , nb_cols , k * nb_cols + 2 )
357394 axes += [ax ]
358395 ax .imshow (warped [k ].cpu (), ** imkwargs )
359396 if k == 0 :
360397 ax .set_title ('moved' )
361- ax .axis ('off' )
398+ ax .axis ('image' )
399+ ax .set_xticks ([])
400+ ax .set_yticks ([])
362401 ax = fig .add_subplot (nb_rows , nb_cols , k * nb_cols + 3 )
363402 axes += [ax ]
364403 ax .imshow (checker [k ].cpu (), ** imkwargs )
365404 if k == 0 :
366405 ax .set_title ('checker' )
367- ax .axis ('off' )
406+ ax .axis ('image' )
407+ ax .set_xticks ([])
408+ ax .set_yticks ([])
368409 ax = fig .add_subplot (nb_rows , nb_cols , k * nb_cols + 4 )
369410 axes += [ax ]
370411 ax .imshow (fixed [k ].cpu (), ** imkwargs )
371412 if k == 0 :
372413 ax .set_title ('fixed' )
373- ax .axis ('off' )
414+ ax .axis ('image' )
415+ ax .set_xticks ([])
416+ ax .set_yticks ([])
374417 if vel :
375418 ax = fig .add_subplot (nb_rows , nb_cols , k * nb_cols + 5 )
376419 axes += [ax ]
377- d = ax .imshow (vel [k ].cpu (), ** imkwargs )
420+ d = ax .imshow (vel [k ].cpu (), ** imkwargs , vmin = 0 , vmax = vmax )
378421 if k == 0 :
379422 ax .set_title ('displacement' )
380- ax .axis ('off' )
381- fig .colorbar (d , None , ax )
423+ ax .axis ('image' )
424+ ax .set_xticks ([])
425+ ax .set_yticks ([])
426+ if k == 0 :
427+ v0 = d
382428 if lam :
383429 ax = fig .add_subplot (nb_rows , nb_cols , k * nb_cols + 5 + bool (vel ))
384430 axes += [ax ]
385- d = ax .imshow (lam [k ].cpu (), ** imkwargs )
431+ d = ax .imshow (lam [k ].cpu (), ** imkwargs , vmin = 0 , vmax = lmax )
386432 if k == 0 :
387433 ax .set_title ('precision' )
388- ax .axis ('off' )
389- fig .colorbar (d , None , ax )
434+ ax .axis ('image' )
435+ ax .set_xticks ([])
436+ ax .set_yticks ([])
437+ if k == 0 :
438+ l0 = d
439+
440+ if title :
441+ axtitle = add_title_axis (fig , title )
442+ axes += [axtitle ]
443+
390444 ax = fig .add_subplot (nb_rows , 1 , nb_rows )
391445 axes += [ax ]
392- all_ll = torch .stack (self .all_ll ).cpu () if self .all_ll else []
393- # ax.plot(range(1, len(all_ll)+1), all_ll)
394- ax .plot ([])
446+ ax .plot (* lldata , animated = True )
395447 ax .set_ylabel ('NLL' )
396448 ax .set_xlabel ('iteration' )
397- if title :
398- fig .suptitle (title )
449+ ax .set_xticks ([])
450+ ax .set_yticks ([])
451+
452+ if vel or lam :
453+ fig .subplots_adjust (right = 0.9 )
454+ if vel and not lam :
455+ vbar_ax = fig .add_axes ([0.92 , 0.15 , 0.025 , 0.7 ])
456+ elif lam and not vel :
457+ lbar_ax = fig .add_axes ([0.92 , 0.15 , 0.025 , 0.7 ])
458+ else :
459+ vbar_ax = fig .add_axes ([0.92 , 0.10 , 0.025 , 0.35 ])
460+ lbar_ax = fig .add_axes ([0.92 , 0.55 , 0.025 , 0.35 ])
461+ if vel :
462+ self ._vbar = fig .colorbar (v0 , cax = vbar_ax )
463+ self ._vbar_ax = vbar_ax
464+ if lam :
465+ self ._lbar = fig .colorbar (l0 , cax = lbar_ax )
466+ self ._lbar_ax = lbar_ax
399467
400- fig .canvas .draw ()
401- self .plt_saved = [fig .canvas .copy_from_bbox (ax .bbox )
402- for ax in axes ]
403- self .axes_saved = axes
404- fig .canvas .flush_events ()
468+ # fig.canvas.draw()
405469 plt .show (block = False )
470+ self ._plt_background = fig .canvas .copy_from_bbox (fig .bbox )
471+ self .axes_saved = axes
406472
407- lldata = (range (1 , len (all_ll )+ 1 ), all_ll )
408- axes [- 1 ].lines [0 ].set_data (lldata )
409- axes [- 1 ].draw_artist (axes [- 1 ].lines [0 ])
410- fig .canvas .blit (ax .bbox )
411- fig .canvas .flush_events ()
412- else :
413- for elem in self .plt_saved :
414- fig .canvas .restore_region (elem )
473+ # --- Actually plot the stuff by blitting ---
474+ fig .canvas .restore_region (self ._plt_background )
475+
476+ for k in range (kdim ):
477+ j = k * nb_cols
478+ self .axes_saved [j ].images [0 ].set_data (moving [k ].cpu ())
479+ self .axes_saved [j + 1 ].images [0 ].set_data (warped [k ].cpu ())
480+ self .axes_saved [j + 2 ].images [0 ].set_data (checker [k ].cpu ())
481+ self .axes_saved [j + 3 ].images [0 ].set_data (fixed [k ].cpu ())
482+ if vel :
483+ self .axes_saved [j + 4 ].images [0 ].set_data (vel [k ].cpu ())
484+ self .axes_saved [j + 4 ].images [0 ].set_clim (0 , vmax )
485+ j += 1
486+ if lam :
487+ self .axes_saved [j + 4 ].images [0 ].set_data (lam [k ].cpu ())
488+ self .axes_saved [j + 4 ].images [0 ].set_clim (0 , lmax )
489+
490+ self .axes_saved [- 1 ].lines [0 ].set_data (lldata )
491+ self .axes_saved [- 1 ].relim ()
492+ self .axes_saved [- 1 ].autoscale_view ()
493+
494+ if title :
495+ self .axes_saved [- 2 ].texts [0 ].set_text (title )
496+
497+ if vel or lam :
498+ if vel :
499+ self ._vbar .update_ticks ()
500+ self ._vbar .update_normal ()
501+ if lam :
502+ self ._lbar .update_ticks ()
503+ self ._lbar .update_normal ()
504+ fig .canvas .draw ()
415505
416- for k in range (kdim ):
417- j = k * nb_cols
418- self .axes_saved [j ].images [0 ].set_data (moving [k ].cpu ())
419- self .axes_saved [j + 1 ].images [0 ].set_data (warped [k ].cpu ())
420- self .axes_saved [j + 2 ].images [0 ].set_data (checker [k ].cpu ())
421- self .axes_saved [j + 3 ].images [0 ].set_data (fixed [k ].cpu ())
422- if vel :
423- self .axes_saved [j + 4 ].images [0 ].set_data (vel [k ].cpu ())
424- j += 1
425- if lam :
426- self .axes_saved [j + 4 ].images [0 ].set_data (lam [k ].cpu ())
427- all_ll = torch .stack (self .all_ll ).cpu () if self .all_ll else []
428- lldata = (range (1 , len (all_ll )+ 1 ), all_ll )
429- self .axes_saved [- 1 ].lines [0 ].set_data (lldata )
430- self .axes_saved [- 1 ].relim ()
431- self .axes_saved [- 1 ].autoscale_view ()
432- if title :
433- fig ._suptitle .set_text (title )
506+ for ax in self .axes_saved :
507+ ax .draw_artist (ax .patch )
508+ if not ax .texts :
509+ for spine in ax .spines .values ():
510+ ax .draw_artist (spine )
511+ for elem in (ax .images or []):
512+ ax .draw_artist (elem )
513+ for elem in (ax .lines or []):
514+ ax .draw_artist (elem )
515+ for elem in (ax .texts or []):
516+ ax .draw_artist (elem )
517+ fig .canvas .blit (ax .bbox )
434518
435- for ax in self .axes_saved :
436- for elem in (ax .images or []):
437- ax .draw_artist (elem )
438- for elem in (ax .lines or []):
439- ax .draw_artist (elem )
440- fig .canvas .blit (ax .bbox )
441- fig .canvas .flush_events ()
519+ fig .canvas .flush_events ()
442520
443521 self .figure = fig
522+ self ._last_plot = time .time ()
523+
524+ # Run next ploting thread if there is one
525+ self ._plot_lock .release ()
526+ if self ._plot_pending :
527+ self ._plot_pending .run ()
444528
445529 def print (self , step , ll , lla = None , llv = None , in_line_search = False ):
446530
0 commit comments