From 7d779fe63352ac99a409fefc15d5a86f8e8021c7 Mon Sep 17 00:00:00 2001 From: wenyifan Date: Wed, 23 Apr 2025 16:05:49 +0800 Subject: [PATCH 1/2] Add support for custom conditioning image --- .../mixin/DataLoaderText2ImageMixin.py | 17 +++++++++++++++-- modules/ui/CaptionUI.py | 2 +- modules/ui/ConceptTab.py | 2 +- modules/ui/ConceptWindow.py | 4 ++-- modules/ui/TrainingTab.py | 5 +++++ modules/util/concept_stats.py | 2 +- modules/util/config/TrainConfig.py | 4 ++++ 7 files changed, 29 insertions(+), 7 deletions(-) diff --git a/modules/dataLoader/mixin/DataLoaderText2ImageMixin.py b/modules/dataLoader/mixin/DataLoaderText2ImageMixin.py index d85fe7624..13ec879cd 100644 --- a/modules/dataLoader/mixin/DataLoaderText2ImageMixin.py +++ b/modules/dataLoader/mixin/DataLoaderText2ImageMixin.py @@ -55,16 +55,19 @@ def _enumerate_input_modules(self, config: TrainConfig, allow_videos: bool = Fal collect_paths = CollectPaths( concept_in_name='concept', path_in_name='path', include_subdirectories_in_name='concept.include_subdirectories', enabled_in_name='enabled', path_out_name='image_path', concept_out_name='concept', - extensions=supported_extensions, include_postfix=None, exclude_postfix=['-masklabel'] + extensions=supported_extensions, include_postfix=None, exclude_postfix=['-masklabel','-condlabel'] ) mask_path = ModifyPath(in_name='image_path', out_name='mask_path', postfix='-masklabel', extension='.png') + cond_path = ModifyPath(in_name='image_path', out_name='cond_path', postfix='-condlabel', extension='.png') sample_prompt_path = ModifyPath(in_name='image_path', out_name='sample_prompt_path', postfix='', extension='.txt') modules = [collect_paths, sample_prompt_path] if config.masked_training: modules.append(mask_path) + if config.custom_conditioning_image: + modules.append(cond_path) return modules @@ -82,6 +85,10 @@ def _load_input_modules( load_mask = LoadImage(path_in_name='mask_path', image_out_name='mask', range_min=0, range_max=1, channels=1, supported_extensions={".png"}, dtype=train_dtype.torch_dtype()) mask_to_video = ImageToVideo(in_name='mask', out_name='mask') + load_cond_image = LoadImage(path_in_name='cond_path', image_out_name='conditioning_image', range_min=0, range_max=1, + supported_extensions=path_util.supported_image_extensions(), + dtype=train_dtype.torch_dtype()) + load_sample_prompts = LoadMultipleTexts(path_in_name='sample_prompt_path', texts_out_name='sample_prompts') load_concept_prompts = LoadMultipleTexts(path_in_name='concept.text.prompt_path', texts_out_name='concept_prompts') filename_prompt = GetFilename(path_in_name='image_path', filename_out_name='filename_prompt', include_extension=False) @@ -105,6 +112,9 @@ def _load_input_modules( elif config.model_type.has_mask_input(): modules.append(generate_mask) + if config.custom_conditioning_image: + modules.append(load_cond_image) + if allow_video: modules.append(mask_to_video) @@ -171,6 +181,9 @@ def _crop_modules(self, config: TrainConfig): if config.model_type.has_depth_input(): inputs.append('depth') + if config.custom_conditioning_image: + inputs.append('conditioning_image') + scale_crop = ScaleCropImage(names=inputs, scale_resolution_in_name='scale_resolution', crop_resolution_in_name='crop_resolution', enable_crop_jitter_in_name='concept.image.enable_crop_jitter', crop_offset_out_name='crop_offset') modules = [scale_crop] @@ -218,7 +231,7 @@ def _inpainting_modules(self, config: TrainConfig): modules = [] - if config.model_type.has_conditioning_image_input(): + if config.model_type.has_conditioning_image_input() and config.custom_conditioning_image == False: modules.append(conditioning_image) return modules diff --git a/modules/ui/CaptionUI.py b/modules/ui/CaptionUI.py index 9a5282dfd..133df74bd 100644 --- a/modules/ui/CaptionUI.py +++ b/modules/ui/CaptionUI.py @@ -225,7 +225,7 @@ def load_directory(self, include_subdirectories: bool = False): def scan_directory(self, include_subdirectories: bool = False): def __is_supported_image_extension(filename): name, ext = os.path.splitext(filename) - return path_util.is_supported_image_extension(ext) and not name.endswith("-masklabel") + return path_util.is_supported_image_extension(ext) and not name.endswith("-masklabel") and not name.endswith("-condlabel") self.image_rel_paths = [] diff --git a/modules/ui/ConceptTab.py b/modules/ui/ConceptTab.py index 9a0417b7d..e6249ead6 100644 --- a/modules/ui/ConceptTab.py +++ b/modules/ui/ConceptTab.py @@ -128,7 +128,7 @@ def __get_preview_image(self): for path in pathlib.Path(self.concept.path).glob(glob_pattern): extension = os.path.splitext(path)[1] if path.is_file() and path_util.is_supported_image_extension(extension) \ - and not path.name.endswith("-masklabel.png"): + and not path.name.endswith("-masklabel.png") and not path.name.endswith("-condlabel.png"): preview_path = path_util.canonical_join(self.concept.path, path) break diff --git a/modules/ui/ConceptWindow.py b/modules/ui/ConceptWindow.py index 32b1e6e42..e0e760d5e 100644 --- a/modules/ui/ConceptWindow.py +++ b/modules/ui/ConceptWindow.py @@ -377,7 +377,7 @@ def __concept_stats_tab(self, master): #basic img/vid stats - count of each type in the concept #the \n at the start of the label gives it better vertical spacing with other rows self.image_count_label = components.label(frame, 3, 0, "\nTotal Images", pad=0, - tooltip="Total number of image files, any of the extensions " + str(path_util.SUPPORTED_IMAGE_EXTENSIONS) + ", excluding '-masklabel.png'") + tooltip="Total number of image files, any of the extensions " + str(path_util.SUPPORTED_IMAGE_EXTENSIONS) + ", excluding '-masklabel.png and -condlabel.png'") self.image_count_label.configure(font=ctk.CTkFont(underline=True)) self.image_count_preview = components.label(frame, 4, 0, pad=0, text="-") self.video_count_label = components.label(frame, 3, 1, "\nTotal Videos", pad=0, @@ -545,7 +545,7 @@ def __get_preview_image(self): for path in pathlib.Path(self.concept.path).glob(glob_pattern): extension = os.path.splitext(path)[1] if path.is_file() and path_util.is_supported_image_extension(extension) \ - and not path.name.endswith("-masklabel.png"): + and not path.name.endswith("-masklabel.png") and not path.name.endswith("-condlabel.png"): preview_image_path = path_util.canonical_join(self.concept.path, path) file_index += 1 if file_index == self.image_preview_file_index: diff --git a/modules/ui/TrainingTab.py b/modules/ui/TrainingTab.py index b49acce59..ba28006ec 100644 --- a/modules/ui/TrainingTab.py +++ b/modules/ui/TrainingTab.py @@ -682,6 +682,11 @@ def __create_masked_frame(self, master, row): tooltip="When masked training is enabled, normalizes the loss for each sample based on the sizes of the masked region") components.switch(frame, 3, 1, self.ui_state, "normalize_masked_area_loss") + # use custom conditioning image + components.label(frame, 4, 0, "Custom Conditioning Image", + tooltip="When custom conditioning image is enabled, will use png postfix with -condlabel instead of automatically generated.It's suitable for special scenarios, such as object removal, allowing the model to learn a certain behavior concept") + components.switch(frame, 4, 1, self.ui_state, "custom_conditioning_image") + def __create_loss_frame(self, master, row, supports_vb_loss: bool = False): frame = ctk.CTkFrame(master=master, corner_radius=5) frame.grid(row=row, column=0, padx=5, pady=5, sticky="nsew") diff --git a/modules/util/concept_stats.py b/modules/util/concept_stats.py index 6c1a31bb9..88cc149e0 100644 --- a/modules/util/concept_stats.py +++ b/modules/util/concept_stats.py @@ -91,7 +91,7 @@ def folder_scan(dir, stats_dict : dict, advanced_checks : bool, conceptconfig : for path in file_list: basename, extension = os.path.splitext(path) - if extension.lower() in img_extensions_list and not path.name.endswith("-masklabel.png"): + if extension.lower() in img_extensions_list and not path.name.endswith("-masklabel.png") and not path.name.endswith("-condlabel.png"): stats_dict["image_count"] += 1 stats_dict["file_size"] += path.stat().st_size if advanced_checks: diff --git a/modules/util/config/TrainConfig.py b/modules/util/config/TrainConfig.py index 8bde8f1ba..2157e639c 100644 --- a/modules/util/config/TrainConfig.py +++ b/modules/util/config/TrainConfig.py @@ -388,6 +388,9 @@ class TrainConfig(BaseConfig): unmasked_weight: float normalize_masked_area_loss: bool + # custom conditioning image + custom_conditioning_image: bool + # embedding embedding_learning_rate: float preserve_embedding_norm: bool @@ -907,6 +910,7 @@ def default_values() -> 'TrainConfig': data.append(("unmasked_probability", 0.1, float, False)) data.append(("unmasked_weight", 0.1, float, False)) data.append(("normalize_masked_area_loss", False, bool, False)) + data.append(("custom_conditioning_image", False, bool, False)) # embedding data.append(("embedding_learning_rate", None, float, True)) From e651761c73fd65b48c61581304326bc2cbb28667 Mon Sep 17 00:00:00 2001 From: wenyifan Date: Wed, 23 Apr 2025 17:13:32 +0800 Subject: [PATCH 2/2] fix config.custom_conditioning_image code syntax issues --- modules/dataLoader/mixin/DataLoaderText2ImageMixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/dataLoader/mixin/DataLoaderText2ImageMixin.py b/modules/dataLoader/mixin/DataLoaderText2ImageMixin.py index 13ec879cd..1bd9c8663 100644 --- a/modules/dataLoader/mixin/DataLoaderText2ImageMixin.py +++ b/modules/dataLoader/mixin/DataLoaderText2ImageMixin.py @@ -231,7 +231,7 @@ def _inpainting_modules(self, config: TrainConfig): modules = [] - if config.model_type.has_conditioning_image_input() and config.custom_conditioning_image == False: + if config.model_type.has_conditioning_image_input() and not config.custom_conditioning_image: modules.append(conditioning_image) return modules