Skip to content

[torchax] Added support for bicubic and billinear resampling #9222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions torchax/test/test_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from absl.testing import parameterized
import unittest
from typing import Tuple
import itertools
from functools import partial
import jax
import torch

import torchax
import torchax.interop


def to_xla_tensor(tensorstree):
return torchax.interop.torch_view(torchax.tensor.t2j(tensorstree))


def to_torch_tensor(tensorstree):
return torchax.tensor.j2t(torchax.interop.jax_view(tensorstree))


@partial(jax.jit, static_argnums=(1, 2, 3, 4))
def upsample_jit(tensor, output_size: Tuple[int, int], align_corners: bool,
antialias: bool, method: str):
tensor = torchax.interop.torch_view(tensor)
tensor = torch.nn.functional.interpolate(
tensor,
size=output_size,
mode=method,
align_corners=align_corners,
antialias=antialias)
return torchax.interop.jax_view(tensor)


class TestResampling(parameterized.TestCase):

@parameterized.product(
antialias=[
True,
False,
], align_corners=[
False,
True,
])
def test_resampling_combinations_bicubic(self, antialias, align_corners):
method = 'bicubic'
input_tensor = torch.rand((1, 1, 256, 512), dtype=torch.float32)
output_size = (128, 64)

upsampled_tensor = torch.nn.functional.interpolate(
input_tensor,
size=output_size,
mode=method,
align_corners=align_corners,
antialias=antialias)

with torchax.default_env():
input_tensor_xla = to_xla_tensor(input_tensor)
input_tensor_xla = torchax.interop.jax_view(input_tensor_xla)
upsampled_tensor_xla = upsample_jit(
input_tensor_xla,
output_size,
align_corners,
antialias=antialias,
method=method)

upsampled_tensor_xla = to_torch_tensor(upsampled_tensor_xla)
abs_err = torch.abs(upsampled_tensor - upsampled_tensor_xla)

assert torch.allclose(
upsampled_tensor, upsampled_tensor_xla, atol=1e-4,
rtol=1e-5), f"{method} upsampling failed with error {abs_err.max()}"


if __name__ == '__main__':
unittest.main()
53 changes: 44 additions & 9 deletions torchax/torchax/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -5181,17 +5181,16 @@ def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0):
return output


@op(torch.ops.aten._upsample_bilinear2d_aa)
def _aten_upsample_bilinear2d_aa(input,
output_size,
align_corners,
scale_factors=None,
scales_h=None,
scales_w=None):
def _aten_upsample(input,
output_size,
align_corners,
antialias,
method,
scale_factors=None,
scales_h=None,
scales_w=None):
# input: is of type jaxlib.xla_extension.ArrayImpl
image = input
method = "bilinear"
antialias = True # ignored for upsampling

# https://jax.readthedocs.io/en/latest/_autosummary/jax.image.resize.html
# Resize does not distinguish batch, channel size.
Expand Down Expand Up @@ -5253,6 +5252,42 @@ def _aten_upsample_bilinear2d_aa(input,
)


@op(torch.ops.aten._upsample_bilinear2d_aa)
def _aten_upsample_billinear_aa(input,
output_size,
align_corners,
scale_factors=None,
scales_h=None,
scales_w=None):
return _aten_upsample(
input,
output_size,
align_corners,
True, # antialias
"bilinear", # method
scale_factors,
scales_h,
scales_w)


@op(torch.ops.aten._upsample_bicubic2d_aa)
def _aten_upsample_bicubic2d_aa(input,
output_size,
align_corners,
scale_factors=None,
scales_h=None,
scales_w=None):
return _aten_upsample(
input,
output_size,
align_corners,
True, # antialias
"bicubic", # method
scale_factors,
scales_h,
scales_w)


@op(torch.ops.aten.polar)
def _aten_polar(abs, angle, *, out=None):
return jax.lax.complex(abs * jnp.cos(angle), abs * jnp.sin(angle))
Expand Down
113 changes: 113 additions & 0 deletions torchax/torchax/ops/jimage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import jax
import jax.numpy as jnp


def cubic_kernel(x, a=-0.75):
"""Cubic kernel with a = -0.75 (PyTorch-like Keys kernel)"""
absx = jnp.abs(x)
x2 = absx * absx
x3 = x2 * absx
cond1 = (absx <= 1)
cond2 = (absx > 1) & (absx < 2)
f1 = (a + 2) * x3 - (a + 3) * x2 + 1
f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a
return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0))


