Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions library/sdxl_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,23 +327,26 @@ def diffusers_saver(out_dir):
)


def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
)
parser.add_argument(
"--cache_text_encoder_outputs_to_disk",
action="store_true",
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
)
def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True):
if support_text_encoder_caching:
parser.add_argument(
"--cache_text_encoder_outputs",
action="store_true",
help="cache text encoder outputs / text encoderの出力をキャッシュする",
)
parser.add_argument(
"--cache_text_encoder_outputs_to_disk",
action="store_true",
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
)
parser.add_argument(
"--disable_mmap_load_safetensors",
action="store_true",
help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
)


def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
def verify_sdxl_training_args(args: argparse.Namespace, support_text_encoder_caching: bool = True):
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"

if args.clip_skip is not None:
Expand All @@ -366,7 +369,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
not hasattr(args, "weighted_captions") or not args.weighted_captions
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"

if supportTextEncoderCaching:
if support_text_encoder_caching:
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
args.cache_text_encoder_outputs = True
logger.warning(
Expand Down
8 changes: 4 additions & 4 deletions sdxl_train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
from library.device_utils import init_ipex

init_ipex()

from library import sdxl_model_util, sdxl_train_util, train_util
Expand All @@ -19,8 +20,8 @@ def __init__(self):
self.is_sdxl = True

def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
# super().assert_extra_args(args, train_dataset_group) # do not call parent because it checks reso steps with 64
sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)

train_dataset_group.verify_bucket_reso_steps(32)

Expand Down Expand Up @@ -122,8 +123,7 @@ def load_weights(self, file):

def setup_parser() -> argparse.ArgumentParser:
parser = train_textual_inversion.setup_parser()
# don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching
# sdxl_train_util.add_sdxl_training_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False)
return parser


Expand Down