Skip to content

Commit 2196d0c

Browse files
qihqizmelumian
andauthored
[torchax] Added support for bicubic and billinear resampling (#9222)
Co-authored-by: zmelumian <[email protected]>
1 parent 06c5533 commit 2196d0c

File tree

4 files changed

+279
-11
lines changed

4 files changed

+279
-11
lines changed

torchax/test/test_image.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from absl.testing import parameterized
2+
import unittest
3+
from typing import Tuple
4+
import itertools
5+
from functools import partial
6+
import jax
7+
import torch
8+
9+
import torchax
10+
import torchax.interop
11+
12+
13+
def to_xla_tensor(tensorstree):
14+
return torchax.interop.torch_view(torchax.tensor.t2j(tensorstree))
15+
16+
17+
def to_torch_tensor(tensorstree):
18+
return torchax.tensor.j2t(torchax.interop.jax_view(tensorstree))
19+
20+
21+
@partial(jax.jit, static_argnums=(1, 2, 3, 4))
22+
def upsample_jit(tensor, output_size: Tuple[int, int], align_corners: bool,
23+
antialias: bool, method: str):
24+
tensor = torchax.interop.torch_view(tensor)
25+
tensor = torch.nn.functional.interpolate(
26+
tensor,
27+
size=output_size,
28+
mode=method,
29+
align_corners=align_corners,
30+
antialias=antialias)
31+
return torchax.interop.jax_view(tensor)
32+
33+
34+
class TestResampling(parameterized.TestCase):
35+
36+
@parameterized.product(
37+
antialias=[
38+
True,
39+
False,
40+
], align_corners=[
41+
False,
42+
True,
43+
])
44+
def test_resampling_combinations_bicubic(self, antialias, align_corners):
45+
method = 'bicubic'
46+
input_tensor = torch.rand((1, 1, 256, 512), dtype=torch.float32)
47+
output_size = (128, 64)
48+
49+
upsampled_tensor = torch.nn.functional.interpolate(
50+
input_tensor,
51+
size=output_size,
52+
mode=method,
53+
align_corners=align_corners,
54+
antialias=antialias)
55+
56+
with torchax.default_env():
57+
input_tensor_xla = to_xla_tensor(input_tensor)
58+
input_tensor_xla = torchax.interop.jax_view(input_tensor_xla)
59+
upsampled_tensor_xla = upsample_jit(
60+
input_tensor_xla,
61+
output_size,
62+
align_corners,
63+
antialias=antialias,
64+
method=method)
65+
66+
upsampled_tensor_xla = to_torch_tensor(upsampled_tensor_xla)
67+
abs_err = torch.abs(upsampled_tensor - upsampled_tensor_xla)
68+
69+
assert torch.allclose(
70+
upsampled_tensor, upsampled_tensor_xla, atol=1e-4,
71+
rtol=1e-5), f"{method} upsampling failed with error {abs_err.max()}"
72+
73+
74+
if __name__ == '__main__':
75+
unittest.main()

torchax/torchax/ops/jaten.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5181,17 +5181,16 @@ def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0):
51815181
return output
51825182

51835183

5184-
@op(torch.ops.aten._upsample_bilinear2d_aa)
5185-
def _aten_upsample_bilinear2d_aa(input,
5186-
output_size,
5187-
align_corners,
5188-
scale_factors=None,
5189-
scales_h=None,
5190-
scales_w=None):
5184+
def _aten_upsample(input,
5185+
output_size,
5186+
align_corners,
5187+
antialias,
5188+
method,
5189+
scale_factors=None,
5190+
scales_h=None,
5191+
scales_w=None):
51915192
# input: is of type jaxlib.xla_extension.ArrayImpl
51925193
image = input
5193-
method = "bilinear"
5194-
antialias = True # ignored for upsampling
51955194

51965195
# https://jax.readthedocs.io/en/latest/_autosummary/jax.image.resize.html
51975196
# Resize does not distinguish batch, channel size.
@@ -5253,6 +5252,42 @@ def _aten_upsample_bilinear2d_aa(input,
52535252
)
52545253

52555254

5255+
@op(torch.ops.aten._upsample_bilinear2d_aa)
5256+
def _aten_upsample_billinear_aa(input,
5257+
output_size,
5258+
align_corners,
5259+
scale_factors=None,
5260+
scales_h=None,
5261+
scales_w=None):
5262+
return _aten_upsample(
5263+
input,
5264+
output_size,
5265+
align_corners,
5266+
True, # antialias
5267+
"bilinear", # method
5268+
scale_factors,
5269+
scales_h,
5270+
scales_w)
5271+
5272+
5273+
@op(torch.ops.aten._upsample_bicubic2d_aa)
5274+
def _aten_upsample_bicubic2d_aa(input,
5275+
output_size,
5276+
align_corners,
5277+
scale_factors=None,
5278+
scales_h=None,
5279+
scales_w=None):
5280+
return _aten_upsample(
5281+
input,
5282+
output_size,
5283+
align_corners,
5284+
True, # antialias
5285+
"bicubic", # method
5286+
scale_factors,
5287+
scales_h,
5288+
scales_w)
5289+
5290+
52565291
@op(torch.ops.aten.polar)
52575292
def _aten_polar(abs, angle, *, out=None):
52585293
return jax.lax.complex(abs * jnp.cos(angle), abs * jnp.sin(angle))

