Skip to content

Backend/pytorch arrays 2 #1679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 52 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
426a7e5
Copy the numpy-tensors module, for migration to pytorch.
leftaroundabout Mar 18, 2024
7bc54d7
A basic PyTorch pendant to NumpyTensorSpace.
leftaroundabout Mar 19, 2024
3ff5a97
Some numpy->pytorch changes that slipped past the previous commit.
leftaroundabout Mar 21, 2024
0e0e9b2
An example for an operator with Torch implementation.
leftaroundabout Apr 2, 2024
b74e8e9
Fix inconsistency about backend/impl naming.
leftaroundabout Apr 3, 2024
2b92872
Adapt the deconvolution example to something more comparable with com…
leftaroundabout Apr 8, 2024
ca336cc
More proper result-showing in examples for solvers.
leftaroundabout Apr 9, 2024
ee003d3
PyTorch version of the deconvolution example.
leftaroundabout Apr 9, 2024
6f6df0b
More consistent naming for the PyTorch versions of examples.
leftaroundabout Apr 9, 2024
595e166
Generalise some utility to support PyTorch in addition to NumPy.
leftaroundabout Apr 9, 2024
f650814
Implementation-consistent types for the `asarray` method.
leftaroundabout Apr 12, 2024
45386ab
A generic way of obtaining a compatible scalar dtype for various things.
leftaroundabout Apr 12, 2024
c3fe2c5
Attempt at making gradient operators compatible with PyTorch.
leftaroundabout Apr 12, 2024
e244f0e
More flexible plotting in 1D example.
leftaroundabout Apr 17, 2024
5227dac
Methods for converting numbers to scalars.
leftaroundabout Apr 19, 2024
eefca31
Multiplication operators with explicitly selected scalar types.
leftaroundabout Jun 12, 2024
b6975f8
Sketch of a `ufuncs` version for PyTorch.
leftaroundabout Jun 12, 2024
40b8d46
Propose using PyTorch convolution for finite-differences.
leftaroundabout Jun 17, 2024
7783af8
Correct axis association of the convolution FDs.
leftaroundabout Jun 17, 2024
b8d0e04
PyTorch version of finite-difference grad etc..
leftaroundabout Jun 17, 2024
3392da0
Consistent use of PyTorch finite_diff also for divergence operator.
leftaroundabout Jun 17, 2024
e47ff95
Abolish in-place updates for PyTorch in PDHG.
leftaroundabout Jun 18, 2024
21d0694
Update PDHG example and enable PyTorch in it.
leftaroundabout Jun 18, 2024
090b875
Add the torch device as a parameter to tensor spaces.
leftaroundabout Jun 18, 2024
76caa8e
Refactor finite-difference kernels.
leftaroundabout Jun 18, 2024
f053134
Use correct Torch device for FD convolutions.
leftaroundabout Jun 18, 2024
0845126
Make `tensor_impl_args` compatible (albeit empty) on NumPy.
leftaroundabout Jun 18, 2024
e815813
GPU-compatible conversions to NumPy.
leftaroundabout Jun 21, 2024
e12c2db
Refactor Fourier trafo classes.
leftaroundabout Aug 21, 2024
d8523e9
Move the lookup dict for PyTorch dtypes to a more global level.
leftaroundabout Oct 4, 2024
afeaf52
Propose a backend-agnostic ways of checking dtype compatibility.
leftaroundabout Oct 4, 2024
21cfabd
Generalize array resizing to Torch.
leftaroundabout Oct 4, 2024
60e0b84
One more example using PyTorch storage.
leftaroundabout Oct 7, 2024
aa1d5b6
Support PyTorch in the dtype-categorization utils.
leftaroundabout Oct 7, 2024
c97c0ac
Make `as_writable_array` handle and reinstore PyTorch-based elements.
leftaroundabout Oct 7, 2024
f5a45de
Default Fourier implementation should be based on the space.
leftaroundabout Oct 8, 2024
7a176ae
Helpers for generating / converting arrays on NumPy or PyTorch as app…
leftaroundabout Oct 9, 2024
bf3eb16
More dtype categorisation with PyTorch.
leftaroundabout Oct 9, 2024
298f767
Make `fast_1d_tensor_mult` PyTorch-compatible.
leftaroundabout Oct 9, 2024
ec19e0f
Make some Fourier utils PyTorch-compatible.
leftaroundabout Oct 9, 2024
c21ef4d
Make Fourier transforms robust towards non-NumPy array storage.
leftaroundabout Oct 9, 2024
5f4b76b
Correct bug in Fourier post-processing.
leftaroundabout Oct 14, 2024
1ed2d6b
Add Fourier methods using PyTorch.
leftaroundabout Oct 14, 2024
068dbc7
Support the PyTorch-based Fourier transforms.
leftaroundabout Oct 14, 2024
c260a19
Generalize Fourier tests to support arrays other than NumPy.
leftaroundabout Oct 14, 2024
37dbb76
Add PyTorch to Fourier unit tests.
leftaroundabout Oct 14, 2024
74c75a5
Add half-precision dtypes for PyTorch.
leftaroundabout Oct 14, 2024
8c0dad6
Avoid Torch warning/error messages when managing arrays that are alre…
leftaroundabout Oct 21, 2024
2a0af48
Use the `ArrayOnBackendManager` classes for generating PyTorch-based …
leftaroundabout Oct 21, 2024
42b07bc
Implemented the methods related to in- vs out-of-place selection for …
leftaroundabout Dec 3, 2024
a6ace7e
Add a `copy` argument to the array-manager's dtype-switching method.
leftaroundabout Apr 30, 2025
58355f0
Start adapting the tensor space tests for PyTorch.
leftaroundabout Apr 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/operator/convolution_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ def adjoint(self):
# Display the results using the show method
kernel.show('kernel')
phantom.show('phantom')
g.show('convolved phantom')
g.show('convolved phantom', force_show=True)
65 changes: 65 additions & 0 deletions examples/operator/convolution_operator_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Create a convolution operator by wrapping a library."""

import odl
import numpy as np
import torch


class Convolution(odl.Operator):
"""Operator calculating the convolution of a kernel with a function.

