Skip to content

Commit 622b922

Browse files
committed
Make test pass
1 parent 5fb3f2e commit 622b922

File tree

4 files changed

+170
-133
lines changed

4 files changed

+170
-133
lines changed

torchax/test/test_image.py

Lines changed: 48 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from absl.testing import parameterized
12
import unittest
23
from typing import Tuple
34
import itertools
@@ -8,63 +9,67 @@
89
import torchax
910
import torchax.interop
1011

12+
1113
def to_xla_tensor(tensorstree):
1214
return torchax.interop.torch_view(torchax.tensor.t2j(tensorstree))
1315

16+
1417
def to_torch_tensor(tensorstree):
1518
return torchax.tensor.j2t(torchax.interop.jax_view(tensorstree))
1619

1720

18-
1921
@partial(jax.jit, static_argnums=(1, 2, 3, 4))
20-
def upsample_jit(tensor, output_size: Tuple[int, int], align_corners: bool, antialias: bool, method: str):
21-
tensor = torchax.interop.torch_view(tensor)
22-
tensor = torch.nn.functional.interpolate(tensor, size=output_size, mode=method, align_corners=align_corners, antialias=antialias)
23-
return torchax.interop.jax_view(tensor)
24-
25-
26-
def test_upsampling(align_corners: bool, antialias: bool, method: str):
27-
28-
if method == 'bilinear':
29-
if align_corners:
30-
return # bilinear upsampling does not support align_corners
31-
22+
def upsample_jit(tensor, output_size: Tuple[int, int], align_corners: bool,
23+
antialias: bool, method: str):
24+
tensor = torchax.interop.torch_view(tensor)
25+
tensor = torch.nn.functional.interpolate(
26+
tensor,
27+
size=output_size,
28+
mode=method,
29+
align_corners=align_corners,
30+
antialias=antialias)
31+
return torchax.interop.jax_view(tensor)
32+
33+
34+
class TestResampling(parameterized.TestCase):
35+
36+
@parameterized.product(
37+
antialias=[
38+
True,
39+
False,
40+
], align_corners=[
41+
False,
42+
True,
43+
])
44+
def test_resampling_combinations_bicubic(self, antialias, align_corners):
45+
method = 'bicubic'
3246
input_tensor = torch.rand((1, 1, 256, 512), dtype=torch.float32)
3347
output_size = (128, 64)
3448

35-
upsampled_tensor = torch.nn.functional.interpolate(input_tensor, size=output_size, mode=method, align_corners=align_corners, antialias=antialias)
36-
49+
upsampled_tensor = torch.nn.functional.interpolate(
50+
input_tensor,
51+
size=output_size,
52+
mode=method,
53+
align_corners=align_corners,
54+
antialias=antialias)
55+
3756
with torchax.default_env():
3857
input_tensor_xla = to_xla_tensor(input_tensor)
3958
input_tensor_xla = torchax.interop.jax_view(input_tensor_xla)
40-
upsampled_tensor_xla = upsample_jit(input_tensor_xla, output_size, align_corners, antialias=antialias, method=method)
41-
59+
upsampled_tensor_xla = upsample_jit(
60+
input_tensor_xla,
61+
output_size,
62+
align_corners,
63+
antialias=antialias,
64+
method=method)
65+
4266
upsampled_tensor_xla = to_torch_tensor(upsampled_tensor_xla)
4367
abs_err = torch.abs(upsampled_tensor - upsampled_tensor_xla)
44-
45-
assert torch.allclose(upsampled_tensor, upsampled_tensor_xla, atol=1e-4, rtol=1e-5), f"{method} upsampling failed with error {abs_err.max()}"
46-
47-
class TestResampling(unittest.TestCase):
48-
def test_resampling_combinations(self):
49-
methods = [
50-
'bicubic',
51-
'bilinear',
52-
]
53-
antialias_options = [
54-
True,
55-
False,
56-
]
57-
58-
aligncorners_options = [
59-
False,
60-
True,
61-
]
62-
63-
for method, antialias, align_corners in itertools.product(methods, antialias_options, aligncorners_options):
64-
with self.subTest(method=method, antialias=antialias, align_corners=align_corners):
65-
test_upsampling(align_corners=align_corners, antialias=antialias, method=method)
66-
67-
68-
68+
69+
assert torch.allclose(
70+
upsampled_tensor, upsampled_tensor_xla, atol=1e-4,
71+
rtol=1e-5), f"{method} upsampling failed with error {abs_err.max()}"
72+
73+
6974
if __name__ == '__main__':
70-
unittest.main()
75+
unittest.main()

