Skip to content

Commit 09f044d

Browse files
committed
[torchax] Added support for bicubic and billinear resampling
1 parent edc1a88 commit 09f044d

File tree

3 files changed

+240
-10
lines changed

3 files changed

+240
-10
lines changed

torchax/test/test_image.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import unittest
2+
from typing import Tuple
3+
import itertools
4+
from functools import partial
5+
import jax
6+
import torch
7+
8+
import torchax
9+
import torchax.interop
10+
11+
def to_xla_tensor(tensorstree):
12+
return torchax.interop.torch_view(torchax.tensor.t2j(tensorstree))
13+
14+
def to_torch_tensor(tensorstree):
15+
return torchax.tensor.j2t(torchax.interop.jax_view(tensorstree))
16+
17+
18+
19+
@partial(jax.jit, static_argnums=(1, 2, 3, 4))
20+
def upsample_jit(tensor, output_size: Tuple[int, int], align_corners: bool, antialias: bool, method: str):
21+
tensor = torchax.interop.torch_view(tensor)
22+
tensor = torch.nn.functional.interpolate(tensor, size=output_size, mode=method, align_corners=align_corners, antialias=antialias)
23+
return torchax.interop.jax_view(tensor)
24+
25+
26+
def test_upsampling(align_corners: bool, antialias: bool, method: str):
27+
28+
if method == 'bilinear':
29+
if align_corners:
30+
return # bilinear upsampling does not support align_corners
31+
32+
input_tensor = torch.rand((1, 1, 256, 512), dtype=torch.float32)
33+
output_size = (128, 64)
34+
35+
upsampled_tensor = torch.nn.functional.interpolate(input_tensor, size=output_size, mode=method, align_corners=align_corners, antialias=antialias)
36+
37+
with torchax.default_env():
38+
input_tensor_xla = to_xla_tensor(input_tensor)
39+
input_tensor_xla = torchax.interop.jax_view(input_tensor_xla)
40+
upsampled_tensor_xla = upsample_jit(input_tensor_xla, output_size, align_corners, antialias=antialias, method=method)
41+
42+
upsampled_tensor_xla = to_torch_tensor(upsampled_tensor_xla)
43+
abs_err = torch.abs(upsampled_tensor - upsampled_tensor_xla)
44+
45+
assert torch.allclose(upsampled_tensor, upsampled_tensor_xla, atol=1e-4, rtol=1e-5), f"{method} upsampling failed with error {abs_err.max()}"
46+
47+
class TestResampling(unittest.TestCase):
48+
def test_resampling_combinations(self):
49+
methods = [
50+
'bicubic',
51+
'bilinear',
52+
]
53+
antialias_options = [
54+
True,
55+
False,
56+
]
57+
58+
aligncorners_options = [
59+
False,
60+
True,
61+
]
62+
63+
for method, antialias, align_corners in itertools.product(methods, antialias_options, aligncorners_options):
64+
with self.subTest(method=method, antialias=antialias, align_corners=align_corners):
65+
test_upsampling(align_corners=align_corners, antialias=antialias, method=method)
66+
67+
68+
69+
if __name__ == '__main__':
70+
unittest.main()

torchax/torchax/ops/jaten.py

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
import torch.distributed._functional_collectives
1414
from torchax.ops import ops_registry
15+
from torchax.ops import jimage
1516
from torchax.ops import op_base, mappings
1617
from torchax import interop
1718
from torchax.ops import jax_reimplement
@@ -5181,17 +5182,16 @@ def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0):
51815182
return output
51825183

51835184

5184-
@op(torch.ops.aten._upsample_bilinear2d_aa)
5185-
def _aten_upsample_bilinear2d_aa(input,
5186-
output_size,
5187-
align_corners,
5188-
scale_factors=None,
5189-
scales_h=None,
5190-
scales_w=None):
5185+
def _aten_upsample(input,
5186+
output_size,
5187+
align_corners,
5188+
antialias,
5189+
method,
5190+
scale_factors=None,
5191+
scales_h=None,
5192+
scales_w=None):
51915193
# input: is of type jaxlib.xla_extension.ArrayImpl
51925194
image = input
5193-
method = "bilinear"
5194-
antialias = True # ignored for upsampling
51955195

51965196
# https://jax.readthedocs.io/en/latest/_autosummary/jax.image.resize.html
51975197
# Resize does not distinguish batch, channel size.
@@ -5241,7 +5241,7 @@ def _aten_upsample_bilinear2d_aa(input,
52415241
])
52425242

