Skip to content

Commit ed4efbd

Browse files
Update training script for txt to img sdxl with lora supp with new interpolation. (#11496)
* Update training script for txt to img sdxl with lora supp with new interpolation. * ran make style and make quality.
1 parent 9c29e93 commit ed4efbd

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

examples/text_to_image/train_text_to_image_lora_sdxl.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,15 @@ def parse_args(input_args=None):
480480
action="store_true",
481481
help="debug loss for each image, if filenames are available in the dataset",
482482
)
483+
parser.add_argument(
484+
"--image_interpolation_mode",
485+
type=str,
486+
default="lanczos",
487+
choices=[
488+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
489+
],
490+
help="The image interpolation method to use for resizing images.",
491+
)
483492

484493
if input_args is not None:
485494
args = parser.parse_args(input_args)
@@ -913,8 +922,14 @@ def tokenize_captions(examples, is_train=True):
913922
tokens_two = tokenize_prompt(tokenizer_two, captions)
914923
return tokens_one, tokens_two
915924

925+
# Get the specified interpolation method from the args
926+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
927+
928+
# Raise an error if the interpolation method is invalid
929+
if interpolation is None:
930+
raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.")
916931
# Preprocessing the datasets.
917-
train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
932+
train_resize = transforms.Resize(args.resolution, interpolation=interpolation) # Use dynamic interpolation method
918933
train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
919934
train_flip = transforms.RandomHorizontalFlip(p=1.0)
920935
train_transforms = transforms.Compose(

0 commit comments

Comments
 (0)