The operator inherits from ``odl.Operator`` to be able to be used with ODL.
"""

def __init__(self, kernel, domain, range):
"""Initialize a convolution operator with a known kernel."""

# Store the kernel
self.kernel = kernel

# Initialize the Operator class by calling its __init__ method.
# This sets properties such as domain and range and allows the other
# operator convenience functions to work.
super(Convolution, self).__init__(
domain=domain, range=range, linear=True)

def _call(self, x):
"""Implement calling the operator by calling PyTorch."""
return self.range.element(torch.conv2d( input=x.data.unsqueeze(0)
, weight=self.kernel.unsqueeze(0).unsqueeze(0)
, stride=(1,1)
, padding="same"
).squeeze(0)
)

@property
def adjoint(self):
"""Implement ``self.adjoint``.

For a convolution operator, the adjoint is given by the convolution
with a kernel with flipped axes. In particular, if the kernel is
symmetric the operator is self-adjoint.
"""
return Convolution( torch.flip(self.kernel, dims=(0,1))
, domain=self.range, range=self.domain )


# Define the space on which the problem should be solved
# Here the square [-1, 1] x [-1, 1] discretized on a 100x100 grid
space = odl.uniform_discr([-1, -1], [1, 1], [100, 100], impl='pytorch', dtype=np.float32)

# Convolution kernel, a small centered rectangle
kernel = torch.ones((5,5))

# Create convolution operator
A = Convolution(kernel, domain=space, range=space)

# Create phantom (the "unknown" solution)
phantom = odl.phantom.shepp_logan(space, modified=True)

# Apply convolution to phantom to create data
g = A.adjoint(phantom)

# Display the results using the show method
phantom.show('phantom')
g.show('convolved phantom', force_show=True)
41 changes: 25 additions & 16 deletions examples/solvers/deconvolution_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ def opnorm(self):


# Discretization
discr_space = odl.uniform_discr(0, 10, 500, impl='numpy')
discr_space = odl.uniform_discr(-5, 5, 500, impl='numpy')

# Complicated functions to check performance
kernel = discr_space.element(lambda x: np.exp(x / 2) * np.cos(x * 1.172))
phantom = discr_space.element(lambda x: x ** 2 * np.sin(x) ** 2 * (x > 5))
kernel = discr_space.element(lambda x: np.exp(-x**2 * 2) * np.cos(x * 1.172))

# phantom = discr_space.element(lambda x: (x+5) ** 2 * np.sin(x+5) ** 2 * (x > 0))
phantom = discr_space.element(lambda x: np.cos(0*x) * (x > -1) * (x < 1))

# Create operator
conv = Convolution(kernel)
Expand All @@ -41,21 +43,28 @@ def opnorm(self):
omega = 1 / conv.opnorm() ** 2


# Display callback
def callback(x):
plt.plot(conv(x))

def test_with_plot(conv, phantom, solver, **extra_args):
fig, axs = plt.subplots(2)
fig.suptitle("CGN")
axs[0].set_title("x")
axs[1].set_title("k*x")
axs[0].plot(phantom)
axs[1].plot(conv(phantom))
def plot_callback(x):
axs[0].plot(conv(x), '--')
axs[1].plot(conv(x), '--')
solver(conv, discr_space.zero(), phantom, iterations, callback=plot_callback, **extra_args)

# Test CGN
plt.figure()
plt.plot(phantom)
odl.solvers.conjugate_gradient_normal(conv, discr_space.zero(), phantom,
iterations, callback)
test_with_plot(conv, phantom, odl.solvers.conjugate_gradient_normal)

# Landweber
plt.figure()
plt.plot(phantom)
odl.solvers.landweber(conv, discr_space.zero(), phantom,
iterations, omega, callback)
# test_with_plot(conv, phantom, odl.solvers.landweber, omega=omega)

# # Landweber
# lw_fig, lw_axs = plt.subplots(1)
# lw_fig.suptitle("Landweber")
# lw_axs.plot(phantom)
# odl.solvers.landweber(conv, discr_space.zero(), phantom,
# iterations, omega, lambda x: lw_axs.plot(conv(x)))
#
plt.show()
91 changes: 91 additions & 0 deletions examples/solvers/deconvolution_1d_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Example of a deconvolution problem with different solvers (CPU)."""

