-
Notifications
You must be signed in to change notification settings - Fork 534
[torchax] Added support for bicubic and billinear resampling #9222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from absl.testing import parameterized | ||
import unittest | ||
from typing import Tuple | ||
import itertools | ||
from functools import partial | ||
import jax | ||
import torch | ||
|
||
import torchax | ||
import torchax.interop | ||
|
||
|
||
def to_xla_tensor(tensorstree): | ||
return torchax.interop.torch_view(torchax.tensor.t2j(tensorstree)) | ||
|
||
|
||
def to_torch_tensor(tensorstree): | ||
return torchax.tensor.j2t(torchax.interop.jax_view(tensorstree)) | ||
|
||
|
||
@partial(jax.jit, static_argnums=(1, 2, 3, 4)) | ||
def upsample_jit(tensor, output_size: Tuple[int, int], align_corners: bool, | ||
antialias: bool, method: str): | ||
tensor = torchax.interop.torch_view(tensor) | ||
tensor = torch.nn.functional.interpolate( | ||
tensor, | ||
size=output_size, | ||
mode=method, | ||
align_corners=align_corners, | ||
antialias=antialias) | ||
return torchax.interop.jax_view(tensor) | ||
|
||
|
||
class TestResampling(parameterized.TestCase): | ||
|
||
@parameterized.product( | ||
antialias=[ | ||
True, | ||
False, | ||
], align_corners=[ | ||
False, | ||
True, | ||
]) | ||
def test_resampling_combinations_bicubic(self, antialias, align_corners): | ||
method = 'bicubic' | ||
input_tensor = torch.rand((1, 1, 256, 512), dtype=torch.float32) | ||
output_size = (128, 64) | ||
|
||
upsampled_tensor = torch.nn.functional.interpolate( | ||
input_tensor, | ||
size=output_size, | ||
mode=method, | ||
align_corners=align_corners, | ||
antialias=antialias) | ||
|
||
with torchax.default_env(): | ||
input_tensor_xla = to_xla_tensor(input_tensor) | ||
input_tensor_xla = torchax.interop.jax_view(input_tensor_xla) | ||
upsampled_tensor_xla = upsample_jit( | ||
input_tensor_xla, | ||
output_size, | ||
align_corners, | ||
antialias=antialias, | ||
method=method) | ||
|
||
upsampled_tensor_xla = to_torch_tensor(upsampled_tensor_xla) | ||
abs_err = torch.abs(upsampled_tensor - upsampled_tensor_xla) | ||
|
||
assert torch.allclose( | ||
upsampled_tensor, upsampled_tensor_xla, atol=1e-4, | ||
rtol=1e-5), f"{method} upsampling failed with error {abs_err.max()}" | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
|
||
def cubic_kernel(x, a=-0.75): | ||
"""Cubic kernel with a = -0.75 (PyTorch-like Keys kernel)""" | ||
absx = jnp.abs(x) | ||
x2 = absx * absx | ||
x3 = x2 * absx | ||
cond1 = (absx <= 1) | ||
cond2 = (absx > 1) & (absx < 2) | ||
f1 = (a + 2) * x3 - (a + 3) * x2 + 1 | ||
f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a | ||
return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0)) | ||
|
||
|
||
def compute_contribs(in_size, | ||
out_size, | ||
scale, | ||
support=2.0, | ||
align_corners=False, | ||
dtype=None): | ||
if align_corners: | ||
if out_size == 1: | ||
in_coords = jnp.zeros((1,), dtype=dtype) | ||
else: | ||
in_coords = jnp.linspace(0, in_size - 1, out_size, dtype=dtype) | ||
else: | ||
out_coords = jnp.arange(out_size, dtype=dtype) + 0.5 | ||
in_coords = out_coords / scale - 0.5 | ||
|
||
left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1 | ||
idxs = left_idx[:, None] + jnp.arange(4) | ||
|
||
dx = in_coords[:, None] - idxs | ||
|
||
weights = cubic_kernel(dx) | ||
|
||
weights = weights / jnp.sum(weights, axis=1, keepdims=True) | ||
return idxs, weights | ||
|
||
|
||
def gather_weights(img, idxs, axis): | ||
"""Safely gather with boundary handling""" | ||
idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) | ||
return jnp.take(img, idxs, axis=axis) | ||
|
||
|
||
def interpolate_along_axis_bchw(img, idxs, weights, axis): | ||
""" | ||
Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W). | ||
idxs: (out_size, 4) int32 indices | ||
weights: (out_size, 4) float32 weights | ||
""" | ||
assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)" | ||
out_size = idxs.shape[0] | ||
k = idxs.shape[1] # Typically 4 for cubic | ||
|
||
# Clip to input bounds | ||
idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) # (out_size, 4) | ||
|
||
def gather_and_weight(i): | ||
idx = idxs[i] # (4,) | ||
w = weights[i] # (4,) | ||
|
||
def gather_one(offset): | ||
return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W) | ||
|
||
gathered = jnp.stack([gather_one(o) for o in range(k)], | ||
axis=0) # (4, B, C, H, W) | ||
weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W) | ||
return weighted | ||
|
||
out = jax.vmap(gather_and_weight)( | ||
jnp.arange(out_size)) # (out_size, B, C, H, W) | ||
|
||
# Move the interpolated axis back into place | ||
if axis == 2: # interpolated over H | ||
return jnp.moveaxis(out, 0, 2) # (B, C, out_H, W) | ||
else: # axis == 3, interpolated over W | ||
return jnp.moveaxis(out, 0, 3) # (B, C, H, out_W) | ||
|
||
|
||
def interpolate_bicubic_no_aa(img, out_h, out_w, align_corners=False): | ||
h, w = img.shape[-2:] | ||
if align_corners and out_h > 1: | ||
scale_y = (h - 1) / (out_h - 1) | ||
else: | ||
scale_y = out_h / h | ||
|
||
if align_corners and out_w > 1: | ||
scale_x = (w - 1) / (out_w - 1) | ||
else: | ||
scale_x = out_w / w | ||
|
||
idxs_y, weights_y = compute_contribs( | ||
h, | ||
out_h, | ||
scale_y, | ||
align_corners=align_corners, | ||
dtype=img.dtype, | ||
) | ||
tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2) | ||
|
||
idxs_x, weights_x = compute_contribs( | ||
w, | ||
out_w, | ||
scale_x, | ||
align_corners=align_corners, | ||
dtype=img.dtype, | ||
) | ||
out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3) | ||
return out |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.