diff --git a/src/imreg_dft/imreg.py b/src/imreg_dft/imreg.py index 7c2ee1a..b826aca 100644 --- a/src/imreg_dft/imreg.py +++ b/src/imreg_dft/imreg.py @@ -40,6 +40,7 @@ import math import numpy as np + try: import pyfftw.interfaces.numpy_fft as fft except ImportError: @@ -438,7 +439,7 @@ def _translation(im0, im1, filter_pcorr=0, constraints=None, reports=None): return ret, succ -def _phase_correlation(im0, im1, callback=None, * args): +def _phase_correlation(im0, im1, callback=None, *args): """ Computes phase correlation between im0 and im1 @@ -466,7 +467,7 @@ def _phase_correlation(im0, im1, callback=None, * args): # scps = shifted cps scps = fft.fftshift(cps) - (t0, t1), success = callback(scps, * args) + (t0, t1), success = callback(scps, *args) ret = np.array((t0, t1)) # _compensate_fftshift is not appropriate here, this is OK. @@ -477,7 +478,7 @@ def _phase_correlation(im0, im1, callback=None, * args): return ret, success -def transform_img_dict(img, tdict, bgval=None, order=1, invert=False): +def transform_img_dict(img, tdict, mode="constant", bgval=None, order=1, invert=False): """ Wrapper of :func:`transform_img`, works well with the :func:`similarity` output. @@ -486,6 +487,9 @@ def transform_img_dict(img, tdict, bgval=None, order=1, invert=False): img tdict (dictionary): Transformation dictionary --- supposed to contain keys "scale", "angle" and "tvec" + + mode (string): The transformation mode (refer to e.g. + :func:`scipy.ndimage.shift` and its kwarg ``mode``). bgval order invert (bool): Whether to perform inverse transformation --- doesn't @@ -501,7 +505,7 @@ def transform_img_dict(img, tdict, bgval=None, order=1, invert=False): scale = 1.0 / scale angle *= -1 tvec *= -1 - res = transform_img(img, scale, angle, tvec, bgval=bgval, order=order) + res = transform_img(img, scale, angle, tvec, bgval=bgval, mode=mode, order=order) return res @@ -552,22 +556,18 @@ def transform_img(img, scale=1.0, angle=0.0, tvec=(0, 0), if bgval is None: bgval = utils.get_borderval(img) - bigshape = np.round(np.array(img.shape) * 1.2).astype(int) - bg = np.zeros(bigshape, img.dtype) + bgval + dest0 = utils._to_shape(img.copy(), img.shape, mode=mode, bgval=bgval) - dest0 = utils.embed_to(bg, img.copy()) # TODO: We have problems with complex numbers # that are not supported by zoom(), rotate() or shift() if scale != 1.0: dest0 = ndii.zoom(dest0, scale, order=order, mode=mode, cval=bgval) if angle != 0.0: dest0 = ndii.rotate(dest0, angle, order=order, mode=mode, cval=bgval) - if tvec[0] != 0 or tvec[1] != 0: dest0 = ndii.shift(dest0, tvec, order=order, mode=mode, cval=bgval) - bg = np.zeros_like(img) + bgval - dest = utils.embed_to(bg, dest0) + dest = utils._to_shape(dest0, img.shape, mode=mode, bgval=bgval) return dest @@ -595,6 +595,8 @@ def similarity_matrix(scale, angle, vector): EXCESS_CONST = 1.1 + + def _get_log_base(shape, new_r): """ Basically common functionality of :func:`_logpolar` @@ -709,13 +711,13 @@ def imshow(im0, im1, im2, cmap=None, fig=None, **kwargs): pl0.imshow(im0.real, cmap, **kwargs) pl0.grid() share = dict(sharex=pl0, sharey=pl0) - pl = fig.add_subplot(222, ** share) + pl = fig.add_subplot(222, **share) pl.imshow(im1.real, cmap, **kwargs) pl.grid() - pl = fig.add_subplot(223, ** share) + pl = fig.add_subplot(223, **share) pl.imshow(im3, cmap, **kwargs) pl.grid() - pl = fig.add_subplot(224, ** share) + pl = fig.add_subplot(224, **share) pl.imshow(im2.real, cmap, **kwargs) pl.grid() return fig diff --git a/src/imreg_dft/utils.py b/src/imreg_dft/utils.py index 507679d..a8e9595 100644 --- a/src/imreg_dft/utils.py +++ b/src/imreg_dft/utils.py @@ -37,6 +37,53 @@ import scipy.ndimage as ndi + + +def _get_pads_and_slices(shape_src, shape_dest): + """ returns correct pads, slices for transforming shape_src to shape_dest""" + + diff = tuple(s1 - s2 for s1, s2 in zip(shape_dest, shape_src)) + slices = tuple(slice((-d) // 2, -(((-d) + 1) // 2)) if d < 0 else slice(None, None) for d in diff) + pads = tuple((d // 2, d - (d // 2)) if d > 0 else (0, 0) for d in diff) + return pads, slices + + +def _to_shape(img_src, shape_dest, mode="constant", bgval=0): + """ pad/crops img_src to shape_dest + + + Parameters + ---------- + img_src: ndarray + the input image + shape_dest: tuple + desired output shape + mode: str or function + same as numpy.pad, e.g. "constant", "reflect", ... + + bgval: + value to use when mode== "constant" + + Returns + ------- + padded/cropped image + """ + + if img_src.ndim != 2: + raise ValueError("img 2d (but is %sd )" % img_src.ndim) + + if img_src.ndim != len(shape_dest): + raise ValueError("im and shape should be same dimension (%s != %s)" % (img_src.ndim, len(shape_dest))) + + shape_src = img_src.shape + if shape_src == shape_dest: + return img_src + pads, slices = _get_pads_and_slices(shape_src, shape_dest) + kwargs = {} + kwargs.update({"constant_values": bgval} if mode == "constant" else {}) + return np.pad(img_src, pads, mode=mode, **kwargs)[slices] + + def wrap_angle(angles, ceil=2 * np.pi): """ Args: diff --git a/tests/unittests/transform_modes.py b/tests/unittests/transform_modes.py new file mode 100644 index 0000000..e325cab --- /dev/null +++ b/tests/unittests/transform_modes.py @@ -0,0 +1,45 @@ +from __future__ import print_function, division +import unittest as ut + +import imreg_dft.imreg as imreg +import imreg_dft.utils as utils +import numpy.testing as npt + +from scipy.misc import ascent +import scipy.ndimage.interpolation as ndii + + +class TestTransformModes(ut.TestCase): + def testModes(self): + im = ascent() + tvec = [200, 40] + angle = 42 + + for mode in ["constant", "reflect", "wrap"]: + + out1 = imreg.transform_img(im, tvec=tvec, angle=angle, mode=mode, bgval = 0.) + out2 = utils._to_shape(ndii.shift(ndii.rotate(im, angle, order=1, mode=mode), + tvec, mode=mode, order=1), + out1.shape) + + npt.assert_allclose(out1,out2) + + def testModesDict(self): + im = ascent() + tvec = [-67, 20] + angle = 37 + scale = 1. + + for mode in ["constant", "reflect", "wrap"]: + + tdict = {"tvec":tvec,"angle":angle, "scale":scale} + out1 = imreg.transform_img_dict(im, tdict, mode=mode, bgval = 0.) + out2 = utils._to_shape(ndii.shift(ndii.rotate(im, angle, order=1, mode=mode), + tvec, mode=mode, order=1), + out1.shape) + + npt.assert_allclose(out1,out2) + + +if __name__ == '__main__': + ut.main()