52435243
translation = jnp.array([0 for i in spatial_dims])
5244-
5244+
52455245
return jax_reimplement.scale_and_translate(
52465246
image,
52475247
shape,
@@ -5253,6 +5253,70 @@ def _aten_upsample_bilinear2d_aa(input,
52535253
)
52545254

52555255

5256+
@op(torch.ops.aten._upsample_bilinear2d_aa)
5257+
def _aten_upsample_billinear_aa(input,
5258+
output_size,
5259+
align_corners,
5260+
scale_factors=None,
5261+
scales_h=None,
5262+
scales_w=None):
5263+
return _aten_upsample(
5264+
input,
5265+
output_size,
5266+
align_corners,
5267+
True, # antialias
5268+
"bilinear", # method
5269+
scale_factors,
5270+
scales_h,
5271+
scales_w
5272+
)
5273+
5274+
@op(torch.nn.functional.interpolate)
5275+
def _interpolate(
5276+
input,
5277+
size: Tuple[int, int],
5278+
scale_factor: Optional[float],
5279+
mode: str,
5280+
align_corners: bool,
5281+
recompute_scale_factor: bool,
5282+
antialias: bool,
5283+
):
5284+
supported_methods = (
5285+
"nearest",
5286+
"linear",
5287+
"bilinear",
5288+
"trilinear",
5289+
"cubic",
5290+
"bicubic",
5291+
"tricubic",
5292+
"lanczos3",
5293+
"lanczos5",
5294+
)
5295+
is_jax_supported = mode in supported_methods
5296+
if not is_jax_supported:
5297+
raise NotImplementedError(
5298+
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
5299+
)
5300+
# None check
5301+
antialias = antialias or False
5302+
align_corners = align_corners or False
5303+
5304+
if mode in ('cubic', 'bicubic', 'tricubic') and not antialias:
5305+
return jimage.interpolate_bicubic_no_aa(
5306+
input,
5307+
size[0],
5308+
size[1],
5309+
align_corners,
5310+
)
5311+
return _aten_upsample(
5312+
input,
5313+
size,
5314+
align_corners,
5315+
antialias,
5316+
mode,
5317+
scale_factor,
5318+
)
5319+
52565320
@op(torch.ops.aten.polar)
52575321
def _aten_polar(abs, angle, *, out=None):
52585322
return jax.lax.complex(abs * jnp.cos(angle), abs * jnp.sin(angle))

torchax/torchax/ops/jimage.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
def cubic_kernel(x, a=-0.75):
5+
"""Cubic kernel with a = -0.75 (PyTorch-like Keys kernel)"""
6+
absx = jnp.abs(x)
7+
x2 = absx * absx
8+
x3 = x2 * absx
9+
cond1 = (absx <= 1)
10+
cond2 = (absx > 1) & (absx < 2)
11+
f1 = (a + 2) * x3 - (a + 3) * x2 + 1
12+
f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a
13+
return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0))
14+
15+
16+
def compute_contribs(in_size, out_size, scale, support=2.0, align_corners=False):
17+
if align_corners:
18+
if out_size == 1:
19+
in_coords = jnp.zeros((1,))
20+
else:
21+
in_coords = jnp.linspace(0, in_size - 1, out_size)
22+
else:
23+
out_coords = jnp.arange(out_size) + 0.5
24+
in_coords = out_coords / scale - 0.5
25+
26+
left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1
27+
idxs = left_idx[:, None] + jnp.arange(4)
28+
29+
dx = in_coords[:, None] - idxs
30+
31+
weights = cubic_kernel(dx)
32+
33+
weights = weights / jnp.sum(weights, axis=1, keepdims=True)
34+
return idxs, weights
35+
36+
def gather_weights(img, idxs, axis):
37+
"""Safely gather with boundary handling"""
38+
idxs = jnp.clip(idxs, 0, img.shape[axis] - 1)
39+
return jnp.take(img, idxs, axis=axis)
40+
41+
def interpolate_along_axis_bchw(img, idxs, weights, axis):
42+
"""
43+
Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W).
44+
idxs: (out_size, 4) int32 indices
45+
weights: (out_size, 4) float32 weights
46+
"""
47+
assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)"
48+
out_size = idxs.shape[0]
49+
k = idxs.shape[1] # Typically 4 for cubic
50+
51+
# Clip to input bounds
52+
idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) # (out_size, 4)
53+
54+
def gather_and_weight(i):
55+
idx = idxs[i] # (4,)
56+
w = weights[i] # (4,)
57+
58+
def gather_one(offset):
59+
return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W)
60+
61+
gathered = jnp.stack([gather_one(o) for o in range(k)], axis=0) # (4, B, C, H, W)
62+
weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W)
63+
return weighted
64+
65+
out = jax.vmap(gather_and_weight)(jnp.arange(out_size)) # (out_size, B, C, H, W)
66+
67+
# Move the interpolated axis back into place
68+
if axis == 2: # interpolated over H
69+
return jnp.moveaxis(out, 0, 2) # (B, C, out_H, W)
70+
else: # axis == 3, interpolated over W
71+
return jnp.moveaxis(out, 0, 3) # (B, C, H, out_W)
72+
73+
74+
75+
def interpolate_bicubic_no_aa(img, out_h, out_w, align_corners=False):
76+
h, w = img.shape[-2:]
77+
if align_corners and out_h > 1:
78+
scale_y = (h - 1) / (out_h - 1)
79+
else:
80+
scale_y = out_h / h
81+
82+
if align_corners and out_w > 1:
83+
scale_x = (w - 1) / (out_w - 1)
84+
else:
85+
scale_x = out_w / w
86+
87+
idxs_y, weights_y = compute_contribs(
88+
h, out_h, scale_y, align_corners=align_corners,
89+
)
90+
tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2)
91+
92+
idxs_x, weights_x = compute_contribs(
93+
w, out_w, scale_x, align_corners=align_corners,
94+
)
95+
out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3)
96+
return out

0 commit comments

Comments
 (0)