Skip to content

add multichannel insertion #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 27, 2025
22 changes: 15 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ This package provides a simple, consistent API for
- sampling values from 2D/3D images (`sample_image_2d()`/`sample_image_3d()`)
- inserting values into 2D/3D images (`insert_into_image_2d()`/`insert_into_image_3d`)

Operations are differentiable and interpolating from or into complex valued images is supported.
Operations are differentiable, multichannel data and complex valued images are supported.

For sampling [
`torch.nn.functional.grid_sample`](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html)
is used under the hood.
[`torch.nn.functional.grid_sample`](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html)
is used under the hood for sampling.

# Installation

Expand All @@ -37,7 +36,7 @@ Fractional coordinates are supported and values are interpolated appropriately.

### 2D Images

For 2D images with shape `(h, w)`:
For 2D images with shape `(h, w)` or `(c, h, w)`:

Coordinates are ordered as `[y, x]` where:

Expand All @@ -48,7 +47,7 @@ For example, in a `(28, 28)` image, valid coordinates range from `[0, 0]` to `[2

### 3D Images

For 3D images with shape `(d, h, w)`:
For 3D images with shape `(d, h, w)` or `(c, d, h, w)`:

Coordinates are ordered as `[z, y, x]` where:

Expand Down Expand Up @@ -84,6 +83,11 @@ samples_bicubic = sample_image_2d(image=image, coordinates=coords, interpolation
The API is identical for 3D `(d, h, w)` images but takes `(..., 3)` arrays of
coordinates.

Sampling is supported for multichannel images in both 2D `(c, h, w)` and 3D `(c, d, h, w)`.
Sampling multichannel images returns `(..., c)` arrays of values.



## Insert into image

```python
Expand Down Expand Up @@ -111,9 +115,13 @@ image_nearest, weights_nearest = insert_into_image_2d(
)
```

The API is identical for 3D `(d, h, w)` images but takes `(..., 3)` arrays of
The API is identical for 3D `(d, h, w)` images but requires `(..., 3)` arrays of
coordinates.

Insertion of is supported for multichannel images in both 2D `(c, h, w)` and 3D `(c, d, h, w)`.
Inserting into multichannel images requires `(..., c)` arrays of values.


## Similar packages

- https://github.com/balbasty/torch-interpol
Expand Down
154 changes: 110 additions & 44 deletions src/torch_image_interpolation/image_interpolation_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch
import torch.nn.functional as F

from .grid_sample_utils import array_to_grid_sample
from torch_image_interpolation import utils
from .grid_sample_utils import array_to_grid_sample


def sample_image_2d(
Expand Down Expand Up @@ -113,13 +113,13 @@ def insert_into_image_2d(
Parameters
----------
values: torch.Tensor
`(...)` array of values to be inserted into `image`.
`(...)` or `(..., c)` array of values to be inserted into `image`.
coordinates: torch.Tensor
`(..., 2)` array of 2D coordinates for each value in `data`.
- Coordinates are ordered `yx` and are positions in the `h` and `w` dimensions respectively.
- Coordinates span the range `[0, N-1]` for a dimension of length N.
image: torch.Tensor
`(h, w)` array containing the image into which data will be inserted.
`(h, w)` or `(c, h, w)` array containing the image into which data will be inserted.
weights: torch.Tensor | None
`(h, w)` array containing weights associated with each pixel in `image`.
This is useful for tracking weights across multiple calls to this function.
Expand All @@ -131,64 +131,130 @@ def insert_into_image_2d(
image, weights: tuple[torch.Tensor, torch.Tensor]
The image and weights after updating with data from `data` at `coordinates`.
"""
if values.shape != coordinates.shape[:-1]:
# keep track of a few properties of the inputs
input_image_is_multichannel = image.ndim == 3
h, w = image.shape[-2:]

# validate inputs
values_shape = values.shape[:-1] if input_image_is_multichannel else values.shape
coordinates_shape, coordinates_ndim = coordinates.shape[:-1], coordinates.shape[-1]

if values_shape != coordinates_shape:
raise ValueError('One coordinate pair is required for each value in data.')
if coordinates.shape[-1] != 2:
raise ValueError('Coordinates must be of shape (..., 2).')
if coordinates_ndim != 2:
raise ValueError('Coordinates must be 2D with shape (..., 2).')
if weights is None:
weights = torch.zeros_like(image)
weights = torch.zeros(size=(h, w), dtype=torch.float32, device=image.device)

# add channel dim to both image and values if input image is not multichannel
if not input_image_is_multichannel:
image = einops.rearrange(image, 'h w -> 1 h w')
values = einops.rearrange(values, '... -> ... 1')

# linearise data and coordinates
values, _ = einops.pack([values], pattern='*')
coordinates, _ = einops.pack([coordinates], pattern='* zyx')
values, _ = einops.pack([values], pattern='* c')
coordinates, _ = einops.pack([coordinates], pattern='* yx')
coordinates = coordinates.float()

# only keep data and coordinates inside the image
upper_bound = torch.tensor(image.shape, device=image.device) - 1
image_shape = torch.tensor((h, w), device=image.device, dtype=torch.float32)
upper_bound = image_shape - 1
idx_inside = (coordinates >= 0) & (coordinates <= upper_bound)
idx_inside = torch.all(idx_inside, dim=-1)
values, coordinates = values[idx_inside], coordinates[idx_inside]

# splat data onto grid
if interpolation == 'nearest':
image = _insert_nearest_2d(values, coordinates, image, weights)
image, weights = _insert_nearest_2d(values, coordinates, image, weights)
if interpolation == 'bilinear':
image, weights = _insert_linear_2d(values, coordinates, image, weights)

# ensure correct output image shape
# single channel input -> (h, w)
# multichannel input -> (c, h, w)
if not input_image_is_multichannel:
image = einops.rearrange(image, '1 h w -> h w')

return image, weights


def _insert_nearest_2d(data, coordinates, image, weights):
def _insert_nearest_2d(
data, # (b, c)
coordinates, # (b, yx)
image, # (c, h, w)
weights # (h, w)
):
# b is number of data points to insert per channel, c is number of channels
b, c = data.shape

# flatten data to insert values for all channels with one call to _index_put()
data = einops.rearrange(data, 'b c -> b c')

# find nearest voxel for each coordinate
coordinates = torch.round(coordinates).long()
idx_y, idx_x = einops.rearrange(coordinates, 'b yx -> yx b')
image.index_put_(indices=(idx_y, idx_x), values=data, accumulate=False)
w = torch.ones(len(coordinates), device=weights.device, dtype=weights.dtype)
weights.index_put_(indices=(idx_y, idx_x), values=w, accumulate=True)
return image


def _insert_linear_2d(data, coordinates, image, weights):
# calculate and cache floor and ceil of coordinates for each value to be inserted
corner_coords = torch.empty(size=(data.shape[0], 2, 2), dtype=torch.long, device=image.device)
corner_coords[:, 0] = torch.floor(coordinates)
corner_coords[:, 1] = torch.ceil(coordinates)

# calculate linear interpolation weights for each data point being inserted
_weights = torch.empty(size=(data.shape[0], 2, 2), device=image.device
) # (b, 2, yx)
_weights[:, 1] = coordinates - corner_coords[:, 0] # upper corner weights
_weights[:, 0] = 1 - _weights[:, 1] # lower corner weights

# define function for adding weighted data at nearest 4 pixels to each coordinate
# make sure to do atomic adds, don't just override existing data at each position
def add_data_at_corner(y: Literal[0, 1], x: Literal[0, 1]):
w = einops.reduce(_weights[:, [y, x], [0, 1]], 'b yx -> b', reduction='prod')
idx_y, idx_x = einops.rearrange(corner_coords[:, [y, x], [0, 1]],'b yx -> yx b')
image.index_put_(indices=(idx_y, idx_x), values=w * data, accumulate=True)
weights.index_put_(indices=(idx_y, idx_x), values=w, accumulate=True)

# insert correctly weighted data at each of 4 nearest pixels then return
add_data_at_corner(0, 0)
add_data_at_corner(0, 1)
add_data_at_corner(1, 0)
add_data_at_corner(1, 1)
idx_h, idx_w = einops.rearrange(coordinates, 'b yx -> yx b')

# insert ones into weights image (h, w) at each position
w = torch.ones(size=(b, 1), device=weights.device, dtype=weights.dtype)

# setup indices for insertion
idx_c = torch.arange(c, device=coordinates.device, dtype=torch.long)
idx_c = einops.rearrange(idx_c, 'c -> 1 c')
idx_h = einops.rearrange(idx_h, 'b -> b 1')
idx_w = einops.rearrange(idx_w, 'b -> b 1')

# insert image data and weights
image.index_put_(indices=(idx_c, idx_h, idx_w), values=data, accumulate=True)
weights.index_put_(indices=(idx_h, idx_w), values=w, accumulate=True)
return image, weights


def _insert_linear_2d(
data, # (b, c)
coordinates, # (b, yx)
image, # (c, h, w)
weights # (h, w)
):
# b is number of data points to insert per channel, c is number of channels
b, c = data.shape

# cache corner coordinates for each value to be inserted
# C10---C11
# | P |
# C00---C01
coordinates = einops.rearrange(coordinates, 'b yx -> yx b')
y0, x0 = torch.floor(coordinates)
y1, x1 = torch.ceil(coordinates)

# populate arrays of corner indices
idx_h = torch.empty(size=(b, 2, 2), dtype=torch.long, device=image.device)
idx_w = torch.empty(size=(b, 2, 2), dtype=torch.long, device=image.device)

idx_h[:, 0, 0], idx_w[:, 0, 0] = y0, x0 # C00
idx_h[:, 0, 1], idx_w[:, 0, 1] = y0, x1 # C01
idx_h[:, 1, 0], idx_w[:, 1, 0] = y1, x0 # C10
idx_h[:, 1, 1], idx_w[:, 1, 1] = y1, x1 # C11

# calculate linear interpolation weights for each corner
y, x = coordinates
ty, tx = y - y0, x - x0 # fractional position between corners
w = torch.empty(size=(b, 2, 2), device=image.device)
w[:, 0, 0] = (1 - ty) * (1 - tx) # C00
w[:, 0, 1] = (1 - ty) * tx # C01
w[:, 1, 0] = ty * (1 - tx) # C10
w[:, 1, 1] = ty * tx # C11

# make sure indices broadcast correctly
idx_c = torch.arange(c, device=coordinates.device, dtype=torch.long)
idx_c = einops.rearrange(idx_c, 'c -> 1 c 1 1')
idx_h = einops.rearrange(idx_h, 'b h w -> b 1 h w')
idx_w = einops.rearrange(idx_w, 'b h w -> b 1 h w')

# insert weighted data and weight values at each corner across all channels
# make sure to do atomic adds
data = einops.rearrange(data, 'b c -> b c 1 1')
w = einops.rearrange(w, 'b h w -> b 1 h w')
image.index_put_(indices=(idx_c, idx_h, idx_w), values=w * data, accumulate=True)
weights.index_put_(indices=(idx_h, idx_w), values=w, accumulate=True)

return image, weights
Loading