Skip to content

Commit 8ab5f7e

Browse files
authored
fix cubic interpolation kwarg passthrough in 1d sampling (#26)
1 parent 281715f commit 8ab5f7e

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

src/torch_image_interpolation/image_interpolation_1d.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,12 @@ def sample_image_1d(
7171
# Take the samples
7272
# We need to convert the coordinates to grid_sample format
7373
coords_2d = array_to_grid_sample(coords_2d, array_shape=(2, w))
74-
mode = 'bilinear' if interpolation == 'linear' else interpolation
74+
if interpolation == 'linear':
75+
mode = 'bilinear'
76+
elif interpolation == 'cubic':
77+
mode = 'bicubic'
78+
else:
79+
mode = interpolation
7580
samples = F.grid_sample(
7681
input=image,
7782
grid=coords_2d,

tests/test_image_interpolation_1d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ def test_sample_image_1d():
1515
coords = torch.tensor(np.random.randint(low=0, high=27, size=(*arbitrary_shape,)))
1616

1717
# sample
18-
samples = sample_image_1d(image=image, coordinates=coords)
18+
for interpolation in ("nearest", "linear", "cubic"):
19+
samples = sample_image_1d(image=image, coordinates=coords, interpolation=interpolation)
1920
assert samples.shape == (6, 7, 8)
2021

2122

0 commit comments

Comments
 (0)