import numpy as np
import torch
import matplotlib.pyplot as plt
import scipy.signal
import odl


class Convolution(odl.Operator):
def __init__(self, kernel, domain, range, adjkernel=None):
self.kernel = kernel
self.adjkernel = torch.flip(kernel, dims=(0,)) if adjkernel is None else adjkernel
self.norm = float(torch.sum(torch.abs(self.kernel)))
super(Convolution, self).__init__(
domain=domain, range=range, linear=True)

def _call(self, x):
return self.range.element(
torch.conv1d( input=x.data.unsqueeze(0)
, weight=self.kernel.unsqueeze(0).unsqueeze(0)
, stride=1
, padding="same"
).squeeze(0)
)

@property
def adjoint(self):
return Convolution( self.adjkernel
, domain=self.range, range=self.domain
, adjkernel = self.kernel
)

def opnorm(self):
return self.norm


resolution = 50

# Discretization
discr_space = odl.uniform_discr(-5, 5, resolution*10, impl='pytorch', dtype=np.float32)

# Complicated functions to check performance
def mk_kernel():
q = 1.172
# Select main lobe and one side lobe on each side
r = np.ceil(3*np.pi/(2*q))
# Quantised to resolution
nr = int(np.ceil(r*resolution))
r = nr / resolution
x = torch.linspace(-r, r, nr*2 + 1)
return torch.exp(-x**2 * 2) * np.cos(x * q)
kernel = mk_kernel()

phantom = discr_space.element(lambda x: np.ones_like(x) ** 2 * (x > -1) * (x < 1))
# phantom = discr_space.element(lambda x: x ** 2 * np.sin(x) ** 2 * (x > 5))

# Create operator
conv = Convolution(kernel, domain=discr_space, range=discr_space)

# Dampening parameter for landweber
iterations = 100
omega = 1 / conv.opnorm() ** 2



def test_with_plot(conv, phantom, solver, **extra_args):
fig, axs = plt.subplots(2)
fig.suptitle("CGN")
def plot_fn(ax_id, fn, *plot_args, **plot_kwargs):
axs[ax_id].plot(fn, *plot_args, **plot_kwargs)
axs[0].set_title("x")
axs[1].set_title("k*x")
plot_fn(0, phantom)
plot_fn(1, conv(phantom))
def plot_callback(x):
plot_fn(0, conv(x), '--')
plot_fn(1, conv(x), '--')
solver(conv, discr_space.zero(), phantom, iterations, callback=plot_callback, **extra_args)

# Test CGN
test_with_plot(conv, phantom, odl.solvers.conjugate_gradient_normal)

