@@ -146,7 +146,7 @@ def resize_image_to(
146
146
scale_factors = target_image_size / orig_image_size
147
147
out = resize (image , scale_factors = scale_factors , ** kwargs )
148
148
else :
149
- out = F .interpolate (image , target_image_size , mode = 'nearest' , align_corners = False )
149
+ out = F .interpolate (image , target_image_size , mode = 'nearest' )
150
150
151
151
if exists (clamp_range ):
152
152
out = out .clamp (* clamp_range )
@@ -1957,16 +1957,13 @@ class LowresConditioner(nn.Module):
1957
1957
def __init__ (
1958
1958
self ,
1959
1959
downsample_first = True ,
1960
- downsample_mode_nearest = False ,
1961
1960
blur_prob = 0.5 ,
1962
1961
blur_sigma = 0.6 ,
1963
1962
blur_kernel_size = 3 ,
1964
1963
input_image_range = None
1965
1964
):
1966
1965
super ().__init__ ()
1967
1966
self .downsample_first = downsample_first
1968
- self .downsample_mode_nearest = downsample_mode_nearest
1969
-
1970
1967
self .input_image_range = input_image_range
1971
1968
1972
1969
self .blur_prob = blur_prob
@@ -1983,7 +1980,7 @@ def forward(
1983
1980
blur_kernel_size = None
1984
1981
):
1985
1982
if self .downsample_first and exists (downsample_image_size ):
1986
- cond_fmap = resize_image_to (cond_fmap , downsample_image_size , clamp_range = self .input_image_range , nearest = self . downsample_mode_nearest )
1983
+ cond_fmap = resize_image_to (cond_fmap , downsample_image_size , clamp_range = self .input_image_range , nearest = True )
1987
1984
1988
1985
# blur is only applied 50% of the time
1989
1986
# section 3.1 in https://arxiv.org/abs/2106.15282
@@ -2010,7 +2007,7 @@ def forward(
2010
2007
2011
2008
cond_fmap = gaussian_blur2d (cond_fmap , cast_tuple (blur_kernel_size , 2 ), cast_tuple (blur_sigma , 2 ))
2012
2009
2013
- cond_fmap = resize_image_to (cond_fmap , target_image_size , clamp_range = self .input_image_range )
2010
+ cond_fmap = resize_image_to (cond_fmap , target_image_size , clamp_range = self .input_image_range , nearest = True )
2014
2011
return cond_fmap
2015
2012
2016
2013
class Decoder (nn .Module ):
@@ -2033,7 +2030,6 @@ def __init__(
2033
2030
image_sizes = None , # for cascading ddpm, image size at each stage
2034
2031
random_crop_sizes = None , # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
2035
2032
lowres_downsample_first = True , # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
2036
- lowres_downsample_mode_nearest = False , # cascading ddpm - whether to use nearest mode downsampling for lower resolution
2037
2033
blur_prob = 0.5 , # cascading ddpm - when training, the gaussian blur is only applied 50% of the time
2038
2034
blur_sigma = 0.6 , # cascading ddpm - blur sigma
2039
2035
blur_kernel_size = 3 , # cascading ddpm - blur kernel size
@@ -2183,11 +2179,8 @@ def __init__(
2183
2179
lowres_conditions = tuple (map (lambda t : t .lowres_cond , self .unets ))
2184
2180
assert lowres_conditions == (False , * ((True ,) * (len (self .unets ) - 1 ))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
2185
2181
2186
- self .lowres_downsample_mode_nearest = lowres_downsample_mode_nearest
2187
-
2188
2182
self .to_lowres_cond = LowresConditioner (
2189
2183
downsample_first = lowres_downsample_first ,
2190
- downsample_mode_nearest = lowres_downsample_mode_nearest ,
2191
2184
blur_prob = blur_prob ,
2192
2185
blur_sigma = blur_sigma ,
2193
2186
blur_kernel_size = blur_kernel_size ,
@@ -2510,7 +2503,7 @@ def sample(
2510
2503
shape = (batch_size , channel , image_size , image_size )
2511
2504
2512
2505
if unet .lowres_cond :
2513
- lowres_cond_img = resize_image_to (img , target_image_size = image_size , clamp_range = self .input_image_range , nearest = self . lowres_downsample_mode_nearest )
2506
+ lowres_cond_img = resize_image_to (img , target_image_size = image_size , clamp_range = self .input_image_range , nearest = True )
2514
2507
2515
2508
is_latent_diffusion = isinstance (vae , VQGanVAE )
2516
2509
image_size = vae .get_encoded_fmap_size (image_size )
@@ -2580,7 +2573,7 @@ def forward(
2580
2573
assert not (not self .condition_on_text_encodings and exists (text_encodings )), 'decoder specified not to be conditioned on text, yet it is presented'
2581
2574
2582
2575
lowres_cond_img = self .to_lowres_cond (image , target_image_size = target_image_size , downsample_image_size = self .image_sizes [unet_index - 1 ]) if unet_number > 1 else None
2583
- image = resize_image_to (image , target_image_size )
2576
+ image = resize_image_to (image , target_image_size , nearest = True )
2584
2577
2585
2578
if exists (random_crop_size ):
2586
2579
aug = K .RandomCrop ((random_crop_size , random_crop_size ), p = 1. )
0 commit comments