Skip to content

Commit

Permalink
use jax.image.scale_and_translate instead of jax.image.resize for _up…
Browse files Browse the repository at this point in the history
…sample_bilinear2d_aa
  • Loading branch information
barney-s committed Oct 2, 2024
1 parent cc631b9 commit 43b3c05
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -4192,32 +4192,32 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
shape[-1] = output_size[-1]
shape[-2] = output_size[-2]

# align_corners is not supported in resize()
# https://github.com/jax-ml/jax/issues/11206
if align_corners:
return resize_with_aligned_corners2d(image, shape, scale_factors, method, antialias=True)
return jax.image.resize(image, shape, method, antialias) # precision=Precision.HIGHEST

# From: https://github.com/jax-ml/jax/issues/11206
def resize_with_aligned_corners2d(
image: jax.Array,
shape: Tuple[int, ...],
scale: Tuple[int, ...],
method: Union[str, jax.image.ResizeMethod],
antialias: bool,
):
"""Alternative to jax.image.resize(), which emulates align_corners=True in PyTorch's
interpolation functions."""

# pytorch upsample_bilinear returns the input as is when the shape is the same as input
if shape == list(image.shape):
return image

spatial_dims = (2,3)
if len(shape) == 3:
spatial_dims = (1,2)

scale = jnp.array([(shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims])
scale = list([shape[i] / image.shape[i] for i in spatial_dims])
if scale_factors:
scale = scale_factors
if scales_h:
scale[0] = scales_h
if scales_w:
scale[1] = scales_w
scale = jnp.array(scale)

# align_corners is not supported in resize()
# https://github.com/jax-ml/jax/issues/11206
if align_corners:
scale = jnp.array([(shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims])

translation = jnp.array([0 for i in spatial_dims])
#translation = (scale / 2.0 - 0.5)
translation = (scale * 0.0 )

print (">> image_shape", image.shape, " spatial_dims", spatial_dims, " shape", shape, " scale", scale, " translation", translation)
return jax.image.scale_and_translate(
image,
shape,
Expand All @@ -4226,4 +4226,4 @@ def resize_with_aligned_corners2d(
spatial_dims=spatial_dims,
translation=translation,
antialias=antialias,
)
)

0 comments on commit 43b3c05

Please sign in to comment.