torchax/torchax/ops/jaten.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5263,12 +5263,30 @@ def _aten_upsample_billinear_aa(input,
52635263
input,
52645264
output_size,
52655265
align_corners,
5266-
True, # antialias
5267-
"bilinear", # method
5266+
True, # antialias
5267+
"bilinear", # method
52685268
scale_factors,
52695269
scales_h,
5270-
scales_w
5271-
)
5270+
scales_w)
5271+
5272+
5273+
@op(torch.ops.aten._upsample_bicubic2d_aa)
5274+
def _aten_upsample_bicubic2d_aa(input,
5275+
output_size,
5276+
align_corners,
5277+
scale_factors=None,
5278+
scales_h=None,
5279+
scales_w=None):
5280+
return _aten_upsample(
5281+
input,
5282+
output_size,
5283+
align_corners,
5284+
True, # antialias
5285+
"bicubic", # method
5286+
scale_factors,
5287+
scales_h,
5288+
scales_w)
5289+
52725290

52735291
@op(torch.ops.aten.polar)
52745292
def _aten_polar(abs, angle, *, out=None):

torchax/torchax/ops/jimage.py

Lines changed: 87 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,113 @@
11
import jax
22
import jax.numpy as jnp
33

4+
45
def cubic_kernel(x, a=-0.75):
5-
"""Cubic kernel with a = -0.75 (PyTorch-like Keys kernel)"""
6-
absx = jnp.abs(x)
7-
x2 = absx * absx
8-
x3 = x2 * absx
9-
cond1 = (absx <= 1)
10-
cond2 = (absx > 1) & (absx < 2)
11-
f1 = (a + 2) * x3 - (a + 3) * x2 + 1
12-
f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a
13-
return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0))
14-
15-
16-
def compute_contribs(in_size, out_size, scale, support=2.0, align_corners=False):
17-
if align_corners:
18-
if out_size == 1:
19-
in_coords = jnp.zeros((1,))
20-
else:
21-
in_coords = jnp.linspace(0, in_size - 1, out_size)
6+
"""Cubic kernel with a = -0.75 (PyTorch-like Keys kernel)"""
7+
absx = jnp.abs(x)
8+
x2 = absx * absx
9+
x3 = x2 * absx
10+
cond1 = (absx <= 1)
11+
cond2 = (absx > 1) & (absx < 2)
12+
f1 = (a + 2) * x3 - (a + 3) * x2 + 1
13+
f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a
14+
return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0))
15+
16+
17+
def compute_contribs(in_size,
18+
out_size,
19+
scale,
20+
support=2.0,
21+
align_corners=False,
22+
dtype=None):
23+
if align_corners:
24+
if out_size == 1:
25+
in_coords = jnp.zeros((1,), dtype=dtype)
2226
else:
23-
out_coords = jnp.arange(out_size) + 0.5
24-
in_coords = out_coords / scale - 0.5
27+
in_coords = jnp.linspace(0, in_size - 1, out_size, dtype=dtype)
28+
else:
29+
out_coords = jnp.arange(out_size, dtype=dtype) + 0.5
30+
in_coords = out_coords / scale - 0.5
31+
32+
left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1
33+
idxs = left_idx[:, None] + jnp.arange(4)
2534

26-
left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1
27-
idxs = left_idx[:, None] + jnp.arange(4)
35+
dx = in_coords[:, None] - idxs
2836

29-
dx = in_coords[:, None] - idxs
37+
weights = cubic_kernel(dx)
3038

31-
weights = cubic_kernel(dx)
39+
weights = weights / jnp.sum(weights, axis=1, keepdims=True)
40+
return idxs, weights
3241

33-
weights = weights / jnp.sum(weights, axis=1, keepdims=True)
34-
return idxs, weights
3542

3643
def gather_weights(img, idxs, axis):
37-
"""Safely gather with boundary handling"""
38-
idxs = jnp.clip(idxs, 0, img.shape[axis] - 1)
39-
return jnp.take(img, idxs, axis=axis)
44+
"""Safely gather with boundary handling"""
45+
idxs = jnp.clip(idxs, 0, img.shape[axis] - 1)
46+
return jnp.take(img, idxs, axis=axis)
47+
4048

