diff --git a/examples/operator/convolution_operator.py b/examples/operator/convolution_operator.py index bf344a183e4..319c357dcc5 100644 --- a/examples/operator/convolution_operator.py +++ b/examples/operator/convolution_operator.py @@ -56,4 +56,4 @@ def adjoint(self): # Display the results using the show method kernel.show('kernel') phantom.show('phantom') -g.show('convolved phantom') +g.show('convolved phantom', force_show=True) diff --git a/examples/operator/convolution_operator_pytorch.py b/examples/operator/convolution_operator_pytorch.py new file mode 100644 index 00000000000..7c7ce21d54a --- /dev/null +++ b/examples/operator/convolution_operator_pytorch.py @@ -0,0 +1,65 @@ +"""Create a convolution operator by wrapping a library.""" + +import odl +import numpy as np +import torch + + +class Convolution(odl.Operator): + """Operator calculating the convolution of a kernel with a function. + + The operator inherits from ``odl.Operator`` to be able to be used with ODL. + """ + + def __init__(self, kernel, domain, range): + """Initialize a convolution operator with a known kernel.""" + + # Store the kernel + self.kernel = kernel + + # Initialize the Operator class by calling its __init__ method. + # This sets properties such as domain and range and allows the other + # operator convenience functions to work. + super(Convolution, self).__init__( + domain=domain, range=range, linear=True) + + def _call(self, x): + """Implement calling the operator by calling PyTorch.""" + return self.range.element(torch.conv2d( input=x.data.unsqueeze(0) + , weight=self.kernel.unsqueeze(0).unsqueeze(0) + , stride=(1,1) + , padding="same" + ).squeeze(0) + ) + + @property + def adjoint(self): + """Implement ``self.adjoint``. + + For a convolution operator, the adjoint is given by the convolution + with a kernel with flipped axes. In particular, if the kernel is + symmetric the operator is self-adjoint. + """ + return Convolution( torch.flip(self.kernel, dims=(0,1)) + , domain=self.range, range=self.domain ) + + +# Define the space on which the problem should be solved +# Here the square [-1, 1] x [-1, 1] discretized on a 100x100 grid +space = odl.uniform_discr([-1, -1], [1, 1], [100, 100], impl='pytorch', dtype=np.float32) + +# Convolution kernel, a small centered rectangle +kernel = torch.ones((5,5)) + +# Create convolution operator +A = Convolution(kernel, domain=space, range=space) + +# Create phantom (the "unknown" solution) +phantom = odl.phantom.shepp_logan(space, modified=True) + +# Apply convolution to phantom to create data +g = A.adjoint(phantom) + +# Display the results using the show method +phantom.show('phantom') +g.show('convolved phantom', force_show=True) diff --git a/examples/solvers/deconvolution_1d.py b/examples/solvers/deconvolution_1d.py index b60d6d23d55..d6ee6bb53b4 100644 --- a/examples/solvers/deconvolution_1d.py +++ b/examples/solvers/deconvolution_1d.py @@ -27,11 +27,13 @@ def opnorm(self): # Discretization -discr_space = odl.uniform_discr(0, 10, 500, impl='numpy') +discr_space = odl.uniform_discr(-5, 5, 500, impl='numpy') # Complicated functions to check performance -kernel = discr_space.element(lambda x: np.exp(x / 2) * np.cos(x * 1.172)) -phantom = discr_space.element(lambda x: x ** 2 * np.sin(x) ** 2 * (x > 5)) +kernel = discr_space.element(lambda x: np.exp(-x**2 * 2) * np.cos(x * 1.172)) + +# phantom = discr_space.element(lambda x: (x+5) ** 2 * np.sin(x+5) ** 2 * (x > 0)) +phantom = discr_space.element(lambda x: np.cos(0*x) * (x > -1) * (x < 1)) # Create operator conv = Convolution(kernel) @@ -41,21 +43,28 @@ def opnorm(self): omega = 1 / conv.opnorm() ** 2 -# Display callback -def callback(x): - plt.plot(conv(x)) - +def test_with_plot(conv, phantom, solver, **extra_args): + fig, axs = plt.subplots(2) + fig.suptitle("CGN") + axs[0].set_title("x") + axs[1].set_title("k*x") + axs[0].plot(phantom) + axs[1].plot(conv(phantom)) + def plot_callback(x): + axs[0].plot(conv(x), '--') + axs[1].plot(conv(x), '--') + solver(conv, discr_space.zero(), phantom, iterations, callback=plot_callback, **extra_args) # Test CGN -plt.figure() -plt.plot(phantom) -odl.solvers.conjugate_gradient_normal(conv, discr_space.zero(), phantom, - iterations, callback) +test_with_plot(conv, phantom, odl.solvers.conjugate_gradient_normal) -# Landweber -plt.figure() -plt.plot(phantom) -odl.solvers.landweber(conv, discr_space.zero(), phantom, - iterations, omega, callback) +# test_with_plot(conv, phantom, odl.solvers.landweber, omega=omega) +# # Landweber +# lw_fig, lw_axs = plt.subplots(1) +# lw_fig.suptitle("Landweber") +# lw_axs.plot(phantom) +# odl.solvers.landweber(conv, discr_space.zero(), phantom, +# iterations, omega, lambda x: lw_axs.plot(conv(x))) +# plt.show() diff --git a/examples/solvers/deconvolution_1d_pytorch.py b/examples/solvers/deconvolution_1d_pytorch.py new file mode 100644 index 00000000000..c517baa2eac --- /dev/null +++ b/examples/solvers/deconvolution_1d_pytorch.py @@ -0,0 +1,91 @@ +"""Example of a deconvolution problem with different solvers (CPU).""" + +import numpy as np +import torch +import matplotlib.pyplot as plt +import scipy.signal +import odl + + +class Convolution(odl.Operator): + def __init__(self, kernel, domain, range, adjkernel=None): + self.kernel = kernel + self.adjkernel = torch.flip(kernel, dims=(0,)) if adjkernel is None else adjkernel + self.norm = float(torch.sum(torch.abs(self.kernel))) + super(Convolution, self).__init__( + domain=domain, range=range, linear=True) + + def _call(self, x): + return self.range.element( + torch.conv1d( input=x.data.unsqueeze(0) + , weight=self.kernel.unsqueeze(0).unsqueeze(0) + , stride=1 + , padding="same" + ).squeeze(0) + ) + + @property + def adjoint(self): + return Convolution( self.adjkernel + , domain=self.range, range=self.domain + , adjkernel = self.kernel + ) + + def opnorm(self): + return self.norm + + +resolution = 50 + +# Discretization +discr_space = odl.uniform_discr(-5, 5, resolution*10, impl='pytorch', dtype=np.float32) + +# Complicated functions to check performance +def mk_kernel(): + q = 1.172 + # Select main lobe and one side lobe on each side + r = np.ceil(3*np.pi/(2*q)) + # Quantised to resolution + nr = int(np.ceil(r*resolution)) + r = nr / resolution + x = torch.linspace(-r, r, nr*2 + 1) + return torch.exp(-x**2 * 2) * np.cos(x * q) +kernel = mk_kernel() + +phantom = discr_space.element(lambda x: np.ones_like(x) ** 2 * (x > -1) * (x < 1)) +# phantom = discr_space.element(lambda x: x ** 2 * np.sin(x) ** 2 * (x > 5)) + +# Create operator +conv = Convolution(kernel, domain=discr_space, range=discr_space) + +# Dampening parameter for landweber +iterations = 100 +omega = 1 / conv.opnorm() ** 2 + + + +def test_with_plot(conv, phantom, solver, **extra_args): + fig, axs = plt.subplots(2) + fig.suptitle("CGN") + def plot_fn(ax_id, fn, *plot_args, **plot_kwargs): + axs[ax_id].plot(fn, *plot_args, **plot_kwargs) + axs[0].set_title("x") + axs[1].set_title("k*x") + plot_fn(0, phantom) + plot_fn(1, conv(phantom)) + def plot_callback(x): + plot_fn(0, conv(x), '--') + plot_fn(1, conv(x), '--') + solver(conv, discr_space.zero(), phantom, iterations, callback=plot_callback, **extra_args) + +# Test CGN +test_with_plot(conv, phantom, odl.solvers.conjugate_gradient_normal) + +# # Landweber +# lw_fig, lw_axs = plt.subplots(1) +# lw_fig.suptitle("Landweber") +# lw_axs.plot(phantom) +# odl.solvers.landweber(conv, discr_space.zero(), phantom, +# iterations, omega, lambda x: lw_axs.plot(conv(x))) + +plt.show() diff --git a/examples/solvers/pdhg_denoising.py b/examples/solvers/pdhg_denoising.py index ed2662d3cf9..321891efe3e 100644 --- a/examples/solvers/pdhg_denoising.py +++ b/examples/solvers/pdhg_denoising.py @@ -11,28 +11,29 @@ """ import numpy as np +import torch import scipy.misc import odl +impl = 'numpy' +# impl = 'pytorch' + # Read test image: use only every second pixel, convert integer to float, # and rotate to get the image upright -image = np.rot90(scipy.misc.ascent()[::2, ::2], 3).astype('float') +image = np.rot90(scipy.datasets.ascent()[::2, ::2], 3).astype('float') shape = image.shape # Rescale max to 1 image /= image.max() # Discretized spaces -space = odl.uniform_discr([0, 0], shape, shape) +space = odl.uniform_discr([0, 0], shape, shape, impl=impl) # Original image -orig = space.element(image) +orig = space.element(image.copy()) # Add noise -image += 0.1 * odl.phantom.white_noise(orig.space) - -# Data of noisy image -noisy = space.element(image) +noisy = space.element(image) + 0.1 * odl.phantom.white_noise(orig.space) # Gradient operator gradient = odl.Gradient(space) diff --git a/examples/solvers/pdhg_denoising_pytorch.py b/examples/solvers/pdhg_denoising_pytorch.py new file mode 100644 index 00000000000..a985ca504ab --- /dev/null +++ b/examples/solvers/pdhg_denoising_pytorch.py @@ -0,0 +1,97 @@ +"""Total variation denoising using PDHG. + +Solves the optimization problem + + min_{x >= 0} 1/2 ||x - g||_2^2 + lam || |grad(x)| ||_1 + +Where ``grad`` the spatial gradient and ``g`` is given noisy data. + +For further details and a description of the solution method used, see +https://odlgroup.github.io/odl/guide/pdhg_guide.html in the ODL documentation. +""" + +import numpy as np +import torch +import scipy.misc +import odl +import cProfile + +# Read test image: use only every second pixel, convert integer to float, +# and rotate to get the image upright +image = np.rot90(scipy.datasets.ascent()[::2, ::2], 3).astype('float') +shape = image.shape + +# Rescale max to 1 +image /= image.max() + +# Discretized spaces +space = odl.uniform_discr([0, 0], shape, shape, impl='pytorch') + +# Original image +orig = space.element(image.copy()) + +orig.data.requires_grad = False + +# Add noise +noisy = space.element(image) + 0.1 * odl.phantom.white_noise(orig.space) + +noisy.data.requires_grad = False + +# Gradient operator +gradient = odl.Gradient(space) + +# grad_xmp = gradient(orig) +# grad_xmp.show(title = "Grad-op applied to original") + +# Matrix of operators +op = odl.BroadcastOperator(odl.IdentityOperator(space), gradient) + +# Set up the functionals + +# l2-squared data matching +l2_norm = odl.solvers.L2NormSquared(space).translated(noisy) + +# Isotropic TV-regularization: l1-norm of grad(x) +l1_norm = 0.15 * odl.solvers.L1Norm(gradient.range) + +# Make separable sum of functionals, order must correspond to the operator K +g = odl.solvers.SeparableSum(l2_norm, l1_norm) + +# Non-negativity constraint +f = odl.solvers.IndicatorNonnegativity(op.domain) + +# --- Select solver parameters and solve using PDHG --- # + +# Estimated operator norm, add 10 percent to ensure ||K||_2^2 * sigma * tau < 1 +op_norm = 1.1 * odl.power_method_opnorm(op, xstart=noisy, maxiter=10) + # 3.2833764101732785 +print(f"{op_norm=}") + +niter = 200 # Number of iterations +tau = 1.0 / op_norm # Step size for the primal variable +sigma = 1.0 / op_norm # Step size for the dual variable + +# Optional: pass callback objects to solver +callback = (odl.solvers.CallbackPrintIteration() & + odl.solvers.CallbackShow(step=5)) + +# Starting point +x = op.domain.zero() + +x.data.requires_grad = False + +print("Go solve...") + +# Run algorithm (and display intermediates) +def do_running(): + with torch.no_grad(): + odl.solvers.pdhg(x, f, g, op, niter=niter, tau=tau, sigma=sigma, + callback=callback) + +do_running() +# cProfile.run('do_running()') + +# Display images +orig.show(title='Original Image') +noisy.show(title='Noisy Image') +x.show(title='Reconstruction', force_show=True) diff --git a/odl/discr/diff_ops.py b/odl/discr/diff_ops.py index e7ba9d7f168..23e09678930 100644 --- a/odl/discr/diff_ops.py +++ b/odl/discr/diff_ops.py @@ -11,11 +11,13 @@ from __future__ import absolute_import, division, print_function import numpy as np +import torch +from math import prod from odl.discr.discr_space import DiscretizedSpace from odl.operator.tensor_ops import PointwiseTensorFieldOperator from odl.space import ProductSpace -from odl.util import indent, signature_string, writable_array +from odl.util import indent, signature_string, writable_array, uses_pytorch, dtype_type __all__ = ('PartialDerivative', 'Gradient', 'Divergence', 'Laplacian') @@ -344,20 +346,25 @@ def __init__(self, domain=None, range=None, method='forward', def _call(self, x, out=None): """Calculate the spatial gradient of ``x``.""" - if out is None: - out = self.range.element() - x_arr = x.asarray() ndim = self.domain.ndim dx = self.domain.cell_sides - for axis in range(ndim): - with writable_array(out[axis]) as out_arr: - finite_diff(x_arr, axis=axis, dx=dx[axis], method=self.method, + if out is None: + return self.range.element([ + finite_diff(x_arr, axis=axis, dx=dx[axis], method=self.method, pad_mode=self.pad_mode, pad_const=self.pad_const, - out=out_arr) - return out + ) + for axis in range(ndim)]) + else: + for axis in range(ndim): + with writable_array(out[axis]) as out_arr: + finite_diff(x_arr, axis=axis, dx=dx[axis], method=self.method, + pad_mode=self.pad_mode, + pad_const=self.pad_const, + out=out_arr) + return out def derivative(self, point=None): """Return the derivative operator. @@ -554,25 +561,38 @@ def __init__(self, domain=None, range=None, method='forward', def _call(self, x, out=None): """Calculate the divergence of ``x``.""" - if out is None: - out = self.range.element() ndim = self.range.ndim dx = self.range.cell_sides - tmp = np.empty(out.shape, out.dtype, order=out.space.default_order) - with writable_array(out) as out_arr: - for axis in range(ndim): - finite_diff(x[axis], axis=axis, dx=dx[axis], - method=self.method, pad_mode=self.pad_mode, - pad_const=self.pad_const, - out=tmp) - if axis == 0: - out_arr[:] = tmp - else: - out_arr += tmp + torch_impl = uses_pytorch(x[0]) - return out + def directional_derivative(axis, dd_out=None): + return finite_diff( x[axis], axis=axis, dx=dx[axis] + , method=self.method, pad_mode=self.pad_mode + , pad_const=self.pad_const + , out=dd_out ) + + if out is None: + result = directional_derivative(0) + for axis in range(1,len(x)): + result += directional_derivative(axis) + + return self.range.element(result) + + else: + assert(not torch_impl) + + tmp = self.range.element().asarray() + with writable_array(out) as out_arr: + for axis in range(ndim): + directional_derivative(axis, out=tmp) + if axis == 0: + out_arr[:] = tmp + else: + out_arr += tmp + + return out def derivative(self, point=None): """Return the derivative operator. @@ -785,106 +805,10 @@ def __str__(self): return '{}:\n{}'.format(self.__class__.__name__, indent(dom_ran_str)) -def finite_diff(f, axis, dx=1.0, method='forward', out=None, +def _finite_diff_numpy(f_arr, axis, dx=1.0, method='forward', out=None, pad_mode='constant', pad_const=0): - """Calculate the partial derivative of ``f`` along a given ``axis``. - - In the interior of the domain of f, the partial derivative is computed - using first-order accurate forward or backward difference or - second-order accurate central differences. - - With padding the same method and thus accuracy is used on endpoints as - in the interior i.e. forward and backward differences use first-order - accuracy on edges while central differences use second-order accuracy at - edges. - - Without padding one-sided forward or backward differences are used at - the boundaries. The accuracy at the endpoints can then also be - triggered by the edge order. + """ NumPy-specific version of `finite_diff`. """ - The returned array has the same shape as the input array ``f``. - - Per default forward difference with dx=1 and no padding is used. - - Parameters - ---------- - f : `array-like` - An N-dimensional array. - axis : int - The axis along which the partial derivative is evaluated. - dx : float, optional - Scalar specifying the distance between sampling points along ``axis``. - method : {'central', 'forward', 'backward'}, optional - Finite difference method which is used in the interior of the domain - of ``f``. - out : `numpy.ndarray`, optional - An N-dimensional array to which the output is written. Has to have - the same shape as the input array ``f``. - pad_mode : string, optional - The padding mode to use outside the domain. - - ``'constant'``: Fill with ``pad_const``. - - ``'symmetric'``: Reflect at the boundaries, not doubling the - outmost values. - - ``'periodic'``: Fill in values from the other side, keeping - the order. - - ``'order0'``: Extend constantly with the outmost values - (ensures continuity). - - ``'order1'``: Extend with constant slope (ensures continuity of - the first derivative). This requires at least 2 values along - each axis where padding is applied. - - ``'order2'``: Extend with second order accuracy (ensures continuity - of the second derivative). This requires at least 3 values along - each axis where padding is applied. - - pad_const : float, optional - For ``pad_mode == 'constant'``, ``f`` assumes ``pad_const`` for - indices outside the domain of ``f`` - - Returns - ------- - out : `numpy.ndarray` - N-dimensional array of the same shape as ``f``. If ``out`` was - provided, the returned object is a reference to it. - - Examples - -------- - >>> f = np.array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]) - - >>> finite_diff(f, axis=0) - array([ 1., 1., 1., 1., 1., 1., 1., 1., 1., -9.]) - - Without arguments the above defaults to: - - >>> finite_diff(f, axis=0, dx=1.0, method='forward', pad_mode='constant') - array([ 1., 1., 1., 1., 1., 1., 1., 1., 1., -9.]) - - Parameters can be changed one by one: - - >>> finite_diff(f, axis=0, dx=0.5) - array([ 2., 2., 2., 2., 2., 2., 2., 2., 2., -18.]) - >>> finite_diff(f, axis=0, pad_mode='order1') - array([ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]) - - Central differences and different edge orders: - - >>> finite_diff(0.5 * f ** 2, axis=0, method='central', pad_mode='order1') - array([ 0.5, 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 8.5]) - >>> finite_diff(0.5 * f ** 2, axis=0, method='central', pad_mode='order2') - array([-0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]) - - In-place evaluation: - - >>> out = f.copy() - >>> out is finite_diff(f, axis=0, out=out) - True - """ - f_arr = np.asarray(f) ndim = f_arr.ndim if f_arr.shape[axis] < 2: @@ -909,42 +833,44 @@ def finite_diff(f, axis, dx=1.0, method='forward', out=None, raise ValueError('`pad_mode` {} not understood' ''.format(pad_mode)) - pad_const = f.dtype.type(pad_const) + pad_const = np.array([pad_const], dtype = f_arr.dtype) if out is None: out = np.empty_like(f_arr) else: - if out.shape != f.shape: + if out.shape != f_arr.shape: raise ValueError('expected output shape {}, got {}' ''.format(f.shape, out.shape)) + orig_shape = f_arr.shape - if f_arr.shape[axis] < 2 and pad_mode == 'order1': + if orig_shape[axis] < 2 and pad_mode == 'order1': raise ValueError("size of array to small to use 'order1', needs at " "least 2 elements along axis {}.".format(axis)) - if f_arr.shape[axis] < 3 and pad_mode == 'order2': + if orig_shape[axis] < 3 and pad_mode == 'order2': raise ValueError("size of array to small to use 'order2', needs at " "least 3 elements along axis {}.".format(axis)) - # create slice objects: initially all are [:, :, ..., :] - - # Swap axes so that the axis of interest is first. This is a O(1) - # operation and is done to simplify the code below. + # Swap axes so that the axis of interest is first. In NumPy (but not PyTorch), + # this is a O(1) operation and is done to simplify the code below. out, out_in = np.swapaxes(out, 0, axis), out f_arr = np.swapaxes(f_arr, 0, axis) + def fd_subtraction(a, b): + np.subtract(a, b, out=out[1:-1]) + # Interior of the domain of f if method == 'central': # 1D equivalent: out[1:-1] = (f[2:] - f[:-2])/2.0 - np.subtract(f_arr[2:], f_arr[:-2], out=out[1:-1]) + fd_subtraction(f_arr[2:], f_arr[:-2]) out[1:-1] /= 2.0 elif method == 'forward': # 1D equivalent: out[1:-1] = (f[2:] - f[1:-1]) - np.subtract(f_arr[2:], f_arr[1:-1], out=out[1:-1]) + fd_subtraction(f_arr[2:], f_arr[1:-1]) elif method == 'backward': # 1D equivalent: out[1:-1] = (f[1:-1] - f[:-2]) - np.subtract(f_arr[1:-1], f_arr[:-2], out=out[1:-1]) + fd_subtraction(f_arr[1:-1], f_arr[:-2]) # Boundaries if pad_mode == 'constant': @@ -1129,6 +1055,194 @@ def finite_diff(f, axis, dx=1.0, method='forward', out=None, return out_in +def _finite_diff_pytorch(f_arr, axis, dx=1.0, method='forward', + pad_mode='constant', pad_const=0): + """ PyTorch-specific version of `finite_diff`. Notice that this has no output argument. """ + + ndim = f_arr.ndim + + if f_arr.shape[axis] < 2: + raise ValueError('in axis {}: at least two elements required, got {}' + ''.format(axis, f_arr.shape[axis])) + + if axis < 0: + axis += ndim + if not (0 <= axis < ndim): + raise IndexError('`axis` {} outside the valid range 0 ... {}' + ''.format(axis, ndim - 1)) + + dx, dx_in = float(dx), dx + if dx <= 0 or not np.isfinite(dx): + raise ValueError("`dx` must be positive, got {}".format(dx_in)) + + method, method_in = str(method).lower(), method + if method not in _SUPPORTED_DIFF_METHODS: + raise ValueError('`method` {} was not understood'.format(method_in)) + + if pad_mode not in _SUPPORTED_PAD_MODES: + raise ValueError('`pad_mode` {} not understood' + ''.format(pad_mode)) + + orig_shape = f_arr.shape + + if orig_shape[axis] < 2 and pad_mode == 'order1': + raise ValueError("size of array to small to use 'order1', needs at " + "least 2 elements along axis {}.".format(axis)) + if orig_shape[axis] < 3 and pad_mode == 'order2': + raise ValueError("size of array to small to use 'order2', needs at " + "least 3 elements along axis {}.".format(axis)) + + # Reshape (in O(1)), so the axis of interest is the pænultimate, all previous + # axes are flattened into the batch dimension, and all subsequent axes flattened + # into the final dimension. This allows a batched 2D convolution of final size 1 + # to perform the differentiation in only the axis of interest. + f_arr = f_arr.reshape([ prod(orig_shape[:axis]) + , 1 + , orig_shape[axis] + , prod(orig_shape[axis+1:]) + ]) + + dtype = f_arr.dtype + + # Kernel for convolution that expresses the finite-difference operator on, at least, + # the interior of the domain of f + def as_kernel(mat): + return torch.tensor(mat, dtype=dtype, device=f_arr.device) + if method == 'central': + fd_kernel = as_kernel([[[[-1],[0],[1]]]]) / (2*dx) + elif method == 'forward': + fd_kernel = as_kernel([[[[0],[-1],[1]]]]) / dx + elif method == 'backward': + fd_kernel = as_kernel([[[[-1],[1],[0]]]]) / dx + + if pad_mode == 'constant': + if pad_const==0: + result = torch.conv2d(f_arr, fd_kernel, padding='same') + else: + padding_arr = torch.ones_like(f_arr[:,:,0:1,:]) * pad_const + result = torch.conv2d( torch.cat([padding_arr, f_arr, padding_arr], dim=-2) + , fd_kernel, padding='valid' ) + elif pad_mode == 'periodic': + result = torch.conv2d(f_arr, fd_kernel, padding='circular') + + else: + raise NotImplementedError(f'{pad_mode=} not implemented for PyTorch') + + return result.reshape(orig_shape) + + +def finite_diff(f, axis, dx=1.0, method='forward', out=None, + pad_mode='constant', pad_const=0): + """Calculate the partial derivative of ``f`` along a given ``axis``. + + In the interior of the domain of f, the partial derivative is computed + using first-order accurate forward or backward difference or + second-order accurate central differences. + + With padding the same method and thus accuracy is used on endpoints as + in the interior i.e. forward and backward differences use first-order + accuracy on edges while central differences use second-order accuracy at + edges. + + Without padding one-sided forward or backward differences are used at + the boundaries. The accuracy at the endpoints can then also be + triggered by the edge order. + + The returned array has the same shape as the input array ``f``. + + Per default forward difference with dx=1 and no padding is used. + + Parameters + ---------- + f : `array-like` + An N-dimensional array. + axis : int + The axis along which the partial derivative is evaluated. + dx : float, optional + Scalar specifying the distance between sampling points along ``axis``. + method : {'central', 'forward', 'backward'}, optional + Finite difference method which is used in the interior of the domain + of ``f``. + out : `numpy.ndarray`, optional + An N-dimensional array to which the output is written. Has to have + the same shape as the input array ``f``. + pad_mode : string, optional + The padding mode to use outside the domain. + + ``'constant'``: Fill with ``pad_const``. + + ``'symmetric'``: Reflect at the boundaries, not doubling the + outmost values. + + ``'periodic'``: Fill in values from the other side, keeping + the order. + + ``'order0'``: Extend constantly with the outmost values + (ensures continuity). + + ``'order1'``: Extend with constant slope (ensures continuity of + the first derivative). This requires at least 2 values along + each axis where padding is applied. + + ``'order2'``: Extend with second order accuracy (ensures continuity + of the second derivative). This requires at least 3 values along + each axis where padding is applied. + + pad_const : float, optional + For ``pad_mode == 'constant'``, ``f`` assumes ``pad_const`` for + indices outside the domain of ``f`` + + Returns + ------- + out : `numpy.ndarray` + N-dimensional array of the same shape as ``f``. If ``out`` was + provided, the returned object is a reference to it. + + Examples + -------- + >>> f = np.array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]) + + >>> finite_diff(f, axis=0) + array([ 1., 1., 1., 1., 1., 1., 1., 1., 1., -9.]) + + Without arguments the above defaults to: + + >>> finite_diff(f, axis=0, dx=1.0, method='forward', pad_mode='constant') + array([ 1., 1., 1., 1., 1., 1., 1., 1., 1., -9.]) + + Parameters can be changed one by one: + + >>> finite_diff(f, axis=0, dx=0.5) + array([ 2., 2., 2., 2., 2., 2., 2., 2., 2., -18.]) + >>> finite_diff(f, axis=0, pad_mode='order1') + array([ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]) + + Central differences and different edge orders: + + >>> finite_diff(0.5 * f ** 2, axis=0, method='central', pad_mode='order1') + array([ 0.5, 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 8.5]) + >>> finite_diff(0.5 * f ** 2, axis=0, method='central', pad_mode='order2') + array([-0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]) + + In-place evaluation: + + >>> out = f.copy() + >>> out is finite_diff(f, axis=0, out=out) + True + """ + if uses_pytorch(f): + if out is None: + return _finite_diff_pytorch(f.data, axis, dx=dx, method=method, + pad_mode=pad_mode, pad_const=pad_const) + else: + assert(isinstance(out, torch.Tensor)), f"{type(out)=}" + out[:] = _finite_diff_pytorch(f.data, axis, dx=dx, method=method, + pad_mode=pad_mode, pad_const=pad_const) + else: + return _finite_diff_numpy(np.asarray(f), axis, dx=dx, method=method, out=out, + pad_mode=pad_mode, pad_const=pad_const) + + if __name__ == '__main__': from odl.util.testutils import run_doctests diff --git a/odl/discr/discr_space.py b/odl/discr/discr_space.py index 033fb5e0c95..6fab39cacf8 100644 --- a/odl/discr/discr_space.py +++ b/odl/discr/discr_space.py @@ -251,6 +251,18 @@ def available_dtypes(self): """ return self.tspace.available_dtypes() + def is_suitable_scalar(self, s): + """Determine whether `s` has a type that can be scalar-multiplied with + elements of this space. + """ + return self.tspace.is_suitable_scalar(s) + + def as_suitable_scalar(self, s): + """Try to convert `s` to a type that can be scalar-multiplied with + elements of this space. + """ + return self.tspace.as_suitable_scalar(s) + # --- Derived properties @property @@ -778,6 +790,10 @@ def __ipow__(self, p): self.tensor.__ipow__(p) return self + def __rmul__(self, μ): + """Implement ``μ * self``.""" + return self.space.element(μ * self.tensor) + @property def real(self): """Real part of this element. @@ -956,6 +972,22 @@ def __setitem__(self, indices, values): values = values.tensor self.tensor.__setitem__(indices, values) + def __array__(self, dtype=None): + """Return a Numpy array from this tensor. + (Contrast with the `asarray` method, which may give other types of array, + not just NumPy.) + + Parameters + ---------- + dtype : + Specifier for the data type of the output array. + + Returns + ------- + array : `numpy.ndarray` + """ + return self.tensor.__array__() + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): """Interface to Numpy's ufunc machinery. @@ -1528,7 +1560,7 @@ def show(self, title=None, method='', coords=None, indices=None, # Squeeze grid and values according to the index expression part = self.space.partition[indices].squeeze() - values = self.asarray()[indices].squeeze() + values = np.array(self)[indices].squeeze() return show_discrete_data(values, part, title=title, method=method, force_show=force_show, fig=fig, @@ -1593,8 +1625,15 @@ def uniform_discr_frompartition(partition, dtype=None, impl='numpy', **kwargs): else: weighting = partition.cell_volume + tensor_impl_args = {} + + if impl=='pytorch': + for arg in ['torch_device']: + if arg in kwargs: + tensor_impl_args[arg] = kwargs.pop(arg) + tspace = tspace_type(partition.shape, dtype, exponent=exponent, - weighting=weighting) + weighting=weighting, **tensor_impl_args) return DiscretizedSpace(partition, tspace, **kwargs) diff --git a/odl/operator/default_ops.py b/odl/operator/default_ops.py index 448da71f2c2..0f1df24f9d9 100644 --- a/odl/operator/default_ops.py +++ b/odl/operator/default_ops.py @@ -320,13 +320,14 @@ def multiplicand(self): def _call(self, x, out=None): """Multiply ``x`` and write to ``out`` if given.""" + μ = x.space.as_suitable_scalar(self.multiplicand) if out is None: - return x * self.multiplicand + return x * μ elif not self.__range_is_field: if self.__domain_is_field: - out.lincomb(x, self.multiplicand) + out.lincomb(x, μ) else: - out.assign(self.multiplicand * x) + out.assign(x * μ) else: raise ValueError('can only use `out` with `LinearSpace` range') diff --git a/odl/operator/oputils.py b/odl/operator/oputils.py index 74f478cc5eb..0cc1b2f8660 100644 --- a/odl/operator/oputils.py +++ b/odl/operator/oputils.py @@ -14,7 +14,7 @@ from future.utils import native from odl.space import ProductSpace from odl.space.base_tensors import TensorSpace -from odl.util import nd_iterator +from odl.util import nd_iterator, uses_pytorch from odl.util.testutils import noise_element __all__ = ( @@ -228,17 +228,27 @@ def calc_opnorm(x_norm): # initial guess of opnorm opnorm = calc_opnorm(x_norm) - # temporary to improve performance - tmp = op.range.element() + if uses_pytorch(x): + calc_in_place = False # In-place updates are not efficient in PyTorch + else: + calc_in_place = True + # temporary to improve performance in NumPy + tmp = op.range.element() # Use the power method to estimate opnorm for i in range(ncalls): if use_normal: - op(x, out=tmp) - op.adjoint(tmp, out=x) + if calc_in_place: + op(x, out=tmp) + op.adjoint(tmp, out=x) + else: + x = op.adjoint(op(x), out=x) else: - op(x, out=tmp) - x, tmp = tmp, x + if calc_in_place: + op(x, out=tmp) + x, tmp = tmp, x + else: + x = op(x) # Calculate x norm and verify it is valid x_norm = x.norm() diff --git a/odl/set/space.py b/odl/set/space.py index b1b4b380b8c..2d10651ab40 100644 --- a/odl/set/space.py +++ b/odl/set/space.py @@ -506,6 +506,18 @@ def __mul__(self, other): return ProductSpace(self, other) + def is_suitable_scalar(self, s): + """Determine whether `s` has a type that can be scalar-multiplied with + elements of this space. + """ + raise NotImplementedError(f'Abstract method not implemented for {type(self)}') + + def as_suitable_scalar(self, s): + """Try to convert `s` to a type that can be scalar-multiplied with + elements of this space. + """ + raise NotImplementedError(f'Abstract method not implemented for {type(self)}') + def __str__(self): """Return ``str(self)``.""" return repr(self) diff --git a/odl/solvers/nonsmooth/primal_dual_hybrid_gradient.py b/odl/solvers/nonsmooth/primal_dual_hybrid_gradient.py index ae7aea3cdd9..9630c5b3b7b 100644 --- a/odl/solvers/nonsmooth/primal_dual_hybrid_gradient.py +++ b/odl/solvers/nonsmooth/primal_dual_hybrid_gradient.py @@ -14,7 +14,7 @@ from __future__ import print_function, division, absolute_import import numpy as np - +from odl.util import uses_pytorch from odl.operator import Operator @@ -263,14 +263,22 @@ def pdhg(x, f, g, L, niter, tau=None, sigma=None, **kwargs): dual_tmp = L.range.element() primal_tmp = L.domain.element() + if uses_pytorch(x): + calc_in_place = False # In-place updates are not efficient in PyTorch + else: + calc_in_place = True + for _ in range(niter): # Copy required for relaxation x_old.assign(x) # Gradient ascent in the dual variable y # Compute dual_tmp = y + sigma * L(x_relax) - L(x_relax, out=dual_tmp) - dual_tmp.lincomb(1, y, sigma, dual_tmp) + if calc_in_place: + L(x_relax, out=dual_tmp) + dual_tmp.lincomb(1, y, sigma, dual_tmp) + else: + dual_tmp = y + sigma*L(x_relax) # Apply the dual proximal if not proximal_constant: @@ -279,13 +287,20 @@ def pdhg(x, f, g, L, niter, tau=None, sigma=None, **kwargs): # Gradient descent in the primal variable x # Compute primal_tmp = x + (- tau) * L.derivative(x).adjoint(y) - L.derivative(x).adjoint(y, out=primal_tmp) - primal_tmp.lincomb(1, x, -tau, primal_tmp) + if calc_in_place: + L.derivative(x).adjoint(y, out=primal_tmp) + primal_tmp.lincomb(1, x, -tau, primal_tmp) + else: + primal_tmp = x - L.derivative(x).adjoint(y)*tau # Apply the primal proximal if not proximal_constant: proximal_primal_tau = proximal_primal(tau) - proximal_primal_tau(primal_tmp, out=x) + + if True or calc_in_place: + proximal_primal_tau(primal_tmp, out=x) + else: + x.assign(proximal_primal_tau(primal_tmp)) # Acceleration if gamma_primal is not None: @@ -299,7 +314,10 @@ def pdhg(x, f, g, L, niter, tau=None, sigma=None, **kwargs): sigma *= theta # Over-relaxation in the primal variable x - x_relax.lincomb(1 + theta, x, -theta, x_old) + if calc_in_place: + x_relax.lincomb(1 + theta, x, -theta, x_old) + else: + x_relax = x*(1+theta) - x_old*theta if callback is not None: callback(x) diff --git a/odl/solvers/nonsmooth/proximal_operators.py b/odl/solvers/nonsmooth/proximal_operators.py index 0d83472ff27..03b0e40e332 100644 --- a/odl/solvers/nonsmooth/proximal_operators.py +++ b/odl/solvers/nonsmooth/proximal_operators.py @@ -391,7 +391,7 @@ def quadratic_perturbation_prox_factory(sigma): return (MultiplyOperator(const, domain=u.space, range=u.space) * prox * (MultiplyOperator(const, domain=u.space, range=u.space) - - sigma * const * u)) + u.space.as_suitable_scalar(sigma * const) * u)) else: space = prox.domain return (MultiplyOperator(const, domain=space, range=space) * diff --git a/odl/space/__init__.py b/odl/space/__init__.py index 59368edebf7..e66e59d6fa0 100644 --- a/odl/space/__init__.py +++ b/odl/space/__init__.py @@ -12,10 +12,12 @@ from . import base_tensors, entry_points, weighting from .npy_tensors import * +from .pytorch_tensors import * from .pspace import * from .space_utils import * __all__ = () __all__ += npy_tensors.__all__ +__all__ += pytorch_tensors.__all__ __all__ += pspace.__all__ __all__ += space_utils.__all__ diff --git a/odl/space/base_tensors.py b/odl/space/base_tensors.py index 3396b6d1142..6c4678e1d5f 100644 --- a/odl/space/base_tensors.py +++ b/odl/space/base_tensors.py @@ -22,7 +22,7 @@ array_str, dtype_str, indent, is_complex_floating_dtype, is_floating_dtype, is_numeric_dtype, is_real_dtype, is_real_floating_dtype, safe_int_conv, signature_string, writable_array) -from odl.util.ufuncs import TensorSpaceUfuncs +from odl.util.ufuncs import NumpyTensorSpaceUfuncs, PytorchTensorSpaceUfuncs from odl.util.utility import TYPE_MAP_C2R, TYPE_MAP_R2C, nullcontext __all__ = ('TensorSpace',) @@ -510,19 +510,20 @@ class Tensor(LinearSpaceElement): """Abstract class for representation of `TensorSpace` elements.""" def asarray(self, out=None): - """Extract the data of this tensor as a Numpy array. + """Extract the data of this tensor as an array. This could be a NumPy array + or a PyTorch tensor, depending on what implementation backend is used. This method should be overridden by subclasses. Parameters ---------- - out : `numpy.ndarray`, optional + out : `array_like`, optional Array to write the result to. Returns ------- - asarray : `numpy.ndarray` - Numpy array of the same data type and shape as the space. + asarray : `array_like` + Array of the same type, data type and shape as the space. If ``out`` was given, the returned object is a reference to it. """ @@ -652,6 +653,8 @@ def __bool__(self): def __array__(self, dtype=None): """Return a Numpy array from this tensor. + (Contrast with the `asarray` method, which may give other types of array, + not just NumPy.) Parameters ---------- @@ -663,9 +666,9 @@ def __array__(self, dtype=None): array : `numpy.ndarray` """ if dtype is None: - return self.asarray() + return np.array(self.asarray()) else: - return self.asarray().astype(dtype, copy=AVOID_UNNECESSARY_COPY) + return np.array(self.asarray()).astype(dtype, copy=AVOID_UNNECESSARY_COPY) def __array_wrap__(self, array): """Return a new tensor wrapping the ``array``. @@ -889,7 +892,13 @@ def ufuncs(self): the minimum required version. Use Numpy ufuncs directly, e.g., ``np.sqrt(x)`` instead of ``x.ufuncs.sqrt()``. """ - return TensorSpaceUfuncs(self) + if self.impl == "numpy": + return NumpyTensorSpaceUfuncs(self) + elif self.impl == "pytorch": + return PytorchTensorSpaceUfuncs(self) + else: + raise NotImplementedError() + def show(self, title=None, method='', indices=None, force_show=False, fig=None, **kwargs): diff --git a/odl/space/entry_points.py b/odl/space/entry_points.py index fe1fc7644f8..e571869938b 100644 --- a/odl/space/entry_points.py +++ b/odl/space/entry_points.py @@ -23,12 +23,15 @@ from __future__ import print_function, division, absolute_import from odl.space.npy_tensors import NumpyTensorSpace +from odl.space.pytorch_tensors import PytorchTensorSpace # We don't expose anything to odl.space __all__ = () IS_INITIALIZED = False -TENSOR_SPACE_IMPLS = {'numpy': NumpyTensorSpace} +TENSOR_SPACE_IMPLS = {'numpy': NumpyTensorSpace, + 'pytorch': PytorchTensorSpace + } def _initialize_if_needed(): diff --git a/odl/space/npy_tensors.py b/odl/space/npy_tensors.py index c5a497ab231..f456f890cdc 100644 --- a/odl/space/npy_tensors.py +++ b/odl/space/npy_tensors.py @@ -879,6 +879,15 @@ def element_type(self): """Type of elements in this space: `NumpyTensor`.""" return NumpyTensor + def is_suitable_scalar(self, s): + return type(s) is self.dtype.type + + def as_suitable_scalar(self, s): + """Try to convert `s` to a type that can be scalar-multiplied with + numpy arrays. + """ + return self.dtype.type(s) + class NumpyTensor(Tensor): diff --git a/odl/space/pspace.py b/odl/space/pspace.py index 6273e19532a..2131c187e3d 100644 --- a/odl/space/pspace.py +++ b/odl/space/pspace.py @@ -389,6 +389,17 @@ def dtype(self): else: raise AttributeError("`dtype`'s of subspaces not equal") + def is_suitable_scalar(self, s): + return all(space.is_suitable_scalar(s) for space in self.spaces) + + def as_suitable_scalar(self, s): + """Try to convert `s` to a type that can be scalar-multiplied with + elements of this space. + """ + s_sui = self.spaces[0].as_suitable_scalar(s) + assert(self.is_suitable_scalar(s_sui)) + return s_sui + @property def supported_num_operation_paradigms(self) -> NumOperationParadigmSupport: """Whether in-place operations an out-of-place operations are supported @@ -1556,6 +1567,12 @@ def show(self, title=None, indices=None, **kwargs): return tuple(figs) + def __rmul__(self, other): + if self.space.is_suitable_scalar(other): + return self.space.element([other*part for part in self.parts]) + else: + raise TypeError("Only multiplication with suitable scalar supported for product spaces.") + # --- Add arithmetic operators that broadcast --- # diff --git a/odl/space/pytorch_tensors.py b/odl/space/pytorch_tensors.py new file mode 100644 index 00000000000..4aa1f3f8d08 --- /dev/null +++ b/odl/space/pytorch_tensors.py @@ -0,0 +1,1919 @@ +# Copyright 2024 The ODL contributors +# +# This file is part of ODL. +# +# This Source Code Form is subject to the terms of the Mozilla Public License, +# v. 2.0. If a copy of the MPL was not distributed with this file, You can +# obtain one at https://mozilla.org/MPL/2.0/. + +"""PyTorch implementation of tensor spaces.""" + +from __future__ import absolute_import, division, print_function +from future.utils import native + +import ctypes +from builtins import object +from functools import partial + +import numpy as np +import torch + +from odl.set.sets import ComplexNumbers, RealNumbers +from odl.set.space import (LinearSpaceTypeError, + NumOperationParadigmSupport, SupportedNumOperationParadigms) +from odl.space.base_tensors import Tensor, TensorSpace +from odl.space.weighting import ( + ArrayWeighting, ConstWeighting, CustomDist, CustomInner, CustomNorm, + Weighting) +from odl.util.utility import ArrayOnPytorchManager, _CORRESPONDING_PYTORCH_DTYPES +from odl.util import ( + dtype_str, is_floating_dtype, is_numeric_dtype, is_real_dtype, nullcontext, + signature_string, writable_array) + +__all__ = ('PytorchTensorSpace',) + + +# Define size thresholds to switch implementations +THRESHOLD_SMALL = 100 +THRESHOLD_MEDIUM = 50000 + + +class PytorchTensorSpace(TensorSpace): + + """Set of tensors of arbitrary data type, implemented with Pytorch. + + A tensor is, in the most general sense, a multi-dimensional array + that allows operations per entry (keep the rank constant), + reductions / contractions (reduce the rank) and broadcasting + (raises the rank). + For non-numeric data type like ``object``, the range of valid + operations is rather limited since such a set of tensors does not + define a vector space. + Any numeric data type, on the other hand, is considered valid for + a tensor space, although certain operations - like division with + integer dtype - are not guaranteed to yield reasonable results. + + Under these restrictions, all basic vector space operations are + supported by this class, along with reductions based on arithmetic + or comparison, and element-wise mathematical functions ("ufuncs"). + + This class is implemented using `torch.Tensor`'s as back-end. + + See the `Wikipedia article on tensors`_ for further details. + See also [Hac2012] "Part I Algebraic Tensors" for a rigorous + treatment of tensors with a definition close to this one. + + Note also that this notion of tensors is the same as in popular + Deep Learning frameworks. + + References + ---------- + [Hac2012] Hackbusch, W. *Tensor Spaces and Numerical Tensor Calculus*. + Springer, 2012. + + .. _Wikipedia article on tensors: https://en.wikipedia.org/wiki/Tensor + """ + + def __init__(self, shape, dtype=None, **kwargs): + r"""Initialize a new instance. + + Parameters + ---------- + shape : positive int or sequence of positive ints + Number of entries per axis for elements in this space. A + single integer results in a space with rank 1, i.e., 1 axis. + dtype : + Data type of each element. Can be provided in any + way the `torch.dtype` function understands, e.g. + as built-in type or as a string. For ``None``, + the `default_dtype` of this space (``float64``) is used. + exponent : positive float, optional + Exponent of the norm. For values other than 2.0, no + inner product is defined. + + This option has no impact if either ``dist``, ``norm`` or + ``inner`` is given, or if ``dtype`` is non-numeric. + + Default: 2.0 + + Other Parameters + ---------------- + torch_device : optional, PyTorch device identifier + Where to store and process data (i.e. arrays) representing elements + of this space. Should typically be a GPU (cuda) if available, else + CPU as also used by NumPy. + + weighting : optional + Use weighted inner product, norm, and dist. The following + types are supported as ``weighting``: + + ``None``: no weighting, i.e. weighting with ``1.0`` (default). + + `Weighting`: Use this weighting as-is. Compatibility + with this space's elements is not checked during init. + + ``float``: Weighting by a constant. + + array-like: Pointwise weighting by an array. + + This option cannot be combined with ``dist``, + ``norm`` or ``inner``. It also cannot be used in case of + non-numeric ``dtype``. + + dist : callable, optional + Distance function defining a metric on the space. + It must accept two `PytorchTensor` arguments and return + a non-negative real number. See ``Notes`` for + mathematical requirements. + + By default, ``dist(x, y)`` is calculated as ``norm(x - y)``. + + This option cannot be combined with ``weight``, + ``norm`` or ``inner``. It also cannot be used in case of + non-numeric ``dtype``. + + norm : callable, optional + The norm implementation. It must accept a + `PytorchTensor` argument, return a non-negative real number. + See ``Notes`` for mathematical requirements. + + By default, ``norm(x)`` is calculated as ``inner(x, x)``. + + This option cannot be combined with ``weight``, + ``dist`` or ``inner``. It also cannot be used in case of + non-numeric ``dtype``. + + inner : callable, optional + The inner product implementation. It must accept two + `PytorchTensor` arguments and return an element of the field + of the space (usually real or complex number). + See ``Notes`` for mathematical requirements. + + This option cannot be combined with ``weight``, + ``dist`` or ``norm``. It also cannot be used in case of + non-numeric ``dtype``. + + kwargs : + Further keyword arguments are passed to the weighting + classes. + + See Also + -------- + odl.space.space_utils.rn : constructor for real tensor spaces + odl.space.space_utils.cn : constructor for complex tensor spaces + odl.space.space_utils.tensor_space : + constructor for tensor spaces of arbitrary scalar data type + + Notes + ----- + - A distance function or metric on a space :math:`\mathcal{X}` + is a mapping + :math:`d:\mathcal{X} \times \mathcal{X} \to \mathbb{R}` + satisfying the following conditions for all space elements + :math:`x, y, z`: + + * :math:`d(x, y) \geq 0`, + * :math:`d(x, y) = 0 \Leftrightarrow x = y`, + * :math:`d(x, y) = d(y, x)`, + * :math:`d(x, y) \leq d(x, z) + d(z, y)`. + + - A norm on a space :math:`\mathcal{X}` is a mapping + :math:`\| \cdot \|:\mathcal{X} \to \mathbb{R}` + satisfying the following conditions for all + space elements :math:`x, y`: and scalars :math:`s`: + + * :math:`\| x\| \geq 0`, + * :math:`\| x\| = 0 \Leftrightarrow x = 0`, + * :math:`\| sx\| = |s| \cdot \| x \|`, + * :math:`\| x+y\| \leq \| x\| + + \| y\|`. + + - An inner product on a space :math:`\mathcal{X}` over a field + :math:`\mathbb{F} = \mathbb{R}` or :math:`\mathbb{C}` is a + mapping + :math:`\langle\cdot, \cdot\rangle: \mathcal{X} \times + \mathcal{X} \to \mathbb{F}` + satisfying the following conditions for all + space elements :math:`x, y, z`: and scalars :math:`s`: + + * :math:`\langle x, y\rangle = + \overline{\langle y, x\rangle}`, + * :math:`\langle sx + y, z\rangle = s \langle x, z\rangle + + \langle y, z\rangle`, + * :math:`\langle x, x\rangle = 0 \Leftrightarrow x = 0`. + + Examples + -------- + Explicit initialization with the class constructor: + + >>> space = PytorchTensorSpace(3, float) + >>> space + rn(3) + >>> space.shape + (3,) + >>> space.dtype + dtype('float64') + """ + super(PytorchTensorSpace, self).__init__(shape, dtype) + if self.dtype not in self.available_dtypes(): + raise ValueError('`dtype` {!r} not supported' + ''.format(dtype_str(dtype))) + + torch_device = kwargs.pop('torch_device', "cpu") + dist = kwargs.pop('dist', None) + norm = kwargs.pop('norm', None) + inner = kwargs.pop('inner', None) + weighting = kwargs.pop('weighting', None) + exponent = kwargs.pop('exponent', getattr(weighting, 'exponent', 2.0)) + + self._torch_device = torch.device(torch_device) + + if (not is_numeric_dtype(self.dtype) and + any(x is not None for x in (dist, norm, inner, weighting))): + raise ValueError('cannot use any of `weighting`, `dist`, `norm` ' + 'or `inner` for non-numeric `dtype` {}' + ''.format(dtype)) + else: + self._torch_dtype = _CORRESPONDING_PYTORCH_DTYPES[self.dtype] + if exponent != 2.0 and any(x is not None for x in (dist, norm, inner)): + raise ValueError('cannot use any of `dist`, `norm` or `inner` ' + 'for exponent != 2') + # Check validity of option combination (0 or 1 may be provided) + num_extra_args = sum(a is not None + for a in (dist, norm, inner, weighting)) + if num_extra_args > 1: + raise ValueError('invalid combination of options `weighting`, ' + '`dist`, `norm` and `inner`') + + # Set the weighting + if weighting is not None: + if isinstance(weighting, Weighting): + if weighting.impl != 'pytorch': + raise ValueError("`weighting.impl` must be 'pytorch', " + '`got {!r}'.format(weighting.impl)) + if weighting.exponent != exponent: + raise ValueError('`weighting.exponent` conflicts with ' + '`exponent`: {} != {}' + ''.format(weighting.exponent, exponent)) + self.__weighting = weighting + else: + self.__weighting = _weighting(weighting, exponent) + + # Check (afterwards) that the weighting input was sane + if isinstance(self.weighting, PytorchTensorSpaceArrayWeighting): + if self.weighting.array.dtype == object: + raise ValueError('invalid `weighting` argument: {}' + ''.format(weighting)) + elif not np.can_cast(self.weighting.array.dtype, self.dtype): + raise ValueError( + 'cannot cast from `weighting` data type {} to ' + 'the space `dtype` {}' + ''.format(dtype_str(self.weighting.array.dtype), + dtype_str(self.dtype))) + if self.weighting.array.shape != self.shape: + raise ValueError('array-like weights must have same ' + 'shape {} as this space, got {}' + ''.format(self.shape, + self.weighting.array.shape)) + + elif dist is not None: + self.__weighting = PytorchTensorSpaceCustomDist(dist) + elif norm is not None: + self.__weighting = PytorchTensorSpaceCustomNorm(norm) + elif inner is not None: + self.__weighting = PytorchTensorSpaceCustomInner(inner) + else: + # No weighting, i.e., weighting with constant 1.0 + self.__weighting = PytorchTensorSpaceConstWeighting(1.0, exponent) + + self._use_in_place_ops = kwargs.pop('use_in_place_ops', True) + + # Make sure there are no leftover kwargs + if kwargs: + raise TypeError('got unknown keyword arguments {}'.format(kwargs)) + + @property + def impl(self): + """Name of the implementation back-end: ``'pytorch'``.""" + return 'pytorch' + + @property + def supported_num_operation_paradigms(self) -> NumOperationParadigmSupport: + """PyTorch supports both in-place and out of place operations, but the + former are problematic especially when automatic differentiation is + used: PyTorch needs to ensure the modification does not interfere with + the backwards pass. This makes the performance much worse than for the + out-of-place style.""" + if self._use_in_place_ops: + return SupportedNumOperationParadigms( + in_place = NumOperationParadigmSupport.SUPPORTED, + out_of_place = NumOperationParadigmSupport.PREFERRED) + else: + return SupportedNumOperationParadigms( + in_place = NumOperationParadigmSupport.NOT_SUPPORTED, + out_of_place = NumOperationParadigmSupport.PREFERRED) + + @property + def default_order(self): + """Default (and only) storage order for new elements in this space: ``'C'``.""" + return 'C' + + def is_suitable_scalar(self, s): + if self._torch_dtype in [torch.complex64, torch.complex128]: + return type(s) is complex + else: + return type(s) is float + # Singleton-tensor version: + # if not isinstance(s, torch.Tensor): + # return False + # elif s.dtype != self._torch_dtype: + # return False + # elif s.shape != (): + # return False + # else: + # return True + + def as_suitable_scalar(self, s): + """Try to convert `s` to a type that can be scalar-multiplied with + torch tensors. + """ + if self._torch_dtype in [torch.complex64, torch.complex128]: + return complex(s) + # Arguably, this would be more appropriate: + # return torch.tensor(complex(s), dtype=self._torch_dtype) + # But this results in wrong PyTorch multiplication functions + # being called. + else: + return float(s) + # return torch.tensor(float(s), dtype=self._torch_dtype) + + @property + def weighting(self): + """This space's weighting scheme.""" + return self.__weighting + + @property + def is_weighted(self): + """Return ``True`` if the space is not weighted by constant 1.0.""" + return not ( + isinstance(self.weighting, PytorchTensorSpaceConstWeighting) and + self.weighting.const == 1.0) + + @property + def exponent(self): + """Exponent of the norm and the distance.""" + return self.weighting.exponent + + def element(self, inp=None, data_ptr=None, order=None): + """Create a new element. + + Parameters + ---------- + inp : `array-like`, optional + Input used to initialize the new element. + + If ``inp`` is `None`, an empty element is created with no + guarantee of its state (memory allocation only). + All tensors use row-major storage (corrsponding to + `order='C'` in NumPy). + + Otherwise, a copy is avoided whenever possible. This requires + correct `shape` and `dtype`, and if ``order`` is provided, + also contiguousness in that ordering. If any of these + conditions is not met, a copy is made. + + data_ptr : int, optional + Pointer to the start memory address of a contiguous PyTorch array + or an equivalent raw container with the same total number of + bytes. + The option is also mutually exclusive with ``inp``. + + Returns + ------- + element : `PytorchTensor` + The new element, created from ``inp`` or from scratch. + + Examples + -------- + Without arguments, an uninitialized element is created. With an + array-like input, the element can be initialized: + + >>> space = odl.rn(3) # TODO adapt / test + >>> empty = space.element() + >>> empty.shape + (3,) + >>> empty.space + rn(3) + >>> x = space.element([1, 2, 3]) + >>> x + rn(3).element([ 1., 2., 3.]) + + If the input already is a `torch.Tensor` of correct `dtype`, it + will merely be wrapped, i.e., both array and space element access + the same memory, such that mutations will affect both: + + >>> arr = torch.Tensor([1, 2, 3], dtype=float) # TODO test + >>> elem = odl.rn(3).element(arr) + >>> elem[0] = 0 + >>> elem + rn(3).element([ 0., 2., 3.]) + >>> arr + array([ 0., 2., 3.]) + + Elements can also be constructed from a data pointer, resulting + again in shared memory: + + >>> int_space = odl.tensor_space((2, 3), dtype=int) + >>> arr = torch.Tensor([[1, 2, 3], + ... [4, 5, 6]], dtype=int, order='F') + >>> ptr = arr.ctypes.data + >>> y = int_space.element(data_ptr=ptr, order='F') + >>> y + tensor_space((2, 3), dtype=int).element( + [[1, 2, 3], + [4, 5, 6]] + ) + >>> y[0, 1] = -1 + >>> arr + array([[ 1, -1, 3], + [ 4, 5, 6]]) + """ + if order is not None and str(order).upper() not in ('C'): + raise ValueError(f"Only row-major order supported ('C'), not '{order}'.") + + def wrapped_array(arr): + if arr.shape != self.shape: + raise ValueError('shape of `inp` not equal to space shape: ' + '{} != {}'.format(arr.shape, self.shape)) + return self.element_type(self, arr) + + if inp is None and data_ptr is None: + return wrapped_array(torch.empty( + self.shape, dtype=self._torch_dtype, device=self._torch_device)) + + elif inp is None and data_ptr is not None: + if order is None: + raise ValueError('`order` cannot be None for element ' + 'creation from pointer') + + ctype_array_def = ctypes.c_byte * self.nbytes + as_ctype_array = ctype_array_def.from_address(data_ptr) + as_numpy_array = np.ctypeslib.as_array(as_ctype_array) + arr = as_numpy_array.view(dtype=self._torch_dtype) + arr = arr.reshape(self.shape, order=order) + return wrapped_array(torch.tensor( + arr, dtype=self._torch_dtype, device=self._torch_device)) + + elif inp is not None and data_ptr is None: + if inp in self and order is None: + # Short-circuit for space elements and no enforced ordering + return inp + + # TODO avoid copy when it's not necessary + return wrapped_array(ArrayOnPytorchManager(device=self._torch_device) + .as_compatible_array(inp, dtype=self._torch_dtype)) + + else: + raise TypeError('cannot provide both `inp` and `data_ptr`') + + def zero(self): + """Return a tensor of all zeros. + + Examples + -------- + >>> space = odl.rn(3) # TODO adapt + >>> x = space.zero() + >>> x + rn(3).element([ 0., 0., 0.]) + """ + return self.element(torch.zeros(self.shape, dtype=self._torch_dtype)) + + def one(self): + """Return a tensor of all ones. + + Examples + -------- + >>> space = odl.rn(3) # TODO adapt + >>> x = space.one() + >>> x + rn(3).element([ 1., 1., 1.]) + """ + return self.element(torch.ones(self.shape, dtype=self._torch_dtype)) + + @staticmethod + def available_dtypes(): + """Return the set of data types available in this implementation. + + Notes + ----- + Currently only a conservative selection of the types supported + by Pytorch. + """ + return [np.float16, np.float32, np.float64, + np.complex64, np.complex128] + + @staticmethod + def default_dtype(field=None): + """Return the default data type of this class for a given field. + + Parameters + ---------- + field : `Field`, optional + Set of numbers to be represented by a data type. + Currently supported : `RealNumbers`, `ComplexNumbers` + The default ``None`` means `RealNumbers` + + Returns + ------- + dtype : `torch.dtype` + Pytorch data type specifier. The returned defaults are: + + ``RealNumbers()`` : ``np.dtype('float64')`` + + ``ComplexNumbers()`` : ``np.dtype('complex128')`` + """ + # Note that we're using the NumPy versions of the types, rather + # than the equivalent Pytorch ones. This is for compatibility + # with the rest of ODL, which is not aware of Pytorch. + if field is None or field == RealNumbers(): + return np.float64 + elif field == ComplexNumbers(): + return np.complex128 + else: + raise ValueError('no default data type defined for field {}' + ''.format(field)) + + def _lincomb(self, a, x1, b, x2, out): + """Implement the linear combination of ``x1`` and ``x2``. + + Compute ``out = a*x1 + b*x2`` using optimized + BLAS routines if possible. + + This function is part of the subclassing API. Do not + call it directly. + + Parameters + ---------- + a, b : `TensorSpace.field` element + Scalars to multiply ``x1`` and ``x2`` with. + x1, x2 : `PytorchTensor` + Summands in the linear combination. + out : `PytorchTensor` + Tensor to which the result is written. + + Examples + -------- + >>> space = odl.rn(3) # TODO adapt + >>> x = space.element([0, 1, 1]) + >>> y = space.element([0, 0, 1]) + >>> out = space.element() + >>> result = space.lincomb(1, x, 2, y, out) + >>> result + rn(3).element([ 0., 1., 3.]) + >>> result is out + True + """ + if self._use_in_place_ops and out is not None: + torch.add(input=a*x1.data, other=x2.data, alpha=b, out=out.data) + else: + assert(out is None) + return self.element(a * x1.data + b * x2.data) + + def _dist(self, x1, x2): + """Return the distance between ``x1`` and ``x2``. + + This function is part of the subclassing API. Do not + call it directly. + + Parameters + ---------- + x1, x2 : `PytorchTensor` + Elements whose mutual distance is calculated. + + Returns + ------- + dist : `float` + Distance between the elements. + + Examples + -------- + Different exponents result in difference metrics: + + >>> space_2 = odl.rn(3, exponent=2) + >>> x = space_2.element([-1, -1, 2]) + >>> y = space_2.one() + >>> space_2.dist(x, y) + 3.0 + + >>> space_1 = odl.rn(3, exponent=1) + >>> x = space_1.element([-1, -1, 2]) + >>> y = space_1.one() + >>> space_1.dist(x, y) + 5.0 + + Weighting is supported, too: + + >>> space_1_w = odl.rn(3, exponent=1, weighting=[2, 1, 1]) + >>> x = space_1_w.element([-1, -1, 2]) + >>> y = space_1_w.one() + >>> space_1_w.dist(x, y) + 7.0 + """ + return self.weighting.dist(x1, x2) + + def _norm(self, x): + """Return the norm of ``x``. + + This function is part of the subclassing API. Do not + call it directly. + + Parameters + ---------- + x : `PytorchTensor` + Element whose norm is calculated. + + Returns + ------- + norm : `float` + Norm of the element. + + Examples + -------- + Different exponents result in difference norms: + + >>> space_2 = odl.rn(3, exponent=2) + >>> x = space_2.element([3, 0, 4]) + >>> space_2.norm(x) + 5.0 + >>> space_1 = odl.rn(3, exponent=1) + >>> x = space_1.element([3, 0, 4]) + >>> space_1.norm(x) + 7.0 + + Weighting is supported, too: + + >>> space_1_w = odl.rn(3, exponent=1, weighting=[2, 1, 1]) + >>> x = space_1_w.element([3, 0, 4]) + >>> space_1_w.norm(x) + 10.0 + """ + return self.weighting.norm(x) + + def _inner(self, x1, x2): + """Return the inner product of ``x1`` and ``x2``. + + This function is part of the subclassing API. Do not + call it directly. + + Parameters + ---------- + x1, x2 : `PytorchTensor` + Elements whose inner product is calculated. + + Returns + ------- + inner : `field` `element` + Inner product of the elements. + + Examples + -------- + >>> space = odl.rn(3) + >>> x = space.element([1, 0, 3]) + >>> y = space.one() + >>> space.inner(x, y) + 4.0 + + Weighting is supported, too: + + >>> space_w = odl.rn(3, weighting=[2, 1, 1]) + >>> x = space_w.element([1, 0, 3]) + >>> y = space_w.one() + >>> space_w.inner(x, y) + 5.0 + """ + return self.weighting.inner(x1, x2) + + def _multiply(self, x1, x2, out): + """Compute the entry-wise product ``out = x1 * x2``. + + This function is part of the subclassing API. Do not + call it directly. + + Parameters + ---------- + x1, x2 : `PytorchTensor` + Factors in the product. + out : `PytorchTensor` + Element to which the result is written. + + Examples + -------- + >>> space = odl.rn(3) + >>> x = space.element([1, 0, 3]) + >>> y = space.element([-1, 1, -1]) + >>> space.multiply(x, y) + rn(3).element([-1., 0., -3.]) + >>> out = space.element() + >>> result = space.multiply(x, y, out=out) + >>> result + rn(3).element([-1., 0., -3.]) + >>> result is out + True + """ + torch.mul(x1.data, x2.data, out=out.data) + + def _divide(self, x1, x2, out): + """Compute the entry-wise quotient ``x1 / x2``. + + This function is part of the subclassing API. Do not + call it directly. + + Parameters + ---------- + x1, x2 : `PytorchTensor` + Dividend and divisor in the quotient. + out : `PytorchTensor` + Element to which the result is written. + + Examples + -------- + >>> space = odl.rn(3) + >>> x = space.element([2, 0, 4]) + >>> y = space.element([1, 1, 2]) + >>> space.divide(x, y) + rn(3).element([ 2., 0., 2.]) + >>> out = space.element() + >>> result = space.divide(x, y, out=out) + >>> result + rn(3).element([ 2., 0., 2.]) + >>> result is out + True + """ + torch.div(x1.data, x2.data, out=out.data) + + def __eq__(self, other): + """Return ``self == other``. + + Returns + ------- + equals : bool + True if ``other`` is an instance of ``type(self)`` + with the same `PytorchTensorSpace.shape`, `PytorchTensorSpace.dtype` + and `PytorchTensorSpace.weighting`, otherwise False. + + Examples + -------- + >>> space = odl.rn(3) + >>> same_space = odl.rn(3, exponent=2) + >>> same_space == space + True + + Different `shape`, `exponent` or `dtype` all result in different + spaces: + + >>> diff_space = odl.rn((3, 4)) + >>> diff_space == space + False + >>> diff_space = odl.rn(3, exponent=1) + >>> diff_space == space + False + >>> diff_space = odl.rn(3, dtype='float32') + >>> diff_space == space + False + >>> space == object + False + """ + if other is self: + return True + + return (super(PytorchTensorSpace, self).__eq__(other) and + self.weighting == other.weighting) + + def __hash__(self): + """Return ``hash(self)``.""" + return hash((super(PytorchTensorSpace, self).__hash__(), + self.weighting)) + + @property + def byaxis(self): + """Return the subspace defined along one or several dimensions. + + Examples + -------- + Indexing with integers or slices: + + >>> space = odl.rn((2, 3, 4)) # TODO adapt + >>> space.byaxis[0] + rn(2) + >>> space.byaxis[1:] + rn((3, 4)) + + Lists can be used to stack spaces arbitrarily: + + >>> space.byaxis[[2, 1, 2]] + rn((4, 3, 4)) + """ + space = self + + class PytorchTensorSpacebyaxis(object): + + """Helper class for indexing by axis.""" + + def __getitem__(self, indices): + """Return ``self[indices]``.""" + try: + iter(indices) + except TypeError: + newshape = space.shape[indices] + else: + newshape = tuple(space.shape[i] for i in indices) + + if isinstance(space.weighting, ArrayWeighting): + new_array = np.asarray(space.weighting.array[indices]) + weighting = PytorchTensorSpaceArrayWeighting( + new_array, space.weighting.exponent) + else: + weighting = space.weighting + + return type(space)(newshape, space.dtype, weighting=weighting) + + def __repr__(self): + """Return ``repr(self)``.""" + return repr(space) + '.byaxis' + + return PytorchTensorSpacebyaxis() + + def __repr__(self): + """Return ``repr(self)``.""" + if self.ndim == 1: + posargs = [self.size] + else: + posargs = [self.shape] + + if self.is_real: + ctor_name = 'rn' # TODO adapt + elif self.is_complex: + ctor_name = 'cn' + else: + ctor_name = 'tensor_space' + + if (ctor_name == 'tensor_space' or + not is_numeric_dtype(self.dtype) or + self.dtype != self.default_dtype(self.field)): + optargs = [('dtype', dtype_str(self.dtype), '')] + if self.dtype in (float, complex, int, bool): + optmod = '!s' + else: + optmod = '' + else: + optargs = [] + optmod = '' + + inner_str = signature_string(posargs, optargs, mod=['', optmod]) + weight_str = self.weighting.repr_part + if weight_str: + inner_str += ', ' + weight_str + + return '{}({})'.format(ctor_name, inner_str) + + @property + def element_type(self): + """Type of elements in this space: `PytorchTensor`.""" + return PytorchTensor + + +class PytorchTensor(Tensor): + + """Representation of a `PytorchTensorSpace` element.""" + + def __init__(self, space, data): + """Initialize a new instance.""" + Tensor.__init__(self, space) + self.__data = data + + @property + def data(self): + """The `torch.Tensor` representing the data of ``self``.""" + return self.__data + + def _assign(self, other, avoid_deep_copy): + """Assign the values of ``other``, which is assumed to be in the + same space, to ``self``.""" + if avoid_deep_copy or not self.space._use_in_place_ops: + self.__data = other.__data + else: + self.__data[:] = other.__data + + def asarray(self, out=None): + """Extract the data of this array as a ``torch.Tensor``. + + This method is invoked when calling `torch.tensor` on this + tensor. + + Parameters + ---------- + out : `np.ndarray`, optional + Array in which the result should be written in-place. + Has to be contiguous and of the correct dtype. + + Returns + ------- + asarray : `torch.Tensor` + Pytorch array with the same data type as ``self``. If + ``out`` was given, the returned object is a reference + to it. + + Examples + -------- + >>> space = odl.rn(3, dtype='float32') # TODO adapt + >>> x = space.element([1, 2, 3]) + >>> x.asarray() + array([ 1., 2., 3.], dtype=float32) + >>> np.asarray(x) is x.asarray() + True + >>> out = np.empty(3, dtype='float32') + >>> result = x.asarray(out=out) + >>> out + array([ 1., 2., 3.], dtype=float32) + >>> result is out + True + >>> space = odl.rn((2, 3)) + >>> space.one().asarray() + array([[ 1., 1., 1.], + [ 1., 1., 1.]]) + """ + if out is None: + return self.data + else: + out[:] = self.data + return out + + def astype(self, dtype): + """Return a copy of this element with new ``dtype``. + + Parameters + ---------- + dtype : + Scalar data type of the returned space. Can be provided + in any way the `numpy.dtype` constructor understands, e.g. + as built-in type or as a string. Data types with non-trivial + shapes are not allowed. + + Returns + ------- + newelem : `PytorchTensor` + Version of this element with given data type. + """ + return self.space.astype(dtype).element(self.data.astype(dtype)) + + @property + def data_ptr(self): + """A raw pointer to the data container of ``self``. + + Examples + -------- + >>> import ctypes + >>> space = odl.tensor_space(3, dtype='uint16') # TODO check example + >>> x = space.element([1, 2, 3]) + >>> arr_type = ctypes.c_uint16 * 3 # C type "array of 3 uint16" + >>> buffer = arr_type.from_address(x.data_ptr) + >>> arr = np.frombuffer(buffer, dtype='uint16') + >>> arr + array([1, 2, 3], dtype=uint16) + + In-place modification via pointer: + + >>> arr[0] = 42 # TODO doubtful if this actually works + >>> x + tensor_space(3, dtype='uint16').element([42, 2, 3]) + """ + return self.data.data_ptr() + + def __eq__(self, other): + """Return ``self == other``. + + Returns + ------- + equals : bool + True if all entries of ``other`` are equal to this + the entries of ``self``, False otherwise. + + Examples + -------- + >>> space = odl.rn(3) + >>> x = space.element([1, 2, 3]) + >>> y = space.element([1, 2, 3]) + >>> x == y + True + + >>> y = space.element([-1, 2, 3]) + >>> x == y + False + >>> x == object + False + + Space membership matters: + + >>> space2 = odl.tensor_space(3, dtype='int64') + >>> y = space2.element([1, 2, 3]) + >>> x == y or y == x + False + """ + if other is self: + return True + elif other not in self.space: + return False + else: + return torch.equal(self.data, other.data) + + def copy(self): + """Return an identical (deep) copy of this tensor. + + Parameters + ---------- + None + + Returns + ------- + copy : `PytorchTensor` + The deep copy + + Examples + -------- + >>> space = odl.rn(3) # TODO adapt + >>> x = space.element([1, 2, 3]) + >>> y = x.copy() + >>> y == x + True + >>> y is x + False + """ + return self.space.element(self.data.clone()) + + def __copy__(self): + """Return ``copy(self)``. + + This implements the (shallow) copy interface of the ``copy`` + module of the Python standard library. + + See Also + -------- + copy + + Examples + -------- + >>> from copy import copy + >>> space = odl.rn(3) + >>> x = space.element([1, 2, 3]) + >>> y = copy(x) + >>> y == x + True + >>> y is x + False + """ + return self.copy() + + def __getitem__(self, indices): + """Return ``self[indices]``. + + Parameters + ---------- + indices : index expression + Integer, slice or sequence of these, defining the positions + of the data array which should be accessed. + + Returns + ------- + values : `PytorchTensorSpace.dtype` or `PytorchTensor` + The value(s) at the given indices. Note that the returned + object is a writable view into the original tensor, except + for the case when ``indices`` is a list. + + Examples + -------- + For one-dimensional spaces, indexing is as in linear arrays: + + >>> space = odl.rn(3) + >>> x = space.element([1, 2, 3]) + >>> x[0] + 1.0 + >>> x[1:] + rn(2).element([ 2., 3.]) + + In higher dimensions, the i-th index expression accesses the + i-th axis: + + >>> space = odl.rn((2, 3)) + >>> x = space.element([[1, 2, 3], + ... [4, 5, 6]]) + >>> x[0, 1] + 2.0 + >>> x[:, 1:] + rn((2, 2)).element( + [[ 2., 3.], + [ 5., 6.]] + ) + + Slices can be assigned to, except if lists are used for indexing: + + >>> y = x[:, ::2] # view into x + >>> y[:] = -9 + >>> x + rn((2, 3)).element( + [[-9., 2., -9.], + [-9., 5., -9.]] + ) + >>> y = x[[0, 1], [1, 2]] # not a view, won't modify x + >>> y + rn(2).element([ 2., -9.]) + >>> y[:] = 0 + >>> x + rn((2, 3)).element( + [[-9., 2., -9.], + [-9., 5., -9.]] + ) + """ + # Lazy implementation: index the array and deal with it + if isinstance(indices, PytorchTensor): + indices = indices.data + arr = self.data[indices] + + if arr.shape == (): # scalar + if self.space.field is not None: + return self.space.field.element(arr) + else: + return arr + else: + if is_numeric_dtype(self.dtype): + weighting = self.space.weighting + else: + weighting = None + space = type(self.space)( + arr.shape, dtype=self.dtype, exponent=self.space.exponent, + weighting=weighting) + return space.element(arr) + + def __setitem__(self, indices, values): + """Implement ``self[indices] = values``. + + Parameters + ---------- + indices : index expression + Integer, slice or sequence of these, defining the positions + of the data array which should be written to. + values : scalar, array-like or `PytorchTensor` + The value(s) that are to be assigned. + + If ``index`` is an integer, ``value`` must be a scalar. + + If ``index`` is a slice or a sequence of slices, ``value`` + must be broadcastable to the shape of the slice. + + Examples + -------- + For 1d spaces, entries can be set with scalars or sequences of + correct shape: + + >>> space = odl.rn(3) + >>> x = space.element([1, 2, 3]) + >>> x[0] = -1 + >>> x[1:] = (0, 1) + >>> x + rn(3).element([-1., 0., 1.]) + + It is also possible to use tensors of other spaces for + casting and assignment: + + >>> space = odl.rn((2, 3)) + >>> x = space.element([[1, 2, 3], + ... [4, 5, 6]]) + >>> x[0, 1] = -1 + >>> x + rn((2, 3)).element( + [[ 1., -1., 3.], + [ 4., 5., 6.]] + ) + >>> short_space = odl.tensor_space((2, 2), dtype='short') + >>> y = short_space.element([[-1, 2], + ... [0, 0]]) + >>> x[:, :2] = y + >>> x + rn((2, 3)).element( + [[-1., 2., 3.], + [ 0., 0., 6.]] + ) + + The Numpy assignment and broadcasting rules apply: + + >>> x[:] = np.array([[0, 0, 0], + ... [1, 1, 1]]) + >>> x + rn((2, 3)).element( + [[ 0., 0., 0.], + [ 1., 1., 1.]] + ) + >>> x[:, 1:] = [7, 8] + >>> x + rn((2, 3)).element( + [[ 0., 7., 8.], + [ 1., 7., 8.]] + ) + >>> x[:, ::2] = -2. + >>> x + rn((2, 3)).element( + [[-2., 7., -2.], + [-2., 7., -2.]] + ) + """ + if isinstance(indices, type(self)): + indices = indices.data + if isinstance(values, type(self)): + values = values.data + + self.data[indices] = values + + def __array__(self, dtype=None): + """Return a Numpy array from this tensor. + + Parameters + ---------- + dtype : + Specifier for the data type of the output array. + + Returns + ------- + array : `numpy.ndarray` + """ + return self.data.cpu().numpy() + + @property + def real(self): + """Real part of ``self``. + + Returns + ------- + real : `PytorchTensor` + Real part of this element as a member of a + `PytorchTensorSpace` with corresponding real data type. + + Examples + -------- + Get the real part: + + >>> space = odl.cn(3) + >>> x = space.element([1 + 1j, 2, 3 - 3j]) + >>> x.real + rn(3).element([ 1., 2., 3.]) + + Set the real part: + + >>> space = odl.cn(3) + >>> x = space.element([1 + 1j, 2, 3 - 3j]) + >>> zero = odl.rn(3).zero() + >>> x.real = zero + >>> x + cn(3).element([ 0.+1.j, 0.+0.j, 0.-3.j]) + + Other array-like types and broadcasting: + + >>> x.real = 1.0 + >>> x + cn(3).element([ 1.+1.j, 1.+0.j, 1.-3.j]) + >>> x.real = [2, 3, 4] + >>> x + cn(3).element([ 2.+1.j, 3.+0.j, 4.-3.j]) + """ + if self.space.is_real: + return self + elif self.space.is_complex: + real_space = self.space.astype(self.space.real_dtype) + return real_space.element(self.data.real) + else: + raise NotImplementedError('`real` not defined for non-numeric ' + 'dtype {}'.format(self.dtype)) + + @real.setter + def real(self, newreal): + """Setter for the real part. + + This method is invoked by ``x.real = other``. + + Parameters + ---------- + newreal : array-like or scalar + Values to be assigned to the real part of this element. + """ + self.real.data[:] = newreal + + @property + def imag(self): + """Imaginary part of ``self``. + + Returns + ------- + imag : `PytorchTensor` + Imaginary part this element as an element of a + `PytorchTensorSpace` with real data type. + + Examples + -------- + Get the imaginary part: + + >>> space = odl.cn(3) + >>> x = space.element([1 + 1j, 2, 3 - 3j]) + >>> x.imag + rn(3).element([ 1., 0., -3.]) + + Set the imaginary part: + + >>> space = odl.cn(3) + >>> x = space.element([1 + 1j, 2, 3 - 3j]) + >>> zero = odl.rn(3).zero() + >>> x.imag = zero + >>> x + cn(3).element([ 1.+0.j, 2.+0.j, 3.+0.j]) + + Other array-like types and broadcasting: + + >>> x.imag = 1.0 + >>> x + cn(3).element([ 1.+1.j, 2.+1.j, 3.+1.j]) + >>> x.imag = [2, 3, 4] + >>> x + cn(3).element([ 1.+2.j, 2.+3.j, 3.+4.j]) + """ + if self.space.is_real: + return self.space.zero() + elif self.space.is_complex: + real_space = self.space.astype(self.space.real_dtype) + return real_space.element(self.data.imag) + else: + raise NotImplementedError('`imag` not defined for non-numeric ' + 'dtype {}'.format(self.dtype)) + + @imag.setter + def imag(self, newimag): + """Setter for the imaginary part. + + This method is invoked by ``x.imag = other``. + + Parameters + ---------- + newimag : array-like or scalar + Values to be assigned to the imaginary part of this element. + + Raises + ------ + ValueError + If the space is real, i.e., no imagninary part can be set. + """ + if self.space.is_real: + raise ValueError('cannot set imaginary part in real spaces') + self.imag.data[:] = newimag + + def conj(self, out=None): + """Return the complex conjugate of ``self``. + + Parameters + ---------- + out : `PytorchTensor`, optional + Element to which the complex conjugate is written. + Must be an element of ``self.space``. + + Returns + ------- + out : `PytorchTensor` + The complex conjugate element. If ``out`` was provided, + the returned object is a reference to it. + + Examples + -------- + >>> space = odl.cn(3) + >>> x = space.element([1 + 1j, 2, 3 - 3j]) + >>> x.conj() + cn(3).element([ 1.-1.j, 2.-0.j, 3.+3.j]) + >>> out = space.element() + >>> result = x.conj(out=out) + >>> result + cn(3).element([ 1.-1.j, 2.-0.j, 3.+3.j]) + >>> result is out + True + + In-place conjugation: + + >>> result = x.conj(out=x) + >>> x + cn(3).element([ 1.-1.j, 2.-0.j, 3.+3.j]) + >>> result is x + True + """ + if self.space.is_real: + if out is None: + return self + else: + out[:] = self + return out + + if not is_numeric_dtype(self.space.dtype): + raise NotImplementedError('`conj` not defined for non-numeric ' + 'dtype {}'.format(self.dtype)) + + if out is None: + return self.space.element(self.data.conj()) + else: + if out not in self.space: + raise LinearSpaceTypeError('`out` {!r} not in space {!r}' + ''.format(out, self.space)) + self.data.conj(out.data) + return out + + def __ipow__(self, other): + """Return ``self **= other``.""" + try: + if other == int(other): + return super(PytorchTensor, self).__ipow__(other) + except TypeError: + pass + + torch.pow(self.data, other, out=self.data) + return self + + def __rmul__(self, other): + result = self.space.element(other * self.data) + return result + + def __int__(self): + """Return ``int(self)``.""" + return int(self.data) + + def __long__(self): + """Return ``long(self)``. + + This method is only useful in Python 2. + """ + return long(self.data) + + def __float__(self): + """Return ``float(self)``.""" + return float(self.data) + + def __complex__(self): + """Return ``complex(self)``.""" + if self.size != 1: + raise TypeError('only size-1 tensors can be converted to ' + 'Python scalars') + return complex(self.data.ravel()[0]) + + + + +def _weighting(weights, exponent): + """Return a weighting whose type is inferred from the arguments.""" + if np.isscalar(weights) or weights.shape == (): + weighting = PytorchTensorSpaceConstWeighting(weights, exponent) + elif weights is None: + weighting = PytorchTensorSpaceConstWeighting(1.0, exponent) + else: # last possibility: make an array + arr = torch.tensor(weights) + weighting = PytorchTensorSpaceArrayWeighting(arr, exponent) + return weighting + + +def pytorch_weighted_inner(weights): + """Weighted inner product on `TensorSpace`'s as free function. + + Parameters + ---------- + weights : scalar or `array-like` + Weights of the inner product. A scalar is interpreted as a + constant weight, a 1-dim. array as a weighting vector. + + Returns + ------- + inner : `callable` + Inner product function with given weight. Constant weightings + are applicable to spaces of any size, for arrays the sizes + of the weighting and the space must match. + + See Also + -------- + PytorchTensorSpaceConstWeighting + PytorchTensorSpaceArrayWeighting + """ + return _weighting(weights, exponent=2.0).inner + + +def pytorch_weighted_norm(weights, exponent=2.0): + """Weighted norm on `TensorSpace`'s as free function. + + Parameters + ---------- + weights : scalar or `array-like` + Weights of the norm. A scalar is interpreted as a + constant weight, a 1-dim. array as a weighting vector. + exponent : positive `float` + Exponent of the norm. + + Returns + ------- + norm : `callable` + Norm function with given weight. Constant weightings + are applicable to spaces of any size, for arrays the sizes + of the weighting and the space must match. + + See Also + -------- + PytorchTensorSpaceConstWeighting + PytorchTensorSpaceArrayWeighting + """ + return _weighting(weights, exponent=exponent).norm + + +def pytorch_weighted_dist(weights, exponent=2.0): + """Weighted distance on `TensorSpace`'s as free function. + + Parameters + ---------- + weights : scalar or `array-like` + Weights of the distance. A scalar is interpreted as a + constant weight, a 1-dim. array as a weighting vector. + exponent : positive `float` + Exponent of the norm. + + Returns + ------- + dist : `callable` + Distance function with given weight. Constant weightings + are applicable to spaces of any size, for arrays the sizes + of the weighting and the space must match. + + See Also + -------- + PytorchTensorSpaceConstWeighting + PytorchTensorSpaceArrayWeighting + """ + return _weighting(weights, exponent=exponent).dist + + +def _norm_default(x): + """Default Euclidean norm implementation.""" + + return x.data.norm(p=2) + + +def _pnorm_default(x, p): + """Default p-norm implementation.""" + return x.data.norm(p=p) + + +def _pnorm_diagweight(x, p, w): + """Diagonally weighted p-norm implementation.""" + xp = torch.abs(x.data) + if p == float('inf'): + xp *= w + return torch.max(xp) + else: + torch.pow(xp, p, out=xp) + xp *= w + return torch.sum(xp) ** (1 / p) + + +def _inner_default(x1, x2): + """Default Euclidean inner product implementation.""" + + if is_real_dtype(x1.dtype): + return torch.dot(x1.data, x2.data) + else: + # x2 as first argument because we want linearity in x1 + return torch.vdot(x2.data, x1.data) + + + +class PytorchTensorSpaceArrayWeighting(ArrayWeighting): + + """Weighting of a `PytorchTensorSpace` by an array. + + This class defines a weighting by an array that has the same shape + as the tensor space. Since the space is not known to this class, + no checks of shape or data type are performed. + See ``Notes`` for mathematical details. + """ + + def __init__(self, array, exponent=2.0): + r"""Initialize a new instance. + + Parameters + ---------- + array : `array-like`, one-dim. + Weighting array of the inner product, norm and distance. + All its entries must be positive, however this is not + verified during initialization. + exponent : positive `float` + Exponent of the norm. For values other than 2.0, no inner + product is defined. + + Notes + ----- + - For exponent 2.0, a new weighted inner product with array + :math:`W` is defined as + + .. math:: + \langle A, B\rangle_W := + \langle W \odot A, B\rangle = + \langle w \odot a, b\rangle = + b^{\mathrm{H}} (w \odot a), + + where :math:`a, b, w` are the "flattened" counterparts of + tensors :math:`A, B, W`, respectively, :math:`b^{\mathrm{H}}` + stands for transposed complex conjugate and :math:`w \odot a` + for element-wise multiplication. + + - For other exponents, only norm and dist are defined. In the + case of exponent :math:`\infty`, the weighted norm is + + .. math:: + \| A\|_{W, \infty} := + \| W \odot A\|_{\infty} = + \| w \odot a\|_{\infty}, + + otherwise it is (using point-wise exponentiation) + + .. math:: + \| A\|_{W, p} := + \| W^{1/p} \odot A\|_{p} = + \| w^{1/p} \odot a\|_{\infty}. + + - Note that this definition does **not** fulfill the limit + property in :math:`p`, i.e. + + .. math:: + \| A\|_{W, p} \not\to + \| A\|_{W, \infty} \quad (p \to \infty) + + unless all weights are equal to 1. + + - The array :math:`W` may only have positive entries, otherwise + it does not define an inner product or norm, respectively. This + is not checked during initialization. + """ + if isinstance(array, PytorchTensor): + array = array.data + elif not isinstance(array, torch.Tensor): + array = torch.tensor(array) + super(PytorchTensorSpaceArrayWeighting, self).__init__( + array, impl='pytorch', exponent=exponent) + + def __hash__(self): + """Return ``hash(self)``.""" + return hash((type(self), hash(self.array), self.exponent)) + + def inner(self, x1, x2): + """Return the weighted inner product of ``x1`` and ``x2``. + + Parameters + ---------- + x1, x2 : `PytorchTensor` + Tensors whose inner product is calculated. + + Returns + ------- + inner : float or complex + The inner product of the two provided vectors. + """ + if self.exponent != 2.0: + raise NotImplementedError('no inner product defined for ' + 'exponent != 2 (got {})' + ''.format(self.exponent)) + else: + inner = _inner_default(x1 * self.array, x2) + if is_real_dtype(x1.dtype): + return float(inner) + else: + return complex(inner) + + def norm(self, x): + """Return the weighted norm of ``x``. + + Parameters + ---------- + x : `PytorchTensor` + Tensor whose norm is calculated. + + Returns + ------- + norm : float + The norm of the provided tensor. + """ + if self.exponent == 2.0: + norm_squared = self.inner(x, x).real # TODO: optimize?! + if norm_squared < 0: + norm_squared = 0.0 # Compensate for numerical error + return float(np.sqrt(norm_squared)) + else: + return float(_pnorm_diagweight(x, self.exponent, self.array)) + + +class PytorchTensorSpaceConstWeighting(ConstWeighting): + + """Weighting of a `PytorchTensorSpace` by a constant. + + See ``Notes`` for mathematical details. + """ + + def __init__(self, const, exponent=2.0): + r"""Initialize a new instance. + + Parameters + ---------- + const : positive float + Weighting constant of the inner product, norm and distance. + exponent : positive float + Exponent of the norm. For values other than 2.0, the inner + product is not defined. + + Notes + ----- + - For exponent 2.0, a new weighted inner product with constant + :math:`c` is defined as + + .. math:: + \langle a, b\rangle_c := + c \, \langle a, b\rangle_c = + c \, b^{\mathrm{H}} a, + + where :math:`b^{\mathrm{H}}` standing for transposed complex + conjugate. + + - For other exponents, only norm and dist are defined. In the + case of exponent :math:`\infty`, the weighted norm is defined + as + + .. math:: + \| a \|_{c, \infty} := + c\, \| a \|_{\infty}, + + otherwise it is + + .. math:: + \| a \|_{c, p} := + c^{1/p}\, \| a \|_{p}. + + - Note that this definition does **not** fulfill the limit + property in :math:`p`, i.e. + + .. math:: + \| a\|_{c, p} \not\to + \| a \|_{c, \infty} \quad (p \to \infty) + + unless :math:`c = 1`. + + - The constant must be positive, otherwise it does not define an + inner product or norm, respectively. + """ + super(PytorchTensorSpaceConstWeighting, self).__init__( + const, impl='pytorch', exponent=exponent) + + def inner(self, x1, x2): + """Return the weighted inner product of ``x1`` and ``x2``. + + Parameters + ---------- + x1, x2 : `PytorchTensor` + Tensors whose inner product is calculated. + + Returns + ------- + inner : float or complex + The inner product of the two provided tensors. + """ + if self.exponent != 2.0: + raise NotImplementedError('no inner product defined for ' + 'exponent != 2 (got {})' + ''.format(self.exponent)) + else: + inner = self.const * _inner_default(x1, x2) + if x1.space.field is None: + return inner + else: + return x1.space.field.element(inner) + + def norm(self, x): + """Return the weighted norm of ``x``. + + Parameters + ---------- + x1 : `PytorchTensor` + Tensor whose norm is calculated. + + Returns + ------- + norm : float + The norm of the tensor. + """ + if self.exponent == 2.0: + return float(np.sqrt(self.const) * _norm_default(x)) + elif self.exponent == float('inf'): + return float(self.const * _pnorm_default(x, self.exponent)) + else: + return float((self.const ** (1 / self.exponent) * + _pnorm_default(x, self.exponent))) + + def dist(self, x1, x2): + """Return the weighted distance between ``x1`` and ``x2``. + + Parameters + ---------- + x1, x2 : `PytorchTensor` + Tensors whose mutual distance is calculated. + + Returns + ------- + dist : float + The distance between the tensors. + """ + if self.exponent == 2.0: + return float(np.sqrt(self.const) * _norm_default(x1 - x2)) + elif self.exponent == float('inf'): + return float(self.const * _pnorm_default(x1 - x2, self.exponent)) + else: + return float((self.const ** (1 / self.exponent) * + _pnorm_default(x1 - x2, self.exponent))) + + +class PytorchTensorSpaceCustomInner(CustomInner): + + """Class for handling a user-specified inner product.""" + + def __init__(self, inner): + """Initialize a new instance. + + Parameters + ---------- + inner : `callable` + The inner product implementation. It must accept two + `Tensor` arguments, return an element from their space's + field (real or complex number) and satisfy the following + conditions for all vectors ``x, y, z`` and scalars ``s``: + + - `` = conj()`` + - `` = s * + `` + - `` = 0`` if and only if ``x = 0`` + """ + super(PytorchTensorSpaceCustomInner, self).__init__(inner, impl='pytorch') + + +class PytorchTensorSpaceCustomNorm(CustomNorm): + + """Class for handling a user-specified norm. + + Note that this removes ``inner``. + """ + + def __init__(self, norm): + """Initialize a new instance. + + Parameters + ---------- + norm : `callable` + The norm implementation. It must accept a `Tensor` + argument, return a `float` and satisfy the following + conditions for all any two elements ``x, y`` and scalars + ``s``: + + - ``||x|| >= 0`` + - ``||x|| = 0`` if and only if ``x = 0`` + - ``||s * x|| = |s| * ||x||`` + - ``||x + y|| <= ||x|| + ||y||`` + """ + super(PytorchTensorSpaceCustomNorm, self).__init__(norm, impl='pytorch') + + +class PytorchTensorSpaceCustomDist(CustomDist): + + """Class for handling a user-specified distance in `TensorSpace`. + + Note that this removes ``inner`` and ``norm``. + """ + + def __init__(self, dist): + """Initialize a new instance. + + Parameters + ---------- + dist : `callable` + The distance function defining a metric on `TensorSpace`. It + must accept two `Tensor` arguments, return a `float` and + fulfill the following mathematical conditions for any three + elements ``x, y, z``: + + - ``dist(x, y) >= 0`` + - ``dist(x, y) = 0`` if and only if ``x = y`` + - ``dist(x, y) = dist(y, x)`` + - ``dist(x, y) <= dist(x, z) + dist(z, y)`` + """ + super(PytorchTensorSpaceCustomDist, self).__init__(dist, impl='pytorch') + + +if __name__ == '__main__': + from odl.util.testutils import run_doctests + run_doctests() diff --git a/odl/test/space/tensors_test.py b/odl/test/space/tensors_test.py index e722d29303e..645f0fdb58d 100644 --- a/odl/test/space/tensors_test.py +++ b/odl/test/space/tensors_test.py @@ -22,6 +22,10 @@ NumpyTensor, NumpyTensorSpace, NumpyTensorSpaceArrayWeighting, NumpyTensorSpaceConstWeighting, NumpyTensorSpaceCustomDist, NumpyTensorSpaceCustomInner, NumpyTensorSpaceCustomNorm) +from odl.space.pytorch_tensors import ( + PytorchTensor, PytorchTensorSpace, PytorchTensorSpaceArrayWeighting, + PytorchTensorSpaceConstWeighting, PytorchTensorSpaceCustomDist, + PytorchTensorSpaceCustomInner, PytorchTensorSpaceCustomNorm) from odl.util.testutils import ( all_almost_equal, all_equal, noise_array, noise_element, noise_elements, simple_fixture) @@ -72,6 +76,19 @@ def _weighting_cls(impl, kind): return NumpyTensorSpaceCustomDist else: assert False + elif impl == 'pytorch': + if kind == 'array': + return PytorchTensorSpaceArrayWeighting + elif kind == 'const': + return PytorchTensorSpaceConstWeighting + elif kind == 'inner': + return PytorchTensorSpaceCustomInner + elif kind == 'norm': + return PytorchTensorSpaceCustomNorm + elif kind == 'dist': + return PytorchTensorSpaceCustomDist + else: + assert False else: assert False diff --git a/odl/test/trafos/fourier_test.py b/odl/test/trafos/fourier_test.py index 45bf60c5993..9bdec0f175a 100644 --- a/odl/test/trafos/fourier_test.py +++ b/odl/test/trafos/fourier_test.py @@ -29,6 +29,7 @@ impl = simple_fixture( 'impl', [pytest.param('numpy'), + pytest.param('pytorch'), pytest.param('pyfftw', marks=skip_if_no_pyfftw)] ) exponent = simple_fixture('exponent', [2.0, 1.0, float('inf'), 1.5]) @@ -45,8 +46,10 @@ def _params_from_dtype(dtype): halfcomplex = False return halfcomplex, complex_dtype(dtype) +def _dft_domain_impl(impl): + return 'pytorch' if impl=='pytorch' else 'numpy' -def _dft_space(shape, dtype='float64'): +def _dft_space(shape, dtype='float64', impl='numpy'): try: ndim = len(shape) except TypeError: @@ -57,6 +60,7 @@ def _dft_space(shape, dtype='float64'): shape, dtype=dtype, nodes_on_bdry=True, + impl = _dft_domain_impl(impl) ) @@ -71,12 +75,12 @@ def sinc(x): def test_dft_init(impl): # Just check if the code runs at all shape = (4, 5) - dom = _dft_space(shape) - dom_nonseq = odl.uniform_discr([0, 0], [1, 1], shape) + dom = _dft_space(shape, impl=impl) + dom_nonseq = odl.uniform_discr([0, 0], [1, 1], shape, impl=impl) dom_f32 = dom.astype('float32') - ran = _dft_space(shape, dtype='complex128') + ran = _dft_space(shape, dtype='complex128', impl=impl) ran_c64 = ran.astype('complex64') - ran_hc = _dft_space((3, 5), dtype='complex128') + ran_hc = _dft_space((3, 5), dtype='complex128', impl=impl) # Implicit range DiscreteFourierTransform(dom, impl=impl) @@ -191,10 +195,10 @@ def test_idft_init(impl): # Just check if the code runs at all; this uses the init function of # DiscreteFourierTransform, so we don't need exhaustive tests here shape = (4, 5) - ran = _dft_space(shape, dtype='complex128') - ran_hc = _dft_space(shape, dtype='float64') - dom = _dft_space(shape, dtype='complex128') - dom_hc = _dft_space((3, 5), dtype='complex128') + ran = _dft_space(shape, dtype='complex128', impl=impl) + ran_hc = _dft_space(shape, dtype='float64', impl=impl) + dom = _dft_space(shape, dtype='complex128', impl=impl) + dom_hc = _dft_space((3, 5), dtype='complex128', impl=impl) # Implicit range DiscreteFourierTransformInverse(dom, impl=impl) @@ -209,7 +213,7 @@ def test_dft_call(impl): # 2d, complex, all ones and random back & forth shape = (4, 5) - dft_dom = _dft_space(shape, dtype='complex64') + dft_dom = _dft_space(shape, dtype='complex64', impl=impl) dft = DiscreteFourierTransform(domain=dft_dom, impl=impl) idft = DiscreteFourierTransformInverse(range=dft_dom, impl=impl) @@ -243,7 +247,7 @@ def test_dft_call(impl): # 2d, halfcomplex, first axis shape = (4, 5) axes = 0 - dft_dom = _dft_space(shape, dtype='float32') + dft_dom = _dft_space(shape, dtype='float32', impl=impl) dft = DiscreteFourierTransform(domain=dft_dom, impl=impl, halfcomplex=True, axes=axes) idft = DiscreteFourierTransformInverse(range=dft_dom, impl=impl, @@ -276,7 +280,7 @@ def test_dft_sign(impl): # 2d, complex, all ones and random back & forth shape = (4, 5) - dft_dom = _dft_space(shape, dtype='complex64') + dft_dom = _dft_space(shape, dtype='complex64', impl=impl) dft_minus = DiscreteFourierTransform(domain=dft_dom, impl=impl, sign='-') dft_plus = DiscreteFourierTransform(domain=dft_dom, impl=impl, sign='+') @@ -297,7 +301,7 @@ def test_dft_sign(impl): # 2d, halfcomplex, first axis shape = (4, 5) axes = (0,) - dft_dom = _dft_space(shape, dtype='float32') + dft_dom = _dft_space(shape, dtype='float32', impl=impl) arr = dft_dom.element([[0, 0, 0, 0, 0], [0, 0, 1, 1, 0], [0, 0, 1, 1, 0], @@ -321,7 +325,7 @@ def test_dft_init_plan(impl): # 2d, halfcomplex, first axis shape = (4, 5) axes = 0 - dft_dom = _dft_space(shape, dtype='float32') + dft_dom = _dft_space(shape, dtype='float32', impl=impl) dft = DiscreteFourierTransform(dft_dom, impl=impl, axes=axes, halfcomplex=True) @@ -395,11 +399,14 @@ def test_fourier_trafo_init_plan(impl, odl_floating_dtype): # Not supported, skip if dtype == np.dtype('float16') and impl == 'pyfftw': return + elif (dtype in [np.dtype('float128'), np.dtype('complex256')] + and impl == 'pytorch'): + return shape = 10 halfcomplex, _ = _params_from_dtype(dtype) - space_discr = odl.uniform_discr(0, 1, shape, dtype=dtype) + space_discr = odl.uniform_discr(0, 1, shape, dtype=dtype, impl=_dft_domain_impl(impl)) ft = FourierTransform(space_discr, impl=impl, halfcomplex=halfcomplex) if impl != 'pyfftw': @@ -474,10 +481,13 @@ def test_fourier_trafo_call(impl, odl_floating_dtype): # Not supported, skip if dtype == np.dtype('float16') and impl == 'pyfftw': return + elif (dtype in [np.dtype('float16'), np.dtype('float128'), np.dtype('complex256')] + and impl == 'pytorch'): + return shape = 10 halfcomplex, _ = _params_from_dtype(dtype) - space_discr = odl.uniform_discr(0, 1, shape, dtype=dtype) + space_discr = odl.uniform_discr(0, 1, shape, dtype=dtype, impl=_dft_domain_impl(impl)) ft = FourierTransform(space_discr, impl=impl, halfcomplex=halfcomplex) ift = ft.inverse @@ -548,7 +558,7 @@ def test_fourier_trafo_sign(impl, odl_real_floating_dtype): def char_interval(x): return (x >= 0) & (x <= 1) - discr = odl.uniform_discr(-2, 2, 40, impl='numpy', dtype=discrspace_dtype) + discr = odl.uniform_discr(-2, 2, 40, impl=_dft_domain_impl(impl), dtype=discrspace_dtype) ft_minus = FourierTransform(discr, sign='-', impl=impl) ft_plus = FourierTransform(discr, sign='+', impl=impl) @@ -593,7 +603,7 @@ def char_interval(x): return (x >= 0) & (x <= 1) # Complex-to-complex - discr = odl.uniform_discr(-2, 2, 40, impl='numpy', dtype='complex64') + discr = odl.uniform_discr(-2, 2, 40, impl=_dft_domain_impl(impl), dtype='complex64') discr_char = discr.element(char_interval) ft = FourierTransform(discr, sign=sign, impl=impl) @@ -601,7 +611,7 @@ def char_interval(x): assert all_almost_equal(ft.adjoint(ft(char_interval)), discr_char) # Half-complex - discr = odl.uniform_discr(-2, 2, 40, impl='numpy', dtype='float32') + discr = odl.uniform_discr(-2, 2, 40, impl=_dft_domain_impl(impl), dtype='float32') ft = FourierTransform(discr, impl=impl, halfcomplex=True) assert all_almost_equal(ft.inverse(ft(char_interval)), discr_char) @@ -609,7 +619,7 @@ def char_rect(x): return (x[0] >= 0) & (x[0] <= 1) & (x[1] >= 0) & (x[1] <= 1) # 2D with axes, C2C - discr = odl.uniform_discr([-2, -2], [2, 2], (20, 10), impl='numpy', + discr = odl.uniform_discr([-2, -2], [2, 2], (20, 10), impl=_dft_domain_impl(impl), dtype='complex64') discr_rect = discr.element(char_rect) @@ -619,7 +629,7 @@ def char_rect(x): assert all_almost_equal(ft.adjoint(ft(char_rect)), discr_rect) # 2D with axes, halfcomplex - discr = odl.uniform_discr([-2, -2], [2, 2], (20, 10), impl='numpy', + discr = odl.uniform_discr([-2, -2], [2, 2], (20, 10), impl=_dft_domain_impl(impl), dtype='float32') discr_rect = discr.element(char_rect) diff --git a/odl/trafos/fourier.py b/odl/trafos/fourier.py index 15424f402f7..36a77ad0a39 100644 --- a/odl/trafos/fourier.py +++ b/odl/trafos/fourier.py @@ -11,6 +11,7 @@ from __future__ import absolute_import, division, print_function import numpy as np +import torch from odl.util.npy_compat import AVOID_UNNECESSARY_COPY @@ -26,15 +27,32 @@ complex_dtype, conj_exponent, dtype_repr, is_complex_floating_dtype, is_real_dtype, normalized_axes_tuple, normalized_scalar_param_list) +from typing import Optional + __all__ = ('DiscreteFourierTransform', 'DiscreteFourierTransformInverse', 'FourierTransform', 'FourierTransformInverse') -_SUPPORTED_FOURIER_IMPLS = ('numpy',) -_DEFAULT_FOURIER_IMPL = 'numpy' +_SUPPORTED_FOURIER_IMPLS = {'numpy': ('numpy',), 'pytorch': ('pytorch',)} +_DEFAULT_FOURIER_IMPL = {'numpy': 'numpy', 'pytorch': 'pytorch'} if PYFFTW_AVAILABLE: - _SUPPORTED_FOURIER_IMPLS += ('pyfftw',) - _DEFAULT_FOURIER_IMPL = 'pyfftw' + _SUPPORTED_FOURIER_IMPLS['numpy'] += ('pyfftw',) + _DEFAULT_FOURIER_IMPL['numpy'] = 'pyfftw' + + +def _select_fft_impl(impl_suggestion: Optional[str], domain_impl: str): + if impl_suggestion is None: + impl = _DEFAULT_FOURIER_IMPL.get(domain_impl) + if impl is None: + raise ValueError("There is no default FFT implementation for" + + " tensors with {domain_impl} implementation.") + else: + impl = impl_suggestion + impl, impl_in = str(impl).lower(), impl + if impl not in _SUPPORTED_FOURIER_IMPLS.get(domain_impl): + raise ValueError(f"`impl` '{impl_in}' not supported for" + + f" tensors with {domain_impl} implementation.") + return impl class DiscreteFourierTransformBase(Operator): @@ -91,12 +109,7 @@ def __init__(self, inverse, domain, range=None, axes=None, sign='-', ''.format(range)) # Implementation - if impl is None: - impl = _DEFAULT_FOURIER_IMPL - impl, impl_in = str(impl).lower(), impl - if impl not in _SUPPORTED_FOURIER_IMPLS: - raise ValueError("`impl` '{}' not supported".format(impl_in)) - self.__impl = impl + self.__impl = _select_fft_impl(impl, domain.impl) # Axes if axes is None: @@ -125,12 +138,13 @@ def __init__(self, inverse, domain, range=None, axes=None, sign='-', domain.grid, shift=False, halfcomplex=halfcomplex, axes=axes).shape if range is None: - impl = domain.tspace.impl + domain_impl = domain.tspace.impl shape = np.atleast_1d(ran_shape) range = uniform_discr( - [0] * len(shape), shape - 1, shape, ran_dtype, impl, - nodes_on_bdry=True, exponent=conj_exponent(domain.exponent)) + [0] * len(shape), shape - 1, shape, ran_dtype, + nodes_on_bdry=True, exponent=conj_exponent(domain.exponent), + impl=domain.impl) else: if range.shape != ran_shape: @@ -171,10 +185,15 @@ def _call(self, x, out, **kwargs): Call pyfftw backend directly """ # TODO: Implement zero padding - if self.impl == 'numpy': - out[:] = self._call_numpy(x.asarray()) - else: - out[:] = self._call_pyfftw(x.asarray(), out.asarray(), **kwargs) + match self.impl: + case 'numpy': + out[:] = self._call_numpy(x.asarray()) + case 'pyfftw': + out[:] = self._call_pyfftw(x.asarray(), out.asarray(), **kwargs) + case 'pytorch': + out[:] = self._call_pytorch(x.data) + case _: + raise NotImplementedError(self.impl) @property def impl(self): @@ -234,6 +253,21 @@ def _call_numpy(self, x): """ raise NotImplementedError('abstract method') + def _call_pytorch(self, x): + """Return ``self(x)`` for PyTorch back-end. + + Parameters + ---------- + x : `torch.Tensor` + Array representing the function to be transformed + + Returns + ------- + out : `torch.Tensor` + Result of the transform + """ + raise NotImplementedError(f'abstract method, not implemented on {type(self)}.') + def _call_pyfftw(self, x, out, **kwargs): """Implement ``self(x[, out, **kwargs])`` using pyfftw. @@ -467,6 +501,25 @@ def _call_numpy(self, x): return (np.prod(np.take(self.domain.shape, self.axes)) * np.fft.ifftn(x, axes=self.axes)) + def _call_pytorch(self, x): + """Return ``self(x)`` using PyTorch. + + See Also + -------- + DiscreteFourierTransformBase._call_pytorch + """ + assert isinstance(x, torch.Tensor) + + if self.halfcomplex: + return torch.fft.rfftn(x, dim=self.axes) + else: + if self.sign == '-': + return torch.fft.fftn(x, dim=self.axes) + else: + # Need to undo IFFT scaling + return (np.prod(np.take(self.domain.shape, self.axes)) + * torch.fft.ifftn(x, dim=self.axes)) + def _call_pyfftw(self, x, out, **kwargs): """Implement ``self(x[, out, **kwargs])`` using pyfftw. @@ -623,6 +676,28 @@ def _call_numpy(self, x): return (np.fft.fftn(x, axes=self.axes) / np.prod(np.take(self.domain.shape, self.axes))) + def _call_pytorch(self, x): + """Return ``self(x)`` using PyTorch. + + Parameters + ---------- + x : `torch.Tensor` + Input array to be transformed + + Returns + ------- + out : `torch.Tensor` + Result of the transform + """ + if self.halfcomplex: + return torch.fft.irfftn(x, dim=self.axes) + else: + if self.sign == '+': + return torch.fft.ifftn(x, dim=self.axes) + else: + return (torch.fft.fftn(x, dim=self.axes) + / np.prod(np.take(self.domain.shape, self.axes))) + def _call_pyfftw(self, x, out, **kwargs): """Implement ``self(x[, out, **kwargs])`` using pyfftw. @@ -739,9 +814,12 @@ def __init__(self, inverse, domain, range=None, impl=None, **kwargs): is determined from ``domain`` and the other parameters. The exponent is chosen to be the conjugate ``p / (p - 1)``, which reads as 'inf' for p=1 and 1 for p='inf'. - impl : {'numpy', 'pyfftw'}, optional - Backend for the FFT implementation. The 'pyfftw' backend - is faster but requires the ``pyfftw`` package. + impl : {'numpy', 'pyfftw', 'pytorch'}, optional + Backend for the FFT implementation. NumPy is slow but always + supported. The 'pyfftw' backend is faster but requires the + ``pyfftw`` package. + 'pytorch' requires ``domain`` to be based on PyTorch tensors, + in which case this is the fastest option particularly on GPU. ``None`` selects the fastest available backend. axes : int or sequence of ints, optional Dimensions along which to take the transform. @@ -804,22 +882,13 @@ def __init__(self, inverse, domain, range=None, impl=None, **kwargs): if not isinstance(domain, DiscretizedSpace): raise TypeError('domain {!r} is not a `DiscretizedSpace` instance' ''.format(domain)) - if domain.impl != 'numpy': - raise NotImplementedError( - 'Only Numpy-based data spaces are supported, got {}' - ''.format(domain.tspace)) # axes axes = kwargs.pop('axes', np.arange(domain.ndim)) self.__axes = normalized_axes_tuple(axes, domain.ndim) # Implementation - if impl is None: - impl = _DEFAULT_FOURIER_IMPL - impl, impl_in = str(impl).lower(), impl - if impl not in _SUPPORTED_FOURIER_IMPLS: - raise ValueError("`impl` '{}' not supported".format(impl_in)) - self.__impl = impl + self.__impl = _select_fft_impl(impl, domain.impl) # Handle half-complex yes/no and shifts halfcomplex = kwargs.pop('halfcomplex', True) @@ -905,11 +974,16 @@ def _call(self, x, out, **kwargs): Call pyfftw backend directly """ # TODO: Implement zero padding - if self.impl == 'numpy': - out[:] = self._call_numpy(x.asarray()) - else: - # 0-overhead assignment if asarray() does not copy - out[:] = self._call_pyfftw(x.asarray(), out.asarray(), **kwargs) + match self.impl: + case 'numpy': + out[:] = self._call_numpy(x.asarray()) + case 'pyfftw': + # 0-overhead assignment if asarray() does not copy + out[:] = self._call_pyfftw(x.asarray(), out.asarray(), **kwargs) + case 'pytorch': + out[:] = self._call_pytorch(x.asarray(), **kwargs) + case _: + raise NotImplementedError(self.impl) def _call_numpy(self, x): """Return ``self(x)`` for numpy back-end. @@ -926,6 +1000,21 @@ def _call_numpy(self, x): """ raise NotImplementedError('abstract method') + def _call_pytorch(self, x): + """Return ``self(x)`` for PyTorch back-end. + + Parameters + ---------- + x : `torch.Tensor` + Array representing the function to be transformed + + Returns + ------- + out : `torch.Tensor` + Result of the transform + """ + raise NotImplementedError(f'abstract method, not implemented on {type(self)}.') + def _call_pyfftw(self, x, out, **kwargs): """Implement ``self(x[, out, **kwargs])`` for pyfftw back-end. @@ -1296,7 +1385,7 @@ def _postprocess(self, x, out=None): # TODO(kohr-h): Add `interp` to operator or simplify it by not # performing interpolation filter return dft_postprocess_data( - out, real_grid=self.domain.grid, recip_grid=self.range.grid, + x, real_grid=self.domain.grid, recip_grid=self.range.grid, shift=self.shifts, axes=self.axes, sign=self.sign, interp='nearest', op='multiply', out=out) @@ -1338,6 +1427,42 @@ def _call_numpy(self, x): self._postprocess(out, out=out) return out + def _call_pytorch(self, x): + """Return ``self(x)`` for PyTorch back-end. + + Parameters + ---------- + x : `torch.Tensor` + Array representing the function to be transformed + + Returns + ------- + out : `torch.Tensor` + Result of the transform + """ + + preproc = self._preprocess(x) + + # The actual call to the FFT library + if self.halfcomplex: + out = torch.fft.rfftn(preproc, dim=self.axes) + else: + if self.sign == '-': + out = torch.fft.fftn(preproc, dim=self.axes) + else: + out = torch.fft.ifftn(preproc, dim=self.axes) + # Numpy's FFT normalizes by 1 / prod(shape[axes]), we + # need to undo that + # TODO(Justus) select PyTorch normalization mode so this + # is unnecessary + out *= float(np.prod(np.take(self.domain.shape, self.axes))) + + # Post-processing accounting for shift, scaling and interpolation + assert(isinstance(out, torch.Tensor)) + out = self._postprocess(out, out=out) + assert(isinstance(out, torch.Tensor)) + return out + def _call_pyfftw(self, x, out, **kwargs): """Implement ``self(x[, out, **kwargs])`` for pyfftw back-end. @@ -1533,7 +1658,7 @@ def _postprocess(self, x, out=None): The result is stored in ``out`` if given, otherwise in a temporary or a new array. """ - if out is None: + if out is None and self.impl!='pytorch': if self.range.field == ComplexNumbers(): out = self._tmp_r if self._tmp_r is not None else self._tmp_f elif self.range.field == RealNumbers() and not self.halfcomplex: @@ -1583,6 +1708,46 @@ def _call_numpy(self, x): else: return out + def _call_pytorch(self, x): + """Return ``self(x)`` for numpy back-end. + + Parameters + ---------- + x : `torch.Tensor` + Array representing the function to be transformed + + Returns + ------- + out : `torch.Tensor` + Result of the transform + """ + # Pre-processing before calculating the DFT + preproc = self._preprocess(x) + + # The actual call to the FFT library + # Normalization by 1 / prod(shape[axes]) is done by Numpy's FFT if + # one of the "i" functions is used. For sign='-' we need to do it + # ourselves. + if self.halfcomplex: + s = tuple(np.asarray(self.range.shape)[list(self.axes)]) + out = torch.fft.irfftn(preproc, dim=self.axes, s=s) + else: + if self.sign == '-': + out = torch.fft.fftn(preproc, dim=self.axes) + out /= np.prod(np.take(self.domain.shape, self.axes)) + else: + out = torch.fft.ifftn(preproc, dim=self.axes) + + # Post-processing in IFT = pre-processing in FT (in-place) + out = self._postprocess(out) + if self.halfcomplex: + assert is_real_dtype(out.dtype) + + if self.range.field == RealNumbers(): + return out.real + else: + return out + def _call_pyfftw(self, x, out, **kwargs): """Implement ``self(x[, out, **kwargs])`` for pyfftw back-end. diff --git a/odl/trafos/util/ft_utils.py b/odl/trafos/util/ft_utils.py index d4d4f65dcbf..956288eacf6 100644 --- a/odl/trafos/util/ft_utils.py +++ b/odl/trafos/util/ft_utils.py @@ -11,6 +11,7 @@ from __future__ import absolute_import, division, print_function import numpy as np +import torch from odl.util.npy_compat import AVOID_UNNECESSARY_COPY @@ -18,8 +19,10 @@ DiscretizedSpace, uniform_discr_frompartition, uniform_grid, uniform_partition_fromgrid) from odl.set import RealNumbers +from odl.space.base_tensors import Tensor from odl.util import ( complex_dtype, conj_exponent, dtype_repr, fast_1d_tensor_mult, + uses_pytorch, compatible_array_manager, is_complex_floating_dtype, is_numeric_dtype, is_real_dtype, is_real_floating_dtype, is_string, normalized_axes_tuple, normalized_scalar_param_list) @@ -296,7 +299,20 @@ def dft_preprocess_data(arr, shift=True, axes=None, sign='-', out=None): type and ``shift`` is not ``True``. In this case, the return type is the complex counterpart of ``arr.dtype``. """ - arr = np.asarray(arr) + + use_pytorch = uses_pytorch(arr) + array_mgr = compatible_array_manager(arr) + + if use_pytorch: + assert(out is None or isinstance(out, torch.Tensor)), f"{type(out)=}" + else: + if hasattr(arr, 'impl'): + assert(arr.impl=='numpy'), f"{arr.impl=}" + else: + assert(isinstance(arr, np.ndarray)), f"{type(arr)=}" + assert(out is None or isinstance(out, np.ndarray)), f"{type(out)=}" + + arr = array_mgr.as_compatible_array(arr) if not is_numeric_dtype(arr.dtype): raise ValueError('array has non-numeric data type {}' ''.format(dtype_repr(arr.dtype))) @@ -320,7 +336,7 @@ def dft_preprocess_data(arr, shift=True, axes=None, sign='-', out=None): if is_real_dtype(arr.dtype) and not all(shift_list): out = np.array(arr, dtype=complex_dtype(arr.dtype), copy=True) else: - out = arr.copy() + out = array_mgr.make_copy(arr) else: out[:] = arr @@ -338,13 +354,13 @@ def dft_preprocess_data(arr, shift=True, axes=None, sign='-', out=None): def _onedim_arr(length, shift): if shift: # (-1)^indices - factor = np.ones(length, dtype=out.dtype) + factor = array_mgr.compatible_ones(length, dtype=out.dtype) factor[1::2] = -1 else: - factor = np.arange(length, dtype=out.dtype) + factor = array_mgr.as_compatible_array(np.arange(length), dtype=out.dtype) factor *= -imag * np.pi * (1 - 1.0 / length) np.exp(factor, out=factor) - return factor.astype(out.dtype, copy=AVOID_UNNECESSARY_COPY) + return array_mgr.select_dtype(factor, out.dtype, copy=AVOID_UNNECESSARY_COPY) onedim_arrs = [] for axis, shift in zip(axes, shift_list): @@ -460,7 +476,23 @@ def dft_postprocess_data(arr, real_grid, recip_grid, shift, axes, *Numerical Recipes in C - The Art of Scientific Computing* (Volume 3). Cambridge University Press, 2007. """ - arr = np.asarray(arr) + + use_pytorch = uses_pytorch(arr) + array_mgr = compatible_array_manager(arr) + + if use_pytorch: + assert(out is None or isinstance(out, torch.Tensor)), f"{type(out)=}" + assert(arr.dtype in [torch.float32, torch.float64, torch.complex64, torch.complex128]) + else: + if hasattr(arr, 'impl'): + assert(arr.impl=='numpy'), f"{arr.impl=}" + else: + assert(isinstance(arr, np.ndarray)), f"{type(arr)=}" + assert(out is None or isinstance(out, np.ndarray)), f"{type(out)=}" + assert(arr.dtype in map(np.dtype, ['float32', 'float64', 'float128', + 'complex64', 'complex128', 'complex256'])) + + arr = array_mgr.as_compatible_array(arr) if is_real_floating_dtype(arr.dtype): arr = arr.astype(complex_dtype(arr.dtype)) elif not is_complex_floating_dtype(arr.dtype): @@ -468,7 +500,7 @@ def dft_postprocess_data(arr, real_grid, recip_grid, shift, axes, 'data type'.format(dtype_repr(arr.dtype))) if out is None: - out = arr.copy() + out = array_mgr.make_copy(arr) elif out is not arr: out[:] = arr @@ -542,7 +574,7 @@ def dft_postprocess_data(arr, real_grid, recip_grid, shift, axes, else: onedim_arr /= interp_kernel - onedim_arrs.append(onedim_arr.astype(out.dtype, copy=AVOID_UNNECESSARY_COPY)) + onedim_arrs.append(array_mgr.as_compatible_array(onedim_arr, dtype=out.dtype)) fast_1d_tensor_mult(out, onedim_arrs, axes=axes, out=out) return out @@ -618,7 +650,7 @@ def reciprocal_space(space, axes=None, halfcomplex=False, shift=True, raise ValueError('{} is not a complex data type' ''.format(dtype_repr(dtype))) - impl = kwargs.pop('impl', 'numpy') + impl = kwargs.pop('impl', space.impl) # Calculate range recip_grid = reciprocal_grid(space.grid, shift=shift, diff --git a/odl/util/numerics.py b/odl/util/numerics.py index d5f59fbb67b..b2cdb24fe85 100644 --- a/odl/util/numerics.py +++ b/odl/util/numerics.py @@ -11,6 +11,8 @@ from __future__ import absolute_import, division, print_function import numpy as np +import torch +from odl.util.utility import is_castable_to, uses_pytorch, compatible_array_manager from odl.util.normalize import normalized_scalar_param_list, safe_int_conv __all__ = ( @@ -208,6 +210,8 @@ def fast_1d_tensor_mult(ndarr, onedim_arrs, axes=None, out=None): The advantage of this approach is that it is memory-friendly and loops over the big array only twice. + TODO update documentation WRT PyTorch + Parameters ---------- ndarr : `array-like` @@ -227,10 +231,18 @@ def fast_1d_tensor_mult(ndarr, onedim_arrs, axes=None, out=None): Result of the modification. If ``out`` was given, the returned object is a reference to it. """ + use_pytorch = uses_pytorch(ndarr) or uses_pytorch(out) + array_mgr = compatible_array_manager(ndarr) + if out is None: - out = np.array(ndarr, copy=True) - else: + if use_pytorch: + out = torch.Tensor(ndarr, copy=True) + else: + out = np.array(ndarr, copy=True) + elif type(out)==type(ndarr): out[:] = ndarr # Self-assignment is free if out is ndarr + else: + raise TypeError(f"{type(ndarr)=} should be the same as {type(out)=}") if not onedim_arrs: raise ValueError('no 1d arrays given') @@ -251,14 +263,17 @@ def fast_1d_tensor_mult(ndarr, onedim_arrs, axes=None, out=None): raise ValueError('`axes` {} out of bounds for {} dimensions' ''.format(axes_in, out.ndim)) + atleast_1d = torch.atleast_1d if use_pytorch else np.atleast_1d + + # Make scalars 1d arrays and squeezable arrays 1d + alist = [atleast_1d(array_mgr.as_compatible_array(a).squeeze()) for a in onedim_arrs] # Make scalars 1d arrays and squeezable arrays 1d - alist = [np.atleast_1d(np.asarray(a).squeeze()) for a in onedim_arrs] if any(a.ndim != 1 for a in alist): raise ValueError('only 1d arrays allowed') if len(axes) < out.ndim: # Make big factor array (start with 0d) - factor = np.array(1.0) + factor = array_mgr.as_compatible_array([1.0]) for ax, arr in zip(axes, alist): # Meshgrid-style slice slc = [None] * out.ndim @@ -272,11 +287,11 @@ def fast_1d_tensor_mult(ndarr, onedim_arrs, axes=None, out=None): # Get the axis to spare for the final multiplication, the one # with the largest stride. - last_ax = np.argmax(out.strides) + last_ax = out.ndim-1 if use_pytorch else np.argmax(out.strides) last_arr = alist[axes.index(last_ax)] # Build the semi-big array and multiply - factor = np.array(1.0) + factor = array_mgr.as_compatible_array([1.0]) for ax, arr in zip(axes, alist): if ax == last_ax: continue @@ -421,16 +436,31 @@ def resize_array(arr, newshp, offset=None, pad_mode='constant', pad_const=0, except TypeError: raise TypeError('`newshp` must be a sequence, got {!r}'.format(newshp)) + if isinstance(arr, np.ndarray): + impl = 'numpy' + elif isinstance(arr, torch.Tensor): + impl = 'pytorch' + else: + raise TypeError(f"Unknown how to resize array (?) of type {type(arr)}.") + if out is not None: - if not isinstance(out, np.ndarray): + if impl=='numpy' and not isinstance(out, np.ndarray): raise TypeError('`out` must be a `numpy.ndarray` instance, got ' '{!r}'.format(out)) + elif impl=='pytorch' and not isinstance(out, torch.Tensor): + raise TypeError('`out` must be a `torch.Tensor` instance, got ' + '{!r}'.format(out)) if out.shape != newshp: raise ValueError('`out` must have shape {}, got {}' ''.format(newshp, out.shape)) - order = 'C' if out.flags.c_contiguous else 'F' - arr = np.asarray(arr, dtype=out.dtype, order=order) + if impl=='pytorch': + if arr.dtype != out.dtype: + arr = torch.tensor(arr, dtype=out.dtype) + else: # NumPy + order = 'C' if out.flags.c_contiguous else 'F' + arr = np.asarray(arr, dtype=out.dtype, order=order) + if arr.ndim != out.ndim: raise ValueError('number of axes of `arr` and `out` do not match ' '({} != {})'.format(arr.ndim, out.ndim)) @@ -455,16 +485,14 @@ def resize_array(arr, newshp, offset=None, pad_mode='constant', pad_const=0, if pad_mode not in _SUPPORTED_RESIZE_PAD_MODES: raise ValueError("`pad_mode` '{}' not understood".format(pad_mode_in)) - if (pad_mode == 'constant' and - any(n_new > n_orig - for n_orig, n_new in zip(arr.shape, out.shape))): - try: - pad_const_scl = np.array([pad_const], out.dtype) - assert(pad_const_scl == np.array([pad_const])) - except Exception as e: - raise ValueError('`pad_const` {} cannot be safely cast to the data ' - 'type {} of the output array' - ''.format(pad_const, out.dtype)) + if pad_mode == 'constant': + incompatible_const_error = ValueError( + f'`pad_const` {pad_const} cannot be safely cast to the data ' + + f'type {out.dtype} of the output array') + if (not is_castable_to(pad_const, out.dtype) + and any(n_new > n_orig + for n_orig, n_new in zip(arr.shape, out.shape))): + raise incompatible_const_error # Handle direction direction, direction_in = str(direction).lower(), direction @@ -476,10 +504,11 @@ def resize_array(arr, newshp, offset=None, pad_mode='constant', pad_const=0, raise ValueError("`pad_const` must be 0 for 'adjoint' direction, " "got {}".format(pad_const)) + fill_with = out.fill_ if impl=='pytorch' else out.fill if direction == 'forward' and pad_mode == 'constant' and pad_const != 0: - out.fill(pad_const) + fill_with(pad_const) else: - out.fill(0) + fill_with(0) # Perform the resizing if direction == 'forward': diff --git a/odl/util/ufuncs.py b/odl/util/ufuncs.py index 6926e642501..caec596ae70 100644 --- a/odl/util/ufuncs.py +++ b/odl/util/ufuncs.py @@ -26,10 +26,11 @@ from __future__ import print_function, division, absolute_import from builtins import object import numpy as np +import torch import re -__all__ = ('TensorSpaceUfuncs', 'ProductSpaceUfuncs') +__all__ = ('NumpyTensorSpaceUfuncs', 'ProductSpaceUfuncs') # Some are ignored since they don't cooperate with dtypes, needs fix @@ -64,6 +65,36 @@ """.format(name) UFUNCS.append((name, n_in, n_out, doc)) +TORCH_RAW_UFUNCS = ['absolute', 'add', 'arccos', 'arccosh', 'arcsin', 'arcsinh', + 'arctan', 'arctan2', 'arctanh', 'bitwise_and', 'bitwise_or', + 'bitwise_xor', 'ceil', 'conj', 'copysign', 'cos', 'cosh', + 'deg2rad', 'divide', 'equal', 'exp', 'exp2', 'expm1', 'floor', + 'floor_divide', 'fmax', 'fmin', 'fmod', 'greater', + 'greater_equal', 'hypot', 'isfinite', 'isinf', 'isnan', + 'less', 'less_equal', 'log', 'log10', 'log1p', + 'log2', 'logaddexp', 'logaddexp2', 'logical_and', 'logical_not', + 'logical_or', 'logical_xor', 'maximum', 'minimum', + 'multiply', 'negative', 'not_equal', + 'rad2deg', 'reciprocal', 'remainder', + 'sign', 'signbit', 'sin', 'sinh', 'sqrt', 'square', 'subtract', + 'tan', 'tanh', 'true_divide', 'trunc'] +# Add some standardized information +TORCH_UFUNCS = [] +for name in TORCH_RAW_UFUNCS: + ufunc = getattr(np, name) + n_in, n_out = ufunc.nin, ufunc.nout + descr = ufunc.__doc__.splitlines()[2] + # Numpy occasionally uses single ticks for doc, we only use them for links + descr = re.sub('`+', '``', descr) + doc = descr + """ + +See Also +-------- +torch.{} +""".format(name) + TORCH_UFUNCS.append((name, n_in, n_out, doc)) + + # TODO: add the following reductions (to the CUDA implementation): # ['var', 'trace', 'tensordot', 'std', 'ptp', 'mean', 'diff', 'cumsum', # 'cumprod', 'average'] @@ -72,7 +103,7 @@ # --- Wrappers for `Tensor` --- # -def wrap_ufunc_base(name, n_in, n_out, doc): +def wrap_ufunc_numpy(name, n_in, n_out, doc): """Return ufunc wrapper for implementation-agnostic ufunc classes.""" ufunc = getattr(np, name) if n_in == 1: @@ -111,8 +142,42 @@ def wrapper(self, x2, out=None, **kwargs): wrapper.__doc__ = doc return wrapper +def wrap_ufunc_pytorch(name, n_in, n_out, doc): + """Return ufunc wrapper for implementation-agnostic ufunc classes.""" + ufunc = getattr(torch, name) + + if n_in == 1: + def wrapper(self, out=None, **kwargs): + if out is None: + return self.elem.space.element(ufunc(self.elem.data, **kwargs)) + elif isinstance(out, type(self.elem)): + ufunc(self.elem.data, out=out.data, **kwargs) + return + raise NotImplementedError() + + elif n_in == 2: + def wrapper(self, x2, out=None, **kwargs): + if out is None: + return self.elem.space.element(ufunc(self.elem.data, **kwargs)) + elif isinstance(out, type(self.elem)): + selfdata = self.elem.data + if isinstance(x2, type(self.elem)): + x2 = x2.data + elif isinstance(x2, (float, int)): + x2 = torch.tensor(x2).to(selfdata.device) + ufunc(selfdata, x2, out=out.data, **kwargs) + return + raise NotImplementedError() -class TensorSpaceUfuncs(object): + else: + raise NotImplementedError + + wrapper.__name__ = wrapper.__qualname__ = name + wrapper.__doc__ = doc + return wrapper + + +class NumpyTensorSpaceUfuncs(object): """Ufuncs for `Tensor` objects. @@ -176,9 +241,18 @@ def max(self, axis=None, dtype=None, out=None, keepdims=False): # Add ufunc methods to ufunc class for name, n_in, n_out, doc in UFUNCS: - method = wrap_ufunc_base(name, n_in, n_out, doc) - setattr(TensorSpaceUfuncs, name, method) + method = wrap_ufunc_numpy(name, n_in, n_out, doc) + setattr(NumpyTensorSpaceUfuncs, name, method) + + +class PytorchTensorSpaceUfuncs(object): + def __init__(self, elem): + """Create ufunc wrapper for elem.""" + self.elem = elem +for name, n_in, n_out, doc in TORCH_UFUNCS: + method = wrap_ufunc_pytorch(name, n_in, n_out, doc) + setattr(PytorchTensorSpaceUfuncs, name, method) # --- Wrappers for `ProductSpaceElement` --- # diff --git a/odl/util/utility.py b/odl/util/utility.py index 1df8f24896d..4ff17d55d14 100644 --- a/odl/util/utility.py +++ b/odl/util/utility.py @@ -15,8 +15,11 @@ from collections import OrderedDict from contextlib import contextmanager from itertools import product +from abc import ABC +from typing import Optional import numpy as np +import torch __all__ = ( 'REPR_PRECISION', @@ -26,6 +29,7 @@ 'array_str', 'dtype_repr', 'dtype_str', + 'dtype_type', 'cache_arguments', 'is_numeric_dtype', 'is_int_dtype', @@ -33,12 +37,16 @@ 'is_real_dtype', 'is_real_floating_dtype', 'is_complex_floating_dtype', + 'is_castable_to', + 'uses_pytorch', 'real_dtype', 'complex_dtype', 'is_string', 'nd_iterator', 'conj_exponent', 'nullcontext', + 'ArrayOnBackendManager', + 'compatible_array_manager', 'writable_array', 'signature_string', 'signature_string_parts', @@ -319,6 +327,27 @@ def dtype_str(dtype): else: return '{}'.format(dtype) +def dtype_type(dtype): + """Obtain a Python type corresponding to the given NumPy or PyTorch + dtype. This can be used for constructing values of a suitable type + for storing in an array of either backend.""" + if isinstance(dtype, str) or isinstance(dtype, type): + dtype = np.dtype(dtype) + + if hasattr(dtype, 'dtype'): + return dtype_type(dtype.dtype) + elif dtype == np.dtype(int): + return int + elif dtype == np.dtype(float): + return float + elif dtype == np.dtype(complex): + return complex + elif dtype == torch.float64: + return float + else: + raise ValueError(f"No suitable Python type available for {dtype}.") + + def cache_arguments(function): """Decorate function to cache the result with given arguments. @@ -341,6 +370,13 @@ def cache_arguments(function): @cache_arguments def is_numeric_dtype(dtype): """Return ``True`` if ``dtype`` is a numeric type.""" + if isinstance(dtype, torch.dtype): + try: + assert(dtype in [torch.float32, torch.float64]) + return True + except AssertionError: + assert(dtype in [torch.complex64, torch.complex128]) + return True dtype = np.dtype(dtype) return np.issubdtype(getattr(dtype, 'base', None), np.number) @@ -367,13 +403,24 @@ def is_real_dtype(dtype): @cache_arguments def is_real_floating_dtype(dtype): """Return ``True`` if ``dtype`` is a real floating point type.""" - dtype = np.dtype(dtype) + if isinstance(dtype, torch.dtype): + if dtype in [torch.complex64, torch.complex128]: + return False + else: + assert(dtype in [torch.float32, torch.float64]) + return True + dtype = np.dtype( dtype) return np.issubdtype(getattr(dtype, 'base', None), np.floating) @cache_arguments def is_complex_floating_dtype(dtype): """Return ``True`` if ``dtype`` is a complex floating point type.""" + if isinstance(dtype, torch.dtype): + if(dtype in [torch.float32, torch.float64]): + return False + assert(dtype in [torch.complex64, torch.complex128]) + return True dtype = np.dtype(dtype) return np.issubdtype(getattr(dtype, 'base', None), np.complexfloating) @@ -497,6 +544,51 @@ def complex_dtype(dtype, default=None): else: return np.dtype((complex_base_dtype, dtype.shape)) +_CORRESPONDING_PYTORCH_DTYPES = { + np.dtype('float16'): torch.float16, + np.dtype('float32'): torch.float32, + np.dtype('float64'): torch.float64, + np.dtype('complex64'): torch.complex64, + np.dtype('complex128'): torch.complex128} + +@cache_arguments +def is_castable(from_dtype, to_dtype): + """Determine whether the type `from` is safely convertible to `to`. + Both should be either NumPy `dtype` or PyTorch `dtype`.""" + if isinstance(from_dtype, np.dtype) and isinstance(to_dtype, np.dtype): + return np.can_cast(from_dtype, to_dtype) + elif (isinstance(to_dtype, torch.dtype)): + from_dtype = _CORRESPONDING_PYTORCH_DTYPES.get(from_dtype, from_dtype) + # Torch does not provide a satisfying way to determine castability, + # so we find it out by experiment. + # This is somewhat expensive, so it is important that this function is + # memoised (cache_arguments). + try: + gen = torch.Generator() + gen.manual_seed(1232451) # Avoid nondeterministic behaviour + test_arr = torch.rand((1000,), generator=gen, dtype=from_dtype) + roundtripped = test_arr.to(to_dtype).to(from_dtype) + except TypeError: + return False + return torch.equal(roundtripped, test_arr) + +def is_castable_to(obj, dtype): + """Determine whether there is a safe way to cast `obj` to the type + specified by `dtype`, which can be either a NumPy dtype or a PyTorch + `dtype`.""" + if hasattr(obj, 'dtype'): + obj_dtype = obj.dtype + else: + obj_dtype = np.array([obj]).dtype + return is_castable(obj_dtype, dtype) + +def uses_pytorch(obj): + if isinstance(obj, torch.Tensor): + return True + elif getattr(obj, "impl", None)=="pytorch": + return True + else: + return False def is_string(obj): """Return ``True`` if ``obj`` behaves like a string, ``False`` else.""" @@ -507,6 +599,70 @@ def is_string(obj): else: return True +class ArrayOnBackendManager(ABC): + def __init__(self): + raise NotImplementedError() + def as_compatible_array(self, arr, **kwargs): + raise NotImplementedError() + def compatible_zeros(self, shape, **kwargs): + raise NotImplementedError() + def compatible_ones(self, shape, **kwargs): + raise NotImplementedError() + def select_dtype(self, arr, dtype, copy: Optional[bool]): + raise NotImplementedError() + def make_copy(self, arr): + raise NotImplementedError() + +class ArrayOnPytorchManager(ABC): + def __init__(self, device): + self._device = device + def as_compatible_array(self, arr, **kwargs): + dtype = kwargs.get('dtype', None) + if isinstance(arr, torch.Tensor): + arr = arr.detach() + if dtype is not None and arr.dtype!=kwargs['dtype']: + arr = arr.type(dtype) + if self._device is not None and arr.device!=self._device: + return arr.to(self._device) + else: + return arr + else: + return torch.tensor(arr, device = self._device, **kwargs) + def compatible_zeros(self, shape, **kwargs): + return torch.zeros(shape, device = self._device, **kwargs) + def compatible_ones(self, shape, **kwargs): + return torch.ones(shape, device = self._device, **kwargs) + def select_dtype(self, arr, dtype, copy): + if dtype in _CORRESPONDING_PYTORCH_DTYPES: + dtype = _CORRESPONDING_PYTORCH_DTYPES[dtype] + # PyTorch (as of version 2.7) only supports the values False and True + # for the `copy` argument, the former being a non-binding request to + # avoid a copy if it is not necessary. + if copy==AVOID_UNNECESSARY_COPY: + copy = False + return arr.type(dtype, copy=copy) + def make_copy(self, arr): + return arr.clone().detach() + +class ArrayOnNumPyManager(ABC): + def __init__(self): + pass + def as_compatible_array(self, arr, **kwargs): + return np.array(arr, **kwargs) + def compatible_zeros(self, shape, **kwargs): + return np.zeros(shape, **kwargs) + def compatible_ones(self, shape, **kwargs): + return np.ones(shape, **kwargs) + def select_dtype(self, arr, dtype, copy): + return arr.astype(dtype, copy=copy) + def make_copy(self, arr): + return arr.copy() + +def compatible_array_manager(arr): + if uses_pytorch(arr): + return ArrayOnPytorchManager(arr.device) + else: + return ArrayOnNumPyManager() def nd_iterator(shape): """Iterator over n-d cube with shape. @@ -628,8 +784,22 @@ def writable_array(obj, **kwargs): [2, 4, 6] """ arr = None + torch_impl = uses_pytorch(obj) try: - arr = np.asarray(obj, **kwargs) + if torch_impl: + if isinstance(obj, torch.Tensor): + arr = obj + elif hasattr(obj, 'data') and isinstance(obj.data, torch.Tensor): + arr = obj.data + else: + if hasattr(obj, 'data'): + if 'dtype' not in kwargs: + kwargs['dtype'] = obj.data.dtype + arr = torch.tensor(obj.data, **kwargs) + else: + arr = torch.tensor(obj, **kwargs) + else: + arr = np.asarray(obj, **kwargs) yield arr finally: if arr is not None: