Skip to content

Commit eb6c388

Browse files
authored
Merge pull request #22 from McHaillet/fix_and_test_dtypes_insertion
dtype issue fix for insertion and additional checks
2 parents 70b4c6d + dd7c1c0 commit eb6c388

File tree

4 files changed

+92
-4
lines changed

4 files changed

+92
-4
lines changed

src/torch_image_interpolation/image_interpolation_2d.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ def insert_into_image_2d(
143143
raise ValueError('One coordinate pair is required for each value in data.')
144144
if coordinates_ndim != 2:
145145
raise ValueError('Coordinates must be 2D with shape (..., 2).')
146+
if image.dtype != values.dtype:
147+
raise ValueError('Image and values must have the same dtype.')
148+
146149
if weights is None:
147150
weights = torch.zeros(size=(h, w), dtype=torch.float32, device=image.device)
148151

@@ -238,7 +241,7 @@ def _insert_linear_2d(
238241
# calculate linear interpolation weights for each corner
239242
y, x = coordinates
240243
ty, tx = y - y0, x - x0 # fractional position between corners
241-
w = torch.empty(size=(b, 2, 2), device=image.device)
244+
w = torch.empty(size=(b, 2, 2), device=image.device, dtype=weights.dtype)
242245
w[:, 0, 0] = (1 - ty) * (1 - tx) # C00
243246
w[:, 0, 1] = (1 - ty) * tx # C01
244247
w[:, 1, 0] = ty * (1 - tx) # C10
@@ -254,7 +257,11 @@ def _insert_linear_2d(
254257
# make sure to do atomic adds
255258
data = einops.rearrange(data, 'b c -> b c 1 1')
256259
w = einops.rearrange(w, 'b h w -> b 1 h w')
257-
image.index_put_(indices=(idx_c, idx_h, idx_w), values=w * data, accumulate=True)
260+
image.index_put_(
261+
indices=(idx_c, idx_h, idx_w),
262+
values=data * w.to(data.dtype),
263+
accumulate=True
264+
)
258265
weights.index_put_(indices=(idx_h, idx_w), values=w, accumulate=True)
259266

260267
return image, weights

src/torch_image_interpolation/image_interpolation_3d.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ def insert_into_image_3d(
143143
raise ValueError('One coordinate triplet is required for each value in data.')
144144
if coordinates_ndim != 3:
145145
raise ValueError('Coordinates must be 3D with shape (..., 3).')
146+
if image.dtype != values.dtype:
147+
raise ValueError('Image and values must have the same dtype.')
148+
146149
if weights is None:
147150
weights = torch.zeros(size=(d, h, w), dtype=torch.float32, device=image.device)
148151

@@ -241,7 +244,7 @@ def _insert_linear_3d(
241244
# calculate trilinear interpolation weights for each corner
242245
z, y, x = coordinates
243246
tz, ty, tx = z - z0, y - y0, x - x0 # fractional position between voxel corners
244-
w = torch.empty(size=(b, 2, 2, 2), device=image.device)
247+
w = torch.empty(size=(b, 2, 2, 2), device=image.device, dtype=weights.dtype)
245248

246249
w[:, 0, 0, 0] = (1 - tz) * (1 - ty) * (1 - tx) # C000
247250
w[:, 0, 0, 1] = (1 - tz) * (1 - ty) * tx # C001
@@ -262,7 +265,11 @@ def _insert_linear_3d(
262265
# insert weighted data and weight values at each corner
263266
data = einops.rearrange(data, 'b c -> b c 1 1 1')
264267
w = einops.rearrange(w, 'b z y x -> b 1 z y x')
265-
image.index_put_(indices=(idx_c, idx_z, idx_y, idx_x), values=w * data, accumulate=True)
268+
image.index_put_(
269+
indices=(idx_c, idx_z, idx_y, idx_x),
270+
values=data * w.to(data.dtype),
271+
accumulate=True
272+
)
266273
weights.index_put_(indices=(idx_z, idx_y, idx_x), values=w, accumulate=True)
267274

268275
return image, weights

tests/test_image_interpolation_2d.py

+37
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import einops
22
import numpy as np
33
import torch
4+
import pytest
45

56
from torch_image_interpolation import sample_image_2d, insert_into_image_2d
67

@@ -153,3 +154,39 @@ def test_insert_into_image_nearest_interp_2d():
153154
expected = torch.zeros((28, 28)).float()
154155
expected[11, 14] = 5
155156
assert torch.allclose(image, expected)
157+
158+
159+
@pytest.mark.parametrize(
160+
"dtype",
161+
[torch.float32, torch.float64, torch.complex64, torch.complex128]
162+
)
163+
def test_insert_into_image_2d_type_consistency(dtype):
164+
image = torch.rand((4, 4), dtype=dtype)
165+
coords = torch.tensor(np.random.uniform(low=0, high=3, size=(3, 4, 2)))
166+
values = torch.rand(size=(3, 4), dtype=dtype)
167+
# cast the dtype to corresponding float for weights
168+
weights = torch.zeros_like(image, dtype=torch.float64)
169+
170+
for mode in ['bilinear', 'nearest']:
171+
image, weights = insert_into_image_2d(
172+
values,
173+
image=image,
174+
weights=weights,
175+
coordinates=coords,
176+
interpolation=mode,
177+
)
178+
assert image.dtype == dtype
179+
assert weights.dtype == torch.float64
180+
181+
182+
def test_insert_into_image_3d_type_error():
183+
image = torch.rand((4, 4), dtype=torch.complex64)
184+
coords = torch.tensor(np.random.uniform(low=0, high=3, size=(3, 4, 2)))
185+
values = torch.rand(size=(3, 4), dtype=torch.complex128)
186+
# cast the dtype to corresponding float for weights
187+
with pytest.raises(ValueError):
188+
insert_into_image_2d(
189+
values,
190+
image=image,
191+
coordinates=coords,
192+
)

tests/test_image_interpolation_3d.py

+37
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import torch
33
import einops
4+
import pytest
45

56
from torch_image_interpolation import sample_image_3d, insert_into_image_3d
67

@@ -154,3 +155,39 @@ def test_insert_multiple_values_into_multichannel_image_2d_nearest():
154155
# check output shapes
155156
assert image.shape == (n_channels, 28, 28, 28)
156157
assert weights.shape == (28, 28, 28)
158+
159+
160+
@pytest.mark.parametrize(
161+
"dtype",
162+
[torch.float32, torch.float64, torch.complex64, torch.complex128]
163+
)
164+
def test_insert_into_image_2d_type_consistency(dtype):
165+
image = torch.rand((4, 4, 4), dtype=dtype)
166+
coords = torch.tensor(np.random.uniform(low=0, high=3, size=(3, 4, 5, 3)))
167+
values = torch.rand(size=(3, 4, 5), dtype=dtype)
168+
# cast the dtype to corresponding float for weights
169+
weights = torch.zeros_like(image, dtype=torch.float64)
170+
171+
for mode in ['bilinear', 'nearest']:
172+
image, weights = insert_into_image_3d(
173+
values,
174+
image=image,
175+
weights=weights,
176+
coordinates=coords,
177+
interpolation=mode,
178+
)
179+
assert image.dtype == dtype
180+
assert weights.dtype == torch.float64
181+
182+
183+
def test_insert_into_image_3d_type_error():
184+
image = torch.rand((4, 4, 4), dtype=torch.complex64)
185+
coords = torch.tensor(np.random.uniform(low=0, high=3, size=(3, 4, 5, 3)))
186+
values = torch.rand(size=(3, 4, 5), dtype=torch.complex128)
187+
# cast the dtype to corresponding float for weights
188+
with pytest.raises(ValueError):
189+
insert_into_image_3d(
190+
values,
191+
image=image,
192+
coordinates=coords,
193+
)

0 commit comments

Comments
 (0)