Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)

loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
Expand Down Expand Up @@ -456,7 +459,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
)

current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
Expand All @@ -469,7 +472,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

Expand Down
7 changes: 5 additions & 2 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,9 @@ def optimizer_hook(parameter: torch.Tensor):

# For --sample_at_first
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)

loss_recorder = train_util.LossRecorder()
epoch = 0 # avoid error when max_train_steps is 0
Expand Down Expand Up @@ -777,7 +780,7 @@ def optimizer_hook(parameter: torch.Tensor):
)

current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)

Expand All @@ -791,7 +794,7 @@ def optimizer_hook(parameter: torch.Tensor):
if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

Expand Down
20 changes: 11 additions & 9 deletions library/flux_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,17 +241,19 @@ def sample_image_inference(
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))

# wandb有効時のみログを送信
try:
# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")

wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log(
{f"sample_{i}": wandb.Image(
image,
caption=prompt # positive prompt as a caption
)},
commit=False
)


def time_shift(mu: float, sigma: float, t: torch.Tensor):
Expand Down
20 changes: 11 additions & 9 deletions library/sd3_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,17 +604,19 @@ def sample_image_inference(
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))

# wandb有効時のみログを送信
try:
# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")

wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log(
{f"sample_{i}": wandb.Image(
image,
caption=prompt # positive prompt as a caption
)},
commit=False
)


# region Diffusers
Expand Down
20 changes: 11 additions & 9 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5745,17 +5745,19 @@ def sample_image_inference(
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))

# wandb有効時のみログを送信
try:
# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")

wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log(
{f"sample_{i}": wandb.Image(
image,
caption=prompt # positive prompt as a caption
)},
commit=False
)


# endregion
Expand Down
7 changes: 5 additions & 2 deletions sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,9 @@ def optimizer_hook(parameter: torch.Tensor):

# For --sample_at_first
sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)

# following function will be moved to sd3_train_utils

Expand Down Expand Up @@ -881,7 +884,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
)

current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_mmdit)

Expand All @@ -895,7 +898,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

Expand Down
7 changes: 5 additions & 2 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,9 @@ def optimizer_hook(parameter: torch.Tensor):
sdxl_train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, [text_encoder1, text_encoder2], unet
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)

loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
Expand Down Expand Up @@ -797,7 +800,7 @@ def optimizer_hook(parameter: torch.Tensor):
)

current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
if block_lrs is None:
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet)
Expand All @@ -814,7 +817,7 @@ def optimizer_hook(parameter: torch.Tensor):
if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

Expand Down
4 changes: 2 additions & 2 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,14 +541,14 @@ def remove_model(old_ckpt_name):
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)

if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

Expand Down
4 changes: 2 additions & 2 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,14 +480,14 @@ def remove_model(old_ckpt_name):
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)

if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

Expand Down
7 changes: 5 additions & 2 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,9 @@ def remove_model(old_ckpt_name):
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)

# training loop
for epoch in range(num_train_epochs):
Expand Down Expand Up @@ -542,14 +545,14 @@ def remove_model(old_ckpt_name):
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)

if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

Expand Down
7 changes: 5 additions & 2 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,9 @@ def train(args):
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)

loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
Expand Down Expand Up @@ -445,7 +448,7 @@ def train(args):
)

current_loss = loss.detach().item()
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
Expand All @@ -458,7 +461,7 @@ def train(args):
if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

Expand Down
7 changes: 5 additions & 2 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,9 @@ def remove_model(old_ckpt_name):

# For --sample_at_first
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)

# training loop
if initial_step > 0: # only if skip_until_initial_step is specified
Expand Down Expand Up @@ -1200,7 +1203,7 @@ def remove_model(old_ckpt_name):
if args.scale_weight_norms:
progress_bar.set_postfix(**{**max_mean_logs, **logs})

if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = self.generate_step_logs(
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm
)
Expand All @@ -1209,7 +1212,7 @@ def remove_model(old_ckpt_name):
if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

Expand Down
8 changes: 6 additions & 2 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,9 @@ def remove_model(old_ckpt_name):
unet,
prompt_replacement,
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)

# training loop
for epoch in range(num_train_epochs):
Expand Down Expand Up @@ -684,7 +687,7 @@ def remove_model(old_ckpt_name):
remove_model(remove_ckpt_name)

current_loss = loss.detach().item()
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
Expand All @@ -702,7 +705,7 @@ def remove_model(old_ckpt_name):
if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch + 1)

Expand Down Expand Up @@ -739,6 +742,7 @@ def remove_model(old_ckpt_name):
unet,
prompt_replacement,
)
accelerator.log({})

# end of epoch

Expand Down
4 changes: 2 additions & 2 deletions train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def remove_model(old_ckpt_name):
remove_model(remove_ckpt_name)

current_loss = loss.detach().item()
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
Expand All @@ -556,7 +556,7 @@ def remove_model(old_ckpt_name):
if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch + 1)

Expand Down