Skip to content

Commit 43b3c05

Browse files
committed
use jax.image.scale_and_translate instead of jax.image.resize for _upsample_bilinear2d_aa
1 parent cc631b9 commit 43b3c05

File tree

1 file changed

+20
-20
lines changed
  • experimental/torch_xla2/torch_xla2/ops

1 file changed

+20
-20
lines changed

experimental/torch_xla2/torch_xla2/ops/jaten.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4192,32 +4192,32 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
41924192
shape[-1] = output_size[-1]
41934193
shape[-2] = output_size[-2]
41944194

4195-
# align_corners is not supported in resize()
4196-
# https://github.com/jax-ml/jax/issues/11206
4197-
if align_corners:
4198-
return resize_with_aligned_corners2d(image, shape, scale_factors, method, antialias=True)
4199-
return jax.image.resize(image, shape, method, antialias) # precision=Precision.HIGHEST
4200-
4201-
# From: https://github.com/jax-ml/jax/issues/11206
4202-
def resize_with_aligned_corners2d(
4203-
image: jax.Array,
4204-
shape: Tuple[int, ...],
4205-
scale: Tuple[int, ...],
4206-
method: Union[str, jax.image.ResizeMethod],
4207-
antialias: bool,
4208-
):
4209-
"""Alternative to jax.image.resize(), which emulates align_corners=True in PyTorch's
4210-
interpolation functions."""
4211-
4195+
# pytorch upsample_bilinear returns the input as is when the shape is the same as input
4196+
if shape == list(image.shape):
4197+
return image
42124198

42134199
spatial_dims = (2,3)
42144200
if len(shape) == 3:
42154201
spatial_dims = (1,2)
42164202

4217-
scale = jnp.array([(shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims])
4203+
scale = list([shape[i] / image.shape[i] for i in spatial_dims])
4204+
if scale_factors:
4205+
scale = scale_factors
4206+
if scales_h:
4207+
scale[0] = scales_h
4208+
if scales_w:
4209+
scale[1] = scales_w
4210+
scale = jnp.array(scale)
4211+
4212+
# align_corners is not supported in resize()
4213+
# https://github.com/jax-ml/jax/issues/11206
4214+
if align_corners:
4215+
scale = jnp.array([(shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims])
4216+
4217+
translation = jnp.array([0 for i in spatial_dims])
42184218
#translation = (scale / 2.0 - 0.5)
4219-
translation = (scale * 0.0 )
42204219

4220+
print (">> image_shape", image.shape, " spatial_dims", spatial_dims, " shape", shape, " scale", scale, " translation", translation)
42214221
return jax.image.scale_and_translate(
42224222
image,
42234223
shape,
@@ -4226,4 +4226,4 @@ def resize_with_aligned_corners2d(
42264226
spatial_dims=spatial_dims,
42274227
translation=translation,
42284228
antialias=antialias,
4229-
)
4229+
)

0 commit comments

Comments
 (0)