4149
def interpolate_along_axis_bchw(img, idxs, weights, axis):
42-
"""
50+
"""
4351
Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W).
4452
idxs: (out_size, 4) int32 indices
4553
weights: (out_size, 4) float32 weights
4654
"""
47-
assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)"
48-
out_size = idxs.shape[0]
49-
k = idxs.shape[1] # Typically 4 for cubic
55+
assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)"
56+
out_size = idxs.shape[0]
57+
k = idxs.shape[1] # Typically 4 for cubic
5058

51-
# Clip to input bounds
52-
idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) # (out_size, 4)
59+
# Clip to input bounds
60+
idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) # (out_size, 4)
5361

54-
def gather_and_weight(i):
55-
idx = idxs[i] # (4,)
56-
w = weights[i] # (4,)
62+
def gather_and_weight(i):
63+
idx = idxs[i] # (4,)
64+
w = weights[i] # (4,)
5765

58-
def gather_one(offset):
59-
return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W)
66+
def gather_one(offset):
67+
return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W)
6068

61-
gathered = jnp.stack([gather_one(o) for o in range(k)], axis=0) # (4, B, C, H, W)
62-
weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W)
63-
return weighted
69+
gathered = jnp.stack([gather_one(o) for o in range(k)],
70+
axis=0) # (4, B, C, H, W)
71+
weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W)
72+
return weighted
6473

65-
out = jax.vmap(gather_and_weight)(jnp.arange(out_size)) # (out_size, B, C, H, W)
66-
67-
# Move the interpolated axis back into place
68-
if axis == 2: # interpolated over H
69-
return jnp.moveaxis(out, 0, 2) # (B, C, out_H, W)
70-
else: # axis == 3, interpolated over W
71-
return jnp.moveaxis(out, 0, 3) # (B, C, H, out_W)
74+
out = jax.vmap(gather_and_weight)(
75+
jnp.arange(out_size)) # (out_size, B, C, H, W)
7276

77+
# Move the interpolated axis back into place
78+
if axis == 2: # interpolated over H
79+
return jnp.moveaxis(out, 0, 2) # (B, C, out_H, W)
80+
else: # axis == 3, interpolated over W
81+
return jnp.moveaxis(out, 0, 3) # (B, C, H, out_W)
7382

7483

7584
def interpolate_bicubic_no_aa(img, out_h, out_w, align_corners=False):
76-
h, w = img.shape[-2:]
77-
if align_corners and out_h > 1:
78-
scale_y = (h - 1) / (out_h - 1)
79-
else:
80-
scale_y = out_h / h
81-
82-
if align_corners and out_w > 1:
83-
scale_x = (w - 1) / (out_w - 1)
84-
else:
85-
scale_x = out_w / w
86-
87-
idxs_y, weights_y = compute_contribs(
88-
h, out_h, scale_y, align_corners=align_corners,
89-
)
90-
tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2)
91-
92-
idxs_x, weights_x = compute_contribs(
93-
w, out_w, scale_x, align_corners=align_corners,
94-
)
95-
out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3)
96-
return out
85+
h, w = img.shape[-2:]
86+
if align_corners and out_h > 1:
87+
scale_y = (h - 1) / (out_h - 1)
88+
else:
89+
scale_y = out_h / h
90+
91+
if align_corners and out_w > 1:
92+
scale_x = (w - 1) / (out_w - 1)
93+
else:
94+
scale_x = out_w / w
95+
96+
idxs_y, weights_y = compute_contribs(
97+
h,
98+
out_h,
99+
scale_y,
100+
align_corners=align_corners,
101+
dtype=img.dtype,
102+
)
103+
tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2)
104+
105+
idxs_x, weights_x = compute_contribs(
106+
w,
107+
out_w,
108+
scale_x,
109+
align_corners=align_corners,
110+
dtype=img.dtype,
111+
)
112+
out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3)
113+
return out

torchax/torchax/ops/jtorch.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -514,8 +514,7 @@ def functional_linear(self, weights, bias=None):
514514
return res
515515

516516

517-
518-
@register_function(torch.nn.functional.interpolate)
517+
@register_function(torch.nn.functional.interpolate)
519518
def functional_interpolate(
520519
input,
521520
size: Tuple[int, int],
@@ -544,19 +543,17 @@ def functional_interpolate(
544543
# None check
545544
antialias = antialias or False
546545
align_corners = align_corners or False
547-
548-
if mode in ('cubic', 'bicubic', 'tricubic') and not antialias:
546+
547+
if mode in ('cubic', 'bicubic',
548+
'tricubic') and not antialias and size is not None:
549549
return jimage.interpolate_bicubic_no_aa(
550-
input,
551-
size[0],
552-
size[1],
553-
align_corners,
550+
input,
551+
size[0],
552+
size[1],
553+
align_corners,
554+
)
555+
else:
556+
# fallback
557+
raise torchax.tensor.OperatorNotFound(
558+
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
554559
)
555-
return jaten._aten_upsample(
556-
input,
557-
size,
558-
align_corners,
559-
antialias,
560-
mode,
561-
scale_factor,
562-
)

0 commit comments

Comments
 (0)