# # Landweber
# lw_fig, lw_axs = plt.subplots(1)
# lw_fig.suptitle("Landweber")
# lw_axs.plot(phantom)
# odl.solvers.landweber(conv, discr_space.zero(), phantom,
# iterations, omega, lambda x: lw_axs.plot(conv(x)))

plt.show()
15 changes: 8 additions & 7 deletions examples/solvers/pdhg_denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,29 @@
"""

import numpy as np
import torch
import scipy.misc
import odl

impl = 'numpy'
# impl = 'pytorch'

# Read test image: use only every second pixel, convert integer to float,
# and rotate to get the image upright
image = np.rot90(scipy.misc.ascent()[::2, ::2], 3).astype('float')
image = np.rot90(scipy.datasets.ascent()[::2, ::2], 3).astype('float')
shape = image.shape

# Rescale max to 1
image /= image.max()

# Discretized spaces
space = odl.uniform_discr([0, 0], shape, shape)
space = odl.uniform_discr([0, 0], shape, shape, impl=impl)

# Original image
orig = space.element(image)
orig = space.element(image.copy())

# Add noise
image += 0.1 * odl.phantom.white_noise(orig.space)

# Data of noisy image
noisy = space.element(image)
noisy = space.element(image) + 0.1 * odl.phantom.white_noise(orig.space)

# Gradient operator
gradient = odl.Gradient(space)
Expand Down
97 changes: 97 additions & 0 deletions examples/solvers/pdhg_denoising_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Total variation denoising using PDHG.

Solves the optimization problem

min_{x >= 0} 1/2 ||x - g||_2^2 + lam || |grad(x)| ||_1

Where ``grad`` the spatial gradient and ``g`` is given noisy data.

For further details and a description of the solution method used, see
https://odlgroup.github.io/odl/guide/pdhg_guide.html in the ODL documentation.
"""

import numpy as np
import torch
import scipy.misc
import odl
import cProfile

# Read test image: use only every second pixel, convert integer to float,
# and rotate to get the image upright
image = np.rot90(scipy.datasets.ascent()[::2, ::2], 3).astype('float')
shape = image.shape

# Rescale max to 1
image /= image.max()

# Discretized spaces
space = odl.uniform_discr([0, 0], shape, shape, impl='pytorch')

# Original image
orig = space.element(image.copy())

orig.data.requires_grad = False

# Add noise
noisy = space.element(image) + 0.1 * odl.phantom.white_noise(orig.space)

noisy.data.requires_grad = False

# Gradient operator
gradient = odl.Gradient(space)

# grad_xmp = gradient(orig)
# grad_xmp.show(title = "Grad-op applied to original")

# Matrix of operators
op = odl.BroadcastOperator(odl.IdentityOperator(space), gradient)

# Set up the functionals

# l2-squared data matching
l2_norm = odl.solvers.L2NormSquared(space).translated(noisy)

# Isotropic TV-regularization: l1-norm of grad(x)
l1_norm = 0.15 * odl.solvers.L1Norm(gradient.range)

# Make separable sum of functionals, order must correspond to the operator K
g = odl.solvers.SeparableSum(l2_norm, l1_norm)

# Non-negativity constraint
f = odl.solvers.IndicatorNonnegativity(op.domain)

# --- Select solver parameters and solve using PDHG --- #

# Estimated operator norm, add 10 percent to ensure ||K||_2^2 * sigma * tau < 1
op_norm = 1.1 * odl.power_method_opnorm(op, xstart=noisy, maxiter=10)
# 3.2833764101732785
print(f"{op_norm=}")

niter = 200 # Number of iterations
tau = 1.0 / op_norm # Step size for the primal variable
sigma = 1.0 / op_norm # Step size for the dual variable

# Optional: pass callback objects to solver
callback = (odl.solvers.CallbackPrintIteration() &
odl.solvers.CallbackShow(step=5))

# Starting point
x = op.domain.zero()

x.data.requires_grad = False

print("Go solve...")

# Run algorithm (and display intermediates)
def do_running():
with torch.no_grad():
odl.solvers.pdhg(x, f, g, op, niter=niter, tau=tau, sigma=sigma,
callback=callback)

do_running()
# cProfile.run('do_running()')

# Display images
orig.show(title='Original Image')
noisy.show(title='Noisy Image')
x.show(title='Reconstruction', force_show=True)
Loading