Skip to content

Commit d0cfc42

Browse files
committed
ENH(register): plot in background thread
1 parent 36eb36f commit d0cfc42

File tree

1 file changed

+145
-61
lines changed

1 file changed

+145
-61
lines changed

nitorch/tools/registration/pairwise_impl.py

Lines changed: 145 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
The command line function `nitorch register` calls this model under the hood.
1111
1212
"""
13+
from threading import Thread, Lock
1314
from nitorch import spatial
1415
from nitorch.core import utils, linalg
1516
import torch
@@ -245,19 +246,30 @@ def __init__(
245246
self.framerate = framerate
246247
self.figure = figure
247248
self._last_plot = 0
249+
self._plot_lock = Lock()
248250
if self.verbose > 2 and not self.figure:
249251
import matplotlib.pyplot as plt
250252
self.figure = plt.figure()
253+
self._plot_pending = None
251254

252-
def mov2fix(self, fixed, moving, warped, vel=None, title=None, lam=None):
255+
def mov2fix(self, *args, **kwargs):
256+
thread = Thread(target=(lambda: self._mov2fix(*args, **kwargs)))
257+
if self._plot_lock.locked():
258+
self._plot_pending = thread
259+
else:
260+
thread.run()
261+
262+
def _mov2fix(self, fixed, moving, warped, vel=None, title=None, lam=None):
253263
"""Plot registration live"""
254264

265+
self._plot_lock.acquire()
266+
self._plot_pending = None
267+
255268
import time
256-
tic = self._last_plot
257-
toc = time.time()
258-
if toc - tic < 1/self.framerate:
259-
return
260-
self._last_plot = toc
269+
# tic = self._last_plot
270+
# toc = time.time()
271+
# if toc - tic < 1/self.framerate:
272+
# return
261273

262274
import matplotlib.pyplot as plt
263275

@@ -277,7 +289,9 @@ def mov2fix(self, fixed, moving, warped, vel=None, title=None, lam=None):
277289
def rescale2d(x):
278290
if not x.dtype.is_floating_point:
279291
x = x.float()
280-
mn, mx = utils.quantile(x, [0.005, 0.995],
292+
mask = (x == 0).any(dim=list(range(x.ndim-2)))
293+
xmasked = x[~mask]
294+
mn, mx = utils.quantile(xmasked, [0.005, 0.995],
281295
dim=range(-2, 0), bins=1024).unbind(-1)
282296
mx = mx.max(mn + 1e-8)
283297
mn, mx = mn[..., None, None], mx[..., None, None]
@@ -314,6 +328,11 @@ def rescale2d(x):
314328
vel = [vel.square().sum(-1).sqrt()] if vel is not None else []
315329
lam = [rescale2d(lam)] if lam is not None else []
316330

331+
if vel:
332+
vmax = max(v.max().item() for v in vel)
333+
if lam:
334+
lmax = max(x.max().item() for x in lam)
335+
317336
checker = []
318337
for f, w in zip(fixed, warped):
319338
patch = max([s // 8 for s in f.shape])
@@ -340,8 +359,24 @@ def rescale2d(x):
340359

341360
imkwargs = dict(interpolation='nearest')
342361

362+
def bbox2rect(text):
363+
(left, bottom, width, height) = text.get_tightbbox().bounds
364+
(_, _, figwidth, figheight) = text.figure.bbox.bounds
365+
return (left / figwidth, bottom / figheight, width / figwidth, height / figheight)
366+
367+
def add_title_axis(fig, title):
368+
_title = fig.suptitle(title)
369+
ax: plt.Axes = fig.add_axes(bbox2rect(_title))
370+
ax.axis("off")
371+
ax.text(0, 0, title, animated=True, font=_title.get_fontproperties())
372+
fig.suptitle("")
373+
return ax
374+
375+
all_ll = torch.stack(self.all_ll).cpu() if self.all_ll else []
376+
lldata = (range(1, len(all_ll)+1), all_ll)
377+
343378
fig = self.figure
344-
replot = len(getattr(self, 'axes_saved', [])) != (nb_rows - 1) * nb_cols + 1
379+
replot = len(getattr(self, 'axes_saved', [])) != (nb_rows - 1) * nb_cols + 1 + bool(title)
345380
if replot:
346381
fig.clf()
347382

@@ -352,95 +387,144 @@ def rescale2d(x):
352387
ax.imshow(moving[k].cpu(), **imkwargs)
353388
if k == 0:
354389
ax.set_title('moving')
355-
ax.axis('off')
390+
ax.axis('image')
391+
ax.set_xticks([])
392+
ax.set_yticks([])
356393
ax = fig.add_subplot(nb_rows, nb_cols, k * nb_cols + 2)
357394
axes += [ax]
358395
ax.imshow(warped[k].cpu(), **imkwargs)
359396
if k == 0:
360397
ax.set_title('moved')
361-
ax.axis('off')
398+
ax.axis('image')
399+
ax.set_xticks([])
400+
ax.set_yticks([])
362401
ax = fig.add_subplot(nb_rows, nb_cols, k * nb_cols + 3)
363402
axes += [ax]
364403
ax.imshow(checker[k].cpu(), **imkwargs)
365404
if k == 0:
366405
ax.set_title('checker')
367-
ax.axis('off')
406+
ax.axis('image')
407+
ax.set_xticks([])
408+
ax.set_yticks([])
368409
ax = fig.add_subplot(nb_rows, nb_cols, k * nb_cols + 4)
369410
axes += [ax]
370411
ax.imshow(fixed[k].cpu(), **imkwargs)
371412
if k == 0:
372413
ax.set_title('fixed')
373-
ax.axis('off')
414+
ax.axis('image')
415+
ax.set_xticks([])
416+
ax.set_yticks([])
374417
if vel:
375418
ax = fig.add_subplot(nb_rows, nb_cols, k * nb_cols + 5)
376419
axes += [ax]
377-
d = ax.imshow(vel[k].cpu(), **imkwargs)
420+
d = ax.imshow(vel[k].cpu(), **imkwargs, vmin=0, vmax=vmax)
378421
if k == 0:
379422
ax.set_title('displacement')
380-
ax.axis('off')
381-
fig.colorbar(d, None, ax)
423+
ax.axis('image')
424+
ax.set_xticks([])
425+
ax.set_yticks([])
426+
if k == 0:
427+
v0 = d
382428
if lam:
383429
ax = fig.add_subplot(nb_rows, nb_cols, k * nb_cols + 5 + bool(vel))
384430
axes += [ax]
385-
d = ax.imshow(lam[k].cpu(), **imkwargs)
431+
d = ax.imshow(lam[k].cpu(), **imkwargs, vmin=0, vmax=lmax)
386432
if k == 0:
387433
ax.set_title('precision')
388-
ax.axis('off')
389-
fig.colorbar(d, None, ax)
434+
ax.axis('image')
435+
ax.set_xticks([])
436+
ax.set_yticks([])
437+
if k == 0:
438+
l0 = d
439+
440+
if title:
441+
axtitle = add_title_axis(fig, title)
442+
axes += [axtitle]
443+
390444
ax = fig.add_subplot(nb_rows, 1, nb_rows)
391445
axes += [ax]
392-
all_ll = torch.stack(self.all_ll).cpu() if self.all_ll else []
393-
# ax.plot(range(1, len(all_ll)+1), all_ll)
394-
ax.plot([])
446+
ax.plot(*lldata, animated=True)
395447
ax.set_ylabel('NLL')
396448
ax.set_xlabel('iteration')
397-
if title:
398-
fig.suptitle(title)
449+
ax.set_xticks([])
450+
ax.set_yticks([])
451+
452+
if vel or lam:
453+
fig.subplots_adjust(right=0.9)
454+
if vel and not lam:
455+
vbar_ax = fig.add_axes([0.92, 0.15, 0.025, 0.7])
456+
elif lam and not vel:
457+
lbar_ax = fig.add_axes([0.92, 0.15, 0.025, 0.7])
458+
else:
459+
vbar_ax = fig.add_axes([0.92, 0.10, 0.025, 0.35])
460+
lbar_ax = fig.add_axes([0.92, 0.55, 0.025, 0.35])
461+
if vel:
462+
self._vbar = fig.colorbar(v0, cax=vbar_ax)
463+
self._vbar_ax = vbar_ax
464+
if lam:
465+
self._lbar = fig.colorbar(l0, cax=lbar_ax)
466+
self._lbar_ax = lbar_ax
399467

400-
fig.canvas.draw()
401-
self.plt_saved = [fig.canvas.copy_from_bbox(ax.bbox)
402-
for ax in axes]
403-
self.axes_saved = axes
404-
fig.canvas.flush_events()
468+
# fig.canvas.draw()
405469
plt.show(block=False)
470+
self._plt_background = fig.canvas.copy_from_bbox(fig.bbox)
471+
self.axes_saved = axes
406472

407-
lldata = (range(1, len(all_ll)+1), all_ll)
408-
axes[-1].lines[0].set_data(lldata)
409-
axes[-1].draw_artist(axes[-1].lines[0])
410-
fig.canvas.blit(ax.bbox)
411-
fig.canvas.flush_events()
412-
else:
413-
for elem in self.plt_saved:
414-
fig.canvas.restore_region(elem)
473+
# --- Actually plot the stuff by blitting ---
474+
fig.canvas.restore_region(self._plt_background)
475+
476+
for k in range(kdim):
477+
j = k * nb_cols
478+
self.axes_saved[j].images[0].set_data(moving[k].cpu())
479+
self.axes_saved[j+1].images[0].set_data(warped[k].cpu())
480+
self.axes_saved[j+2].images[0].set_data(checker[k].cpu())
481+
self.axes_saved[j+3].images[0].set_data(fixed[k].cpu())
482+
if vel:
483+
self.axes_saved[j+4].images[0].set_data(vel[k].cpu())
484+
self.axes_saved[j+4].images[0].set_clim(0, vmax)
485+
j += 1
486+
if lam:
487+
self.axes_saved[j+4].images[0].set_data(lam[k].cpu())
488+
self.axes_saved[j+4].images[0].set_clim(0, lmax)
489+
490+
self.axes_saved[-1].lines[0].set_data(lldata)
491+
self.axes_saved[-1].relim()
492+
self.axes_saved[-1].autoscale_view()
493+
494+
if title:
495+
self.axes_saved[-2].texts[0].set_text(title)
496+
497+
if vel or lam:
498+
if vel:
499+
self._vbar.update_ticks()
500+
self._vbar.update_normal()
501+
if lam:
502+
self._lbar.update_ticks()
503+
self._lbar.update_normal()
504+
fig.canvas.draw()
415505

416-
for k in range(kdim):
417-
j = k * nb_cols
418-
self.axes_saved[j].images[0].set_data(moving[k].cpu())
419-
self.axes_saved[j+1].images[0].set_data(warped[k].cpu())
420-
self.axes_saved[j+2].images[0].set_data(checker[k].cpu())
421-
self.axes_saved[j+3].images[0].set_data(fixed[k].cpu())
422-
if vel:
423-
self.axes_saved[j+4].images[0].set_data(vel[k].cpu())
424-
j += 1
425-
if lam:
426-
self.axes_saved[j+4].images[0].set_data(lam[k].cpu())
427-
all_ll = torch.stack(self.all_ll).cpu() if self.all_ll else []
428-
lldata = (range(1, len(all_ll)+1), all_ll)
429-
self.axes_saved[-1].lines[0].set_data(lldata)
430-
self.axes_saved[-1].relim()
431-
self.axes_saved[-1].autoscale_view()
432-
if title:
433-
fig._suptitle.set_text(title)
506+
for ax in self.axes_saved:
507+
ax.draw_artist(ax.patch)
508+
if not ax.texts:
509+
for spine in ax.spines.values():
510+
ax.draw_artist(spine)
511+
for elem in (ax.images or []):
512+
ax.draw_artist(elem)
513+
for elem in (ax.lines or []):
514+
ax.draw_artist(elem)
515+
for elem in (ax.texts or []):
516+
ax.draw_artist(elem)
517+
fig.canvas.blit(ax.bbox)
434518

435-
for ax in self.axes_saved:
436-
for elem in (ax.images or []):
437-
ax.draw_artist(elem)
438-
for elem in (ax.lines or []):
439-
ax.draw_artist(elem)
440-
fig.canvas.blit(ax.bbox)
441-
fig.canvas.flush_events()
519+
fig.canvas.flush_events()
442520

443521
self.figure = fig
522+
self._last_plot = time.time()
523+
524+
# Run next ploting thread if there is one
525+
self._plot_lock.release()
526+
if self._plot_pending:
527+
self._plot_pending.run()
444528

445529
def print(self, step, ll, lla=None, llv=None, in_line_search=False):
446530

0 commit comments

Comments
 (0)