def compute_contribs(in_size,
out_size,
scale,
support=2.0,
align_corners=False,
dtype=None):
if align_corners:
if out_size == 1:
in_coords = jnp.zeros((1,), dtype=dtype)
else:
in_coords = jnp.linspace(0, in_size - 1, out_size, dtype=dtype)
else:
out_coords = jnp.arange(out_size, dtype=dtype) + 0.5
in_coords = out_coords / scale - 0.5

left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1
idxs = left_idx[:, None] + jnp.arange(4)

dx = in_coords[:, None] - idxs

weights = cubic_kernel(dx)

weights = weights / jnp.sum(weights, axis=1, keepdims=True)
return idxs, weights


def gather_weights(img, idxs, axis):
"""Safely gather with boundary handling"""
idxs = jnp.clip(idxs, 0, img.shape[axis] - 1)
return jnp.take(img, idxs, axis=axis)


def interpolate_along_axis_bchw(img, idxs, weights, axis):
"""
Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W).
idxs: (out_size, 4) int32 indices
weights: (out_size, 4) float32 weights
"""
assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)"
out_size = idxs.shape[0]
k = idxs.shape[1] # Typically 4 for cubic

# Clip to input bounds
idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) # (out_size, 4)

def gather_and_weight(i):
idx = idxs[i] # (4,)
w = weights[i] # (4,)

def gather_one(offset):
return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W)

gathered = jnp.stack([gather_one(o) for o in range(k)],
axis=0) # (4, B, C, H, W)
weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W)
return weighted

out = jax.vmap(gather_and_weight)(
jnp.arange(out_size)) # (out_size, B, C, H, W)

# Move the interpolated axis back into place
if axis == 2: # interpolated over H
return jnp.moveaxis(out, 0, 2) # (B, C, out_H, W)
else: # axis == 3, interpolated over W
return jnp.moveaxis(out, 0, 3) # (B, C, H, out_W)


def interpolate_bicubic_no_aa(img, out_h, out_w, align_corners=False):
h, w = img.shape[-2:]
if align_corners and out_h > 1:
scale_y = (h - 1) / (out_h - 1)
else:
scale_y = out_h / h

if align_corners and out_w > 1:
scale_x = (w - 1) / (out_w - 1)
else:
scale_x = out_w / w

idxs_y, weights_y = compute_contribs(
h,
out_h,
scale_y,
align_corners=align_corners,
dtype=img.dtype,
)
tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2)

idxs_x, weights_x = compute_contribs(
w,
out_w,
scale_x,
align_corners=align_corners,
dtype=img.dtype,
)
out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3)
return out
49 changes: 47 additions & 2 deletions torchax/torchax/ops/jtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math
import collections.abc
import functools
from typing import Optional, Sequence
from typing import Optional, Sequence, Tuple
import numpy as np

import jax
Expand All @@ -13,7 +13,7 @@

import torch
from torchax.ops.ops_registry import register_torch_function_op
from torchax.ops import op_base, mappings, jaten
from torchax.ops import op_base, mappings, jaten, jimage
import torchax.tensor
from torchax.view import View, NarrowInfo
import torch.utils._pytree as pytree
Expand Down Expand Up @@ -512,3 +512,48 @@ def functional_linear(self, weights, bias=None):
if bias is not None:
res += bias
return res


@register_function(torch.nn.functional.interpolate)
def functional_interpolate(
input,
size: Tuple[int, int],
scale_factor: Optional[float],
mode: str,
align_corners: bool,
recompute_scale_factor: bool,
antialias: bool,
):
supported_methods = (
"nearest",
"linear",
"bilinear",
"trilinear",
"cubic",
"bicubic",
"tricubic",
"lanczos3",
"lanczos5",
)
is_jax_supported = mode in supported_methods
if not is_jax_supported:
raise torchax.tensor.OperatorNotFound(
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
)
# None check
antialias = antialias or False
align_corners = align_corners or False

if mode in ('cubic', 'bicubic',
'tricubic') and not antialias and size is not None:
return jimage.interpolate_bicubic_no_aa(
input,
size[0],
size[1],
align_corners,
)
else:
# fallback
raise torchax.tensor.OperatorNotFound(
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
)
Loading