Skip to content

Commit 78c8efd

Browse files
committed
Make test pass
1 parent fa99f20 commit 78c8efd

File tree

4 files changed

+45
-38
lines changed

4 files changed

+45
-38
lines changed

torchax/test/test_image.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from absl.testing import parameterized
12
import unittest
23
from typing import Tuple
34
import itertools
@@ -23,12 +24,20 @@ def upsample_jit(tensor, output_size: Tuple[int, int], align_corners: bool, anti
2324
return torchax.interop.jax_view(tensor)
2425

2526

26-
def test_upsampling(align_corners: bool, antialias: bool, method: str):
2727

28-
if method == 'bilinear':
29-
if align_corners:
30-
return # bilinear upsampling does not support align_corners
31-
28+
class TestResampling(parameterized.TestCase):
29+
30+
@parameterized.product(
31+
antialias=[
32+
True,
33+
False,
34+
],
35+
align_corners=[
36+
False,
37+
True,
38+
])
39+
def test_resampling_combinations_bicubic(self, antialias, align_corners):
40+
method = 'bicubic'
3241
input_tensor = torch.rand((1, 1, 256, 512), dtype=torch.float32)
3342
output_size = (128, 64)
3443

@@ -44,26 +53,6 @@ def test_upsampling(align_corners: bool, antialias: bool, method: str):
4453

4554
assert torch.allclose(upsampled_tensor, upsampled_tensor_xla, atol=1e-4, rtol=1e-5), f"{method} upsampling failed with error {abs_err.max()}"
4655

47-
class TestResampling(unittest.TestCase):
48-
def test_resampling_combinations(self):
49-
methods = [
50-
'bicubic',
51-
'bilinear',
52-
]
53-
antialias_options = [
54-
True,
55-
False,
56-
]
57-
58-
aligncorners_options = [
59-
False,
60-
True,
61-
]
62-
63-
for method, antialias, align_corners in itertools.product(methods, antialias_options, aligncorners_options):
64-
with self.subTest(method=method, antialias=antialias, align_corners=align_corners):
65-
test_upsampling(align_corners=align_corners, antialias=antialias, method=method)
66-
6756

6857

6958
if __name__ == '__main__':

torchax/torchax/ops/jaten.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5270,6 +5270,25 @@ def _aten_upsample_billinear_aa(input,
52705270
scales_w
52715271
)
52725272

5273+
@op(torch.ops.aten._upsample_bicubic2d_aa)
5274+
def _aten_upsample_bicubic2d_aa(input,
5275+
output_size,
5276+
align_corners,
5277+
scale_factors=None,
5278+
scales_h=None,
5279+
scales_w=None):
5280+
return _aten_upsample(
5281+
input,
5282+
output_size,
5283+
align_corners,
5284+
True, # antialias
5285+
"bicubic", # method
5286+
scale_factors,
5287+
scales_h,
5288+
scales_w
5289+
)
5290+
5291+
52735292
@op(torch.ops.aten.polar)
52745293
def _aten_polar(abs, angle, *, out=None):
52755294
return jax.lax.complex(abs * jnp.cos(angle), abs * jnp.sin(angle))

torchax/torchax/ops/jimage.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ def cubic_kernel(x, a=-0.75):
1313
return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0))
1414

1515

16-
def compute_contribs(in_size, out_size, scale, support=2.0, align_corners=False):
16+
def compute_contribs(in_size, out_size, scale, support=2.0, align_corners=False, dtype=None):
1717
if align_corners:
1818
if out_size == 1:
19-
in_coords = jnp.zeros((1,))
19+
in_coords = jnp.zeros((1,), dtype=dtype)
2020
else:
21-
in_coords = jnp.linspace(0, in_size - 1, out_size)
21+
in_coords = jnp.linspace(0, in_size - 1, out_size, dtype=dtype)
2222
else:
23-
out_coords = jnp.arange(out_size) + 0.5
23+
out_coords = jnp.arange(out_size, dtype=dtype) + 0.5
2424
in_coords = out_coords / scale - 0.5
2525

2626
left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1
@@ -86,11 +86,13 @@ def interpolate_bicubic_no_aa(img, out_h, out_w, align_corners=False):
8686

8787
idxs_y, weights_y = compute_contribs(
8888
h, out_h, scale_y, align_corners=align_corners,
89+
dtype=img.dtype,
8990
)
9091
tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2)
9192

9293
idxs_x, weights_x = compute_contribs(
9394
w, out_w, scale_x, align_corners=align_corners,
95+
dtype=img.dtype,
9496
)
9597
out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3)
9698
return out

torchax/torchax/ops/jtorch.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -545,18 +545,15 @@ def functional_interpolate(
545545
antialias = antialias or False
546546
align_corners = align_corners or False
547547

548-
if mode in ('cubic', 'bicubic', 'tricubic') and not antialias:
548+
if mode in ('cubic', 'bicubic', 'tricubic') and not antialias and size is not None:
549549
return jimage.interpolate_bicubic_no_aa(
550550
input,
551551
size[0],
552552
size[1],
553553
align_corners,
554554
)
555-
return jaten._aten_upsample(
556-
input,
557-
size,
558-
align_corners,
559-
antialias,
560-
mode,
561-
scale_factor,
562-
)
555+
else:
556+
# fallback
557+
raise torchax.tensor.OperatorNotFound(
558+
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
559+
)

0 commit comments

Comments
 (0)