@@ -4192,32 +4192,32 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
4192
4192
shape [- 1 ] = output_size [- 1 ]
4193
4193
shape [- 2 ] = output_size [- 2 ]
4194
4194
4195
- # align_corners is not supported in resize()
4196
- # https://github.com/jax-ml/jax/issues/11206
4197
- if align_corners :
4198
- return resize_with_aligned_corners2d (image , shape , scale_factors , method , antialias = True )
4199
- return jax .image .resize (image , shape , method , antialias ) # precision=Precision.HIGHEST
4200
-
4201
- # From: https://github.com/jax-ml/jax/issues/11206
4202
- def resize_with_aligned_corners2d (
4203
- image : jax .Array ,
4204
- shape : Tuple [int , ...],
4205
- scale : Tuple [int , ...],
4206
- method : Union [str , jax .image .ResizeMethod ],
4207
- antialias : bool ,
4208
- ):
4209
- """Alternative to jax.image.resize(), which emulates align_corners=True in PyTorch's
4210
- interpolation functions."""
4211
-
4195
+ # pytorch upsample_bilinear returns the input as is when the shape is the same as input
4196
+ if shape == list (image .shape ):
4197
+ return image
4212
4198
4213
4199
spatial_dims = (2 ,3 )
4214
4200
if len (shape ) == 3 :
4215
4201
spatial_dims = (1 ,2 )
4216
4202
4217
- scale = jnp .array ([(shape [i ] - 1.0 ) / (image .shape [i ] - 1.0 ) for i in spatial_dims ])
4203
+ scale = list ([shape [i ] / image .shape [i ] for i in spatial_dims ])
4204
+ if scale_factors :
4205
+ scale = scale_factors
4206
+ if scales_h :
4207
+ scale [0 ] = scales_h
4208
+ if scales_w :
4209
+ scale [1 ] = scales_w
4210
+ scale = jnp .array (scale )
4211
+
4212
+ # align_corners is not supported in resize()
4213
+ # https://github.com/jax-ml/jax/issues/11206
4214
+ if align_corners :
4215
+ scale = jnp .array ([(shape [i ] - 1.0 ) / (image .shape [i ] - 1.0 ) for i in spatial_dims ])
4216
+
4217
+ translation = jnp .array ([0 for i in spatial_dims ])
4218
4218
#translation = (scale / 2.0 - 0.5)
4219
- translation = (scale * 0.0 )
4220
4219
4220
+ print (">> image_shape" , image .shape , " spatial_dims" , spatial_dims , " shape" , shape , " scale" , scale , " translation" , translation )
4221
4221
return jax .image .scale_and_translate (
4222
4222
image ,
4223
4223
shape ,
@@ -4226,4 +4226,4 @@ def resize_with_aligned_corners2d(
4226
4226
spatial_dims = spatial_dims ,
4227
4227
translation = translation ,
4228
4228
antialias = antialias ,
4229
- )
4229
+ )
0 commit comments