@@ -480,6 +480,15 @@ def parse_args(input_args=None):
480
480
action = "store_true" ,
481
481
help = "debug loss for each image, if filenames are available in the dataset" ,
482
482
)
483
+ parser .add_argument (
484
+ "--image_interpolation_mode" ,
485
+ type = str ,
486
+ default = "lanczos" ,
487
+ choices = [
488
+ f .lower () for f in dir (transforms .InterpolationMode ) if not f .startswith ("__" ) and not f .endswith ("__" )
489
+ ],
490
+ help = "The image interpolation method to use for resizing images." ,
491
+ )
483
492
484
493
if input_args is not None :
485
494
args = parser .parse_args (input_args )
@@ -913,8 +922,14 @@ def tokenize_captions(examples, is_train=True):
913
922
tokens_two = tokenize_prompt (tokenizer_two , captions )
914
923
return tokens_one , tokens_two
915
924
925
+ # Get the specified interpolation method from the args
926
+ interpolation = getattr (transforms .InterpolationMode , args .image_interpolation_mode .upper (), None )
927
+
928
+ # Raise an error if the interpolation method is invalid
929
+ if interpolation is None :
930
+ raise ValueError (f"Unsupported interpolation mode { args .image_interpolation_mode } ." )
916
931
# Preprocessing the datasets.
917
- train_resize = transforms .Resize (args .resolution , interpolation = transforms . InterpolationMode . BILINEAR )
932
+ train_resize = transforms .Resize (args .resolution , interpolation = interpolation ) # Use dynamic interpolation method
918
933
train_crop = transforms .CenterCrop (args .resolution ) if args .center_crop else transforms .RandomCrop (args .resolution )
919
934
train_flip = transforms .RandomHorizontalFlip (p = 1.0 )
920
935
train_transforms = transforms .Compose (
0 commit comments