Skip to content

Add support for custom conditioning image #812

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions modules/dataLoader/mixin/DataLoaderText2ImageMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,19 @@ def _enumerate_input_modules(self, config: TrainConfig, allow_videos: bool = Fal
collect_paths = CollectPaths(
concept_in_name='concept', path_in_name='path', include_subdirectories_in_name='concept.include_subdirectories', enabled_in_name='enabled',
path_out_name='image_path', concept_out_name='concept',
extensions=supported_extensions, include_postfix=None, exclude_postfix=['-masklabel']
extensions=supported_extensions, include_postfix=None, exclude_postfix=['-masklabel','-condlabel']
)

mask_path = ModifyPath(in_name='image_path', out_name='mask_path', postfix='-masklabel', extension='.png')
cond_path = ModifyPath(in_name='image_path', out_name='cond_path', postfix='-condlabel', extension='.png')
sample_prompt_path = ModifyPath(in_name='image_path', out_name='sample_prompt_path', postfix='', extension='.txt')

modules = [collect_paths, sample_prompt_path]

if config.masked_training:
modules.append(mask_path)
if config.custom_conditioning_image:
modules.append(cond_path)

return modules

Expand All @@ -82,6 +85,10 @@ def _load_input_modules(
load_mask = LoadImage(path_in_name='mask_path', image_out_name='mask', range_min=0, range_max=1, channels=1, supported_extensions={".png"}, dtype=train_dtype.torch_dtype())
mask_to_video = ImageToVideo(in_name='mask', out_name='mask')

load_cond_image = LoadImage(path_in_name='cond_path', image_out_name='conditioning_image', range_min=0, range_max=1,
supported_extensions=path_util.supported_image_extensions(),
dtype=train_dtype.torch_dtype())

load_sample_prompts = LoadMultipleTexts(path_in_name='sample_prompt_path', texts_out_name='sample_prompts')
load_concept_prompts = LoadMultipleTexts(path_in_name='concept.text.prompt_path', texts_out_name='concept_prompts')
filename_prompt = GetFilename(path_in_name='image_path', filename_out_name='filename_prompt', include_extension=False)
Expand All @@ -105,6 +112,9 @@ def _load_input_modules(
elif config.model_type.has_mask_input():
modules.append(generate_mask)

if config.custom_conditioning_image:
modules.append(load_cond_image)

if allow_video:
modules.append(mask_to_video)

Expand Down Expand Up @@ -171,6 +181,9 @@ def _crop_modules(self, config: TrainConfig):
if config.model_type.has_depth_input():
inputs.append('depth')

if config.custom_conditioning_image:
inputs.append('conditioning_image')

scale_crop = ScaleCropImage(names=inputs, scale_resolution_in_name='scale_resolution', crop_resolution_in_name='crop_resolution', enable_crop_jitter_in_name='concept.image.enable_crop_jitter', crop_offset_out_name='crop_offset')

modules = [scale_crop]
Expand Down Expand Up @@ -218,7 +231,7 @@ def _inpainting_modules(self, config: TrainConfig):

modules = []

if config.model_type.has_conditioning_image_input():
if config.model_type.has_conditioning_image_input() and not config.custom_conditioning_image:
modules.append(conditioning_image)

return modules
Expand Down
2 changes: 1 addition & 1 deletion modules/ui/CaptionUI.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def load_directory(self, include_subdirectories: bool = False):
def scan_directory(self, include_subdirectories: bool = False):
def __is_supported_image_extension(filename):
name, ext = os.path.splitext(filename)
return path_util.is_supported_image_extension(ext) and not name.endswith("-masklabel")
return path_util.is_supported_image_extension(ext) and not name.endswith("-masklabel") and not name.endswith("-condlabel")

self.image_rel_paths = []

Expand Down
2 changes: 1 addition & 1 deletion modules/ui/ConceptTab.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __get_preview_image(self):
for path in pathlib.Path(self.concept.path).glob(glob_pattern):
extension = os.path.splitext(path)[1]
if path.is_file() and path_util.is_supported_image_extension(extension) \
and not path.name.endswith("-masklabel.png"):
and not path.name.endswith("-masklabel.png") and not path.name.endswith("-condlabel.png"):
preview_path = path_util.canonical_join(self.concept.path, path)
break

Expand Down
4 changes: 2 additions & 2 deletions modules/ui/ConceptWindow.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def __concept_stats_tab(self, master):
#basic img/vid stats - count of each type in the concept
#the \n at the start of the label gives it better vertical spacing with other rows
self.image_count_label = components.label(frame, 3, 0, "\nTotal Images", pad=0,
tooltip="Total number of image files, any of the extensions " + str(path_util.SUPPORTED_IMAGE_EXTENSIONS) + ", excluding '-masklabel.png'")
tooltip="Total number of image files, any of the extensions " + str(path_util.SUPPORTED_IMAGE_EXTENSIONS) + ", excluding '-masklabel.png and -condlabel.png'")
self.image_count_label.configure(font=ctk.CTkFont(underline=True))
self.image_count_preview = components.label(frame, 4, 0, pad=0, text="-")
self.video_count_label = components.label(frame, 3, 1, "\nTotal Videos", pad=0,
Expand Down Expand Up @@ -545,7 +545,7 @@ def __get_preview_image(self):
for path in pathlib.Path(self.concept.path).glob(glob_pattern):
extension = os.path.splitext(path)[1]
if path.is_file() and path_util.is_supported_image_extension(extension) \
and not path.name.endswith("-masklabel.png"):
and not path.name.endswith("-masklabel.png") and not path.name.endswith("-condlabel.png"):
preview_image_path = path_util.canonical_join(self.concept.path, path)
file_index += 1
if file_index == self.image_preview_file_index:
Expand Down
5 changes: 5 additions & 0 deletions modules/ui/TrainingTab.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,11 @@ def __create_masked_frame(self, master, row):
tooltip="When masked training is enabled, normalizes the loss for each sample based on the sizes of the masked region")
components.switch(frame, 3, 1, self.ui_state, "normalize_masked_area_loss")

# use custom conditioning image
components.label(frame, 4, 0, "Custom Conditioning Image",
tooltip="When custom conditioning image is enabled, will use png postfix with -condlabel instead of automatically generated.It's suitable for special scenarios, such as object removal, allowing the model to learn a certain behavior concept")
components.switch(frame, 4, 1, self.ui_state, "custom_conditioning_image")

def __create_loss_frame(self, master, row, supports_vb_loss: bool = False):
frame = ctk.CTkFrame(master=master, corner_radius=5)
frame.grid(row=row, column=0, padx=5, pady=5, sticky="nsew")
Expand Down
2 changes: 1 addition & 1 deletion modules/util/concept_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def folder_scan(dir, stats_dict : dict, advanced_checks : bool, conceptconfig :

for path in file_list:
basename, extension = os.path.splitext(path)
if extension.lower() in img_extensions_list and not path.name.endswith("-masklabel.png"):
if extension.lower() in img_extensions_list and not path.name.endswith("-masklabel.png") and not path.name.endswith("-condlabel.png"):
stats_dict["image_count"] += 1
stats_dict["file_size"] += path.stat().st_size
if advanced_checks:
Expand Down
4 changes: 4 additions & 0 deletions modules/util/config/TrainConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,9 @@ class TrainConfig(BaseConfig):
unmasked_weight: float
normalize_masked_area_loss: bool

# custom conditioning image
custom_conditioning_image: bool

# embedding
embedding_learning_rate: float
preserve_embedding_norm: bool
Expand Down Expand Up @@ -907,6 +910,7 @@ def default_values() -> 'TrainConfig':
data.append(("unmasked_probability", 0.1, float, False))
data.append(("unmasked_weight", 0.1, float, False))
data.append(("normalize_masked_area_loss", False, bool, False))
data.append(("custom_conditioning_image", False, bool, False))

# embedding
data.append(("embedding_learning_rate", None, float, True))
Expand Down