torchax/torchax/ops/jimage.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
5+
def cubic_kernel(x, a=-0.75):
6+
"""Cubic kernel with a = -0.75 (PyTorch-like Keys kernel)"""
7+
absx = jnp.abs(x)
8+
x2 = absx * absx
9+
x3 = x2 * absx
10+
cond1 = (absx <= 1)
11+
cond2 = (absx > 1) & (absx < 2)
12+
f1 = (a + 2) * x3 - (a + 3) * x2 + 1
13+
f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a
14+
return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0))
15+
16+
17+
def compute_contribs(in_size,
18+
out_size,
19+
scale,
20+
support=2.0,
21+
align_corners=False,
22+
dtype=None):
23+
if align_corners:
24+
if out_size == 1:
25+
in_coords = jnp.zeros((1,), dtype=dtype)
26+
else:
27+
in_coords = jnp.linspace(0, in_size - 1, out_size, dtype=dtype)
28+
else:
29+
out_coords = jnp.arange(out_size, dtype=dtype) + 0.5
30+
in_coords = out_coords / scale - 0.5
31+
32+
left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1
33+
idxs = left_idx[:, None] + jnp.arange(4)
34+
35+
dx = in_coords[:, None] - idxs
36+
37+
weights = cubic_kernel(dx)
38+
39+
weights = weights / jnp.sum(weights, axis=1, keepdims=True)
40+
return idxs, weights
41+
42+
43+
def gather_weights(img, idxs, axis):
44+
"""Safely gather with boundary handling"""
45+
idxs = jnp.clip(idxs, 0, img.shape[axis] - 1)
46+
return jnp.take(img, idxs, axis=axis)
47+
48+
49+
def interpolate_along_axis_bchw(img, idxs, weights, axis):
50+
"""
51+
Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W).
52+
idxs: (out_size, 4) int32 indices
53+
weights: (out_size, 4) float32 weights
54+
"""
55+
assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)"
56+
out_size = idxs.shape[0]
57+
k = idxs.shape[1] # Typically 4 for cubic
58+
59+
# Clip to input bounds
60+
idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) # (out_size, 4)
61+
62+
def gather_and_weight(i):
63+
idx = idxs[i] # (4,)
64+
w = weights[i] # (4,)
65+
66+
def gather_one(offset):
67+
return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W)
68+
69+
gathered = jnp.stack([gather_one(o) for o in range(k)],
70+
axis=0) # (4, B, C, H, W)
71+
weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W)
72+
return weighted
73+
74+
out = jax.vmap(gather_and_weight)(
75+
jnp.arange(out_size)) # (out_size, B, C, H, W)
76+
77+
# Move the interpolated axis back into place
78+
if axis == 2: # interpolated over H
79+
return jnp.moveaxis(out, 0, 2) # (B, C, out_H, W)
80+
else: # axis == 3, interpolated over W
81+
return jnp.moveaxis(out, 0, 3) # (B, C, H, out_W)
82+
83+
84+
def interpolate_bicubic_no_aa(img, out_h, out_w, align_corners=False):
85+
h, w = img.shape[-2:]
86+
if align_corners and out_h > 1:
87+
scale_y = (h - 1) / (out_h - 1)
88+
else:
89+
scale_y = out_h / h
90+
91+
if align_corners and out_w > 1:
92+
scale_x = (w - 1) / (out_w - 1)
93+
else:
94+
scale_x = out_w / w
95+
96+
idxs_y, weights_y = compute_contribs(
97+
h,
98+
out_h,
99+
scale_y,
100+
align_corners=align_corners,
101+
dtype=img.dtype,
102+
)
103+
tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2)
104+
105+
idxs_x, weights_x = compute_contribs(
106+
w,
107+
out_w,
108+
scale_x,
109+
align_corners=align_corners,
110+
dtype=img.dtype,
111+
)
112+
out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3)
113+
return out

torchax/torchax/ops/jtorch.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import math
44
import collections.abc
55
import functools
6-
from typing import Optional, Sequence
6+
from typing import Optional, Sequence, Tuple
77
import numpy as np
88

99
import jax
@@ -13,7 +13,7 @@
1313

1414
import torch
1515
from torchax.ops.ops_registry import register_torch_function_op
16-
from torchax.ops import op_base, mappings, jaten
16+
from torchax.ops import op_base, mappings, jaten, jimage
1717
import torchax.tensor
1818
from torchax.view import View, NarrowInfo
1919
import torch.utils._pytree as pytree
@@ -512,3 +512,48 @@ def functional_linear(self, weights, bias=None):
512512
if bias is not None:
513513
res += bias
514514
return res
515+
516+
517+
@register_function(torch.nn.functional.interpolate)
518+
def functional_interpolate(
519+
input,
520+
size: Tuple[int, int],
521+
scale_factor: Optional[float],
522+
mode: str,
523+
align_corners: bool,
524+
recompute_scale_factor: bool,
525+
antialias: bool,
526+
):
527+
supported_methods = (
528+
"nearest",
529+
"linear",
530+
"bilinear",
531+
"trilinear",
532+
"cubic",
533+
"bicubic",
534+
"tricubic",
535+
"lanczos3",
536+
"lanczos5",
537+
)
538+
is_jax_supported = mode in supported_methods
539+
if not is_jax_supported:
540+
raise torchax.tensor.OperatorNotFound(
541+
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
542+
)
543+
# None check
544+
antialias = antialias or False
545+
align_corners = align_corners or False
546+
547+
if mode in ('cubic', 'bicubic',
548+
'tricubic') and not antialias and size is not None:
549+
return jimage.interpolate_bicubic_no_aa(
550+
input,
551+
size[0],
552+
size[1],
553+
align_corners,
554+
)
555+
else:
556+
# fallback
557+
raise torchax.tensor.OperatorNotFound(
558+
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
559+
)

0 commit comments

